diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 05714731b9d4d..7ef8bc1797384 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,7 +6,6 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; -use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -79,10 +78,6 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, - /// Describe the memory layout of input types - pub inputs: Vec, - /// Describe the memory layout of the output type - pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffAttrs { @@ -262,22 +257,14 @@ impl AutoDiffAttrs { !matches!(self.mode, DiffMode::Error | DiffMode::Source) } - pub fn into_item( - self, - source: String, - target: String, - inputs: Vec, - output: TypeTree, - ) -> AutoDiffItem { - AutoDiffItem { source, target, inputs, output, attrs: self } + pub fn into_item(self, source: String, target: String) -> AutoDiffItem { + AutoDiffItem { source, target, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs)?; - write!(f, " with inputs: {:?}", self.inputs)?; - write!(f, " with output: {:?}", self.output) + write!(f, " with attributes: {:?}", self.attrs) } } diff --git a/compiler/rustc_codegen_gcc/src/lib.rs b/compiler/rustc_codegen_gcc/src/lib.rs index 452e92bffa235..096e9a5961794 100644 --- a/compiler/rustc_codegen_gcc/src/lib.rs +++ b/compiler/rustc_codegen_gcc/src/lib.rs @@ -93,6 +93,7 @@ use gccjit::{CType, Context, OptimizationLevel}; #[cfg(feature = "master")] use gccjit::{TargetInfo, Version}; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryFn, @@ -439,6 +440,14 @@ impl WriteBackendMethods for GccCodegenBackend { ) -> Result, FatalError> { back::write::link(cgcx, dcx, modules) } + fn autodiff( + _cgcx: &CodegenContext, + _module: &ModuleCodegen, + _diff_fncs: Vec, + _config: &ModuleConfig, + ) -> Result<(), FatalError> { + unimplemented!() + } } /// This is the entrypoint for a hot plugged rustc_codegen_gccjit diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index 63c64269eb805..3982c37528df8 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -1,3 +1,5 @@ +codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err} codegen_llvm_dynamic_linking_with_lto = @@ -47,6 +49,8 @@ codegen_llvm_parse_bitcode_with_llvm_err = failed to parse bitcode for LTO modul codegen_llvm_parse_target_machine_config = failed to parse target machine config to target machine: {$error} +codegen_llvm_prepare_autodiff = failed to prepare autodiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare autodiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} codegen_llvm_prepare_thin_lto_context = failed to prepare thin LTO context codegen_llvm_prepare_thin_lto_context_with_llvm_err = failed to prepare thin LTO context: {$llvm_err} diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index f6d2ec24da6f2..4adf99e91d08d 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -604,7 +604,14 @@ pub(crate) fn run_pass_manager( debug!("running the pass manager"); let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - unsafe { write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }?; + + // If this rustc version was build with enzyme/autodiff enabled, and if users applied the + // `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time. + let first_run = true; + debug!("running llvm pm opt pipeline"); + unsafe { + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; + } debug!("lto done"); Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index 66ca4e2b4738c..499523699d4f9 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -8,6 +8,7 @@ use libc::{c_char, c_int, c_void, size_t}; use llvm::{ LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::versioned_llvm_target; use rustc_codegen_ssa::back::write::{ @@ -28,7 +29,7 @@ use rustc_session::config::{ use rustc_span::symbol::sym; use rustc_span::{BytePos, InnerSpan, Pos, SpanData, SyntaxContext}; use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel}; -use tracing::debug; +use tracing::{debug, trace}; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; @@ -530,9 +531,38 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + skip_size_increasing_opts: bool, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized + // source code. However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting + // differentiated. + + let unroll_loops; + let vectorize_slp; + let vectorize_loop; + + // When we build rustc with enzyme/autodiff support, we want to postpone size-increasing + // optimizations until after differentiation. FIXME(ZuseZ4): Before shipping on nightly, + // we should make this more granular, or at least check that the user has at least one autodiff + // call in their code, to justify altering the compilation pipeline. + if skip_size_increasing_opts && cfg!(llvm_enzyme) { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = + opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + } + trace!( + "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", + unroll_loops, vectorize_slp, vectorize_loop + ); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -596,8 +626,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -619,6 +649,83 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +pub(crate) fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + config: &ModuleConfig, +) -> Result<(), FatalError> { + for item in &diff_items { + trace!("{}", item); + } + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + // Before dumping the module, we want all the tt to become part of the module. + for item in diff_items.iter() { + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) }; + let fn_def = match fn_def { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find source function".to_owned(), + })); + } + }; + let target_name = CString::new(item.target.clone()).unwrap(); + debug!("target name: {:?}", &target_name); + let fn_target: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, target_name.as_ptr()) }; + let fn_target = match fn_target { + Some(x) => x, + None => { + return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find target function".to_owned(), + })); + } + }; + + crate::builder::generate_enzyme_call(llmod, llcx, fn_def, fn_target, item.attrs.clone()); + } + + // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts + + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + // This is our second opt call, so now we run all opts, + // to make sure we get the best performance. + let skip_size_increasing_opts = false; + trace!("running Module Optimization after differentiation"); + unsafe { + llvm_optimize( + cgcx, + diag_handler.handle(), + module, + config, + opt_level, + opt_stage, + skip_size_increasing_opts, + )? + }; + } + trace!("done with differentiate()"); + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -641,6 +748,8 @@ pub(crate) unsafe fn optimize( unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) }; } + // FIXME(ZuseZ4): support SanitizeHWAddress and prevent illegal/unsupported opts + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -648,7 +757,20 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }; + + // If we know that we will later run AD, then we disable vectorization and loop unrolling + let skip_size_increasing_opts = cfg!(llvm_enzyme); + return unsafe { + llvm_optimize( + cgcx, + dcx, + module, + config, + opt_level, + opt_stage, + skip_size_increasing_opts, + ) + }; } Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index b5bb7630ca6c9..6a2c84f610821 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -5,6 +5,7 @@ use std::{iter, ptr}; use libc::{c_char, c_uint}; use rustc_abi as abi; use rustc_abi::{Align, Size, WrappingRange}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::MemFlags; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -24,17 +25,276 @@ use rustc_span::Span; use rustc_target::callconv::FnAbi; use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; use crate::abi::FnAbiLlvmExt; use crate::attributes; use crate::common::Funclet; use crate::context::CodegenCx; -use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True}; +use crate::llvm::AttributePlace::Function; +use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = llvm::LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another +/// function with expected naming and calling conventions[^1] which will be +/// discovered by the enzyme LLVM pass and its body populated with the differentiated +/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated +/// function and handle the differences between the Rust calling convention and +/// Enzyme. +/// [^1]: +// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to +// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise. +pub(crate) fn generate_enzyme_call<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + fn_to_diff: &'ll Value, + outer_fn: &'ll Value, + attrs: AutoDiffAttrs, +) { + let inputs = attrs.input_activity; + let output = attrs.ret_activity; + + // We have to pick the name depending on whether we want forward or reverse mode autodiff. + // FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since + // it will handle higher-order derivatives correctly automatically (in theory). Currently + // higher-order derivatives fail, so we should debug that before adjusting this code. + let mut ad_name: String = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("logic bug in autodiff, unrecognized mode"), + } + .to_string(); + + // add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple + // functions. Unwrap will only panic, if LLVM gave us an invalid string. + let name = llvm::get_value_name(outer_fn); + let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap(); + ad_name.push_str(outer_fn_name.to_string().as_str()); + + // Let us assume the user wrote the following function square: + // + // ```llvm + // define double @square(double %x) { + // entry: + // %0 = fmul double %x, %x + // ret double %0 + // } + // ``` + // + // The user now applies autodiff to the function square, in which case fn_to_diff will be `square`. + // Our macro generates the following placeholder code (slightly simplified): + // + // ```llvm + // define double @dsquare(double %x) { + // ; placeholder code + // return 0.0; + // } + // ``` + // + // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder + // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call. + // Again, the arguments to all functions are slightly simplified. + // ```llvm + // declare double @__enzyme_autodiff_square(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x) + // ret double %0 + // } + // ``` + unsafe { + // On LLVM-IR, we can luckily declare __enzyme_ functions without specifying the input + // arguments. We do however need to declare them with their correct return type. + // We already figured the correct return type out in our frontend, when generating the outer_fn, + // so we can now just go ahead and use that. FIXME(ZuseZ4): This doesn't handle sret yet. + let fn_ty = llvm::LLVMGlobalGetValueType(outer_fn); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // LLVM can figure out the input types on it's own, so we take a shortcut here. + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + // Otherwise LLVM might inline our temporary code before the enzyme pass has a chance to + // do it's work. + let attr = llvm::AttributeKind::NoInline.create_attr(llcx); + attributes::apply_to_llfn(ad_fn, Function, &[attr]); + + // first, remove all calls from fnc + let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); + let br = llvm::LLVMRustGetTerminator(entry); + llvm::LLVMRustEraseInstFromParent(br); + + let builder = llvm::LLVMCreateBuilderInContext(llcx); + let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + + let num_args = llvm::LLVMCountParams(&fn_to_diff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(fn_to_diff); + + let enzyme_const = llvm::create_md_string(llcx, "enzyme_const"); + let enzyme_out = llvm::create_md_string(llcx, "enzyme_out"); + let enzyme_dup = llvm::create_md_string(llcx, "enzyme_dup"); + let enzyme_dupnoneed = llvm::create_md_string(llcx, "enzyme_dupnoneed"); + let enzyme_primal_ret = llvm::create_md_string(llcx, "enzyme_primal_return"); + + match output { + DiffActivity::Dual => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + } + DiffActivity::Active => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + } + _ => {} + } + + trace!("matching autodiff arguments"); + // We now handle the issue that Rust level arguments not always match the llvm-ir level + // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on + // llvm-ir level. The number of activities matches the number of Rust level arguments, so we + // need to match those. + let mut outer_pos: usize = 0; + let mut activity_pos = 0; + let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + while activity_pos < inputs.len() { + let activity = inputs[activity_pos as usize]; + // Duplicated arguments received a shadow argument, into which enzyme will write the + // gradient. + let (activity, duplicated): (&Metadata, bool) = match activity { + DiffActivity::None => panic!("not a valid input activity"), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + let outer_arg = outer_args[outer_pos]; + args.push(llvm::LLVMMetadataAsValue(llcx, activity)); + args.push(outer_arg); + if duplicated { + // We know that duplicated args by construction have a following argument, + // so this can not be out of bounds. + let next_outer_arg = outer_args[outer_pos + 1]; + let next_outer_ty = llvm::LLVMTypeOf(next_outer_arg); + // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since + // vectors behind references (&Vec) are already supported. Users can not pass a + // Vec by value for reverse mode, so this would only help forward mode autodiff. + let slice = { + if activity_pos + 1 >= inputs.len() { + // If there is no arg following our ptr, it also can't be a slice, + // since that would lead to a ptr, int pair. + false + } else { + let next_activity = inputs[activity_pos + 1]; + // We analyze the MIR types and add this dummy activity if we visit a slice. + next_activity == DiffActivity::FakeActivitySize + } + }; + if slice { + // A duplicated slice will have the following two outer_fn arguments: + // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call: + // (..., metadata! enzyme_dup, ptr, ptr, int1, ...). + // FIXME(ZuseZ4): We will upstream a safety check later which asserts that + // int2 >= int1, which means the shadow vector is large enough to store the gradient. + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer); + let next_outer_arg2 = outer_args[outer_pos + 2]; + let next_outer_ty2 = llvm::LLVMTypeOf(next_outer_arg2); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer); + let next_outer_arg3 = outer_args[outer_pos + 3]; + let next_outer_ty3 = llvm::LLVMTypeOf(next_outer_arg3); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer); + args.push(next_outer_arg2); + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_const)); + args.push(next_outer_arg); + outer_pos += 4; + activity_pos += 2; + } else { + // A duplicated pointer will have the following two outer_fn arguments: + // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call: + // (..., metadata! enzyme_dup, ptr, ptr, ...). + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + args.push(next_outer_arg); + outer_pos += 2; + activity_pos += 1; + } + } else { + // We do not differentiate with resprect to this argument. + // We already added the metadata and argument above, so just increase the counters. + outer_pos += 1; + activity_pos += 1; + } + } + + let call = llvm::LLVMBuildCall2( + builder, + enzyme_ty, + ad_fn, + args.as_mut_ptr(), + args.len().try_into().unwrap(), + ad_name.as_ptr() as *const c_char, + ); + + // This part is a bit iffy. LLVM requires that a call to an inlineable function has some + // metadata attachted to it, but we just created this code oota. Given that the + // differentiated function already has partly confusing metadata, and given that this + // affects nothing but the auttodiff IR, we take a shortcut and just steal metadata from the + // dummy code which we inserted at a higher level. + // FIXME(ZuseZ4): Work with Enzyme core devs to clarify what debug metadata issues we have, + // and how to best improve it for enzyme core and rust-enzyme. + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + if llvm::LLVMRustHasMetadata(last_inst, md_ty) { + let md = llvm::LLVMRustDIGetInstMetadata(last_inst); + let md_todiff = llvm::LLVMMetadataAsValue(llcx, md); + llvm::LLVMSetMetadata(call, md_ty, md_todiff); + } else { + // We don't panic, since depending on whether we are in debug or release mode, we might + // have no debug info to copy, which would then be ok. + trace!("no dbg info"); + } + // Now that we copied the metadata, get rid of dummy code. + llvm::LLVMRustEraseInstBefore(entry, last_inst); + llvm::LLVMRustEraseInstFromParent(last_inst); + + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + }; + llvm::LLVMDisposeBuilder(builder); + + // Let's crash in case that we messed something up above and generated invalid IR. + llvm::LLVMVerifyFunction(outer_fn, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + } +} + // All Builders must have an llfn associated with them #[must_use] pub(crate) struct Builder<'a, 'll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 3cdb5b971d908..f340b06e876cd 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -89,6 +89,11 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +#[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_lto)] +#[note] +pub(crate) struct AutoDiffWithoutLTO; + #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] pub(crate) struct LtoDisallowed; @@ -131,6 +136,8 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -152,6 +159,7 @@ impl Diagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; self.0 .into_diag(dcx, level) diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index 5235891a18df2..4b2b50ccf2c28 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -26,9 +26,10 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use errors::ParseTargetMachineConfig; +use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -42,7 +43,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::symbol::Symbol; mod back { @@ -231,6 +232,19 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + if cgcx.lto != Lto::Fat { + let dcx = cgcx.create_dcx(); + return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO)); + } + back::write::differentiate(module, cgcx, diff_fncs, config) + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs new file mode 100644 index 0000000000000..1cef0296fa7f5 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -0,0 +1,38 @@ +#![allow(non_camel_case_types)] + +use libc::{c_char, c_uint, size_t}; + +use super::ffi::{Attribute, BasicBlock, Builder, Metadata, Module, Type, Value}; +extern "C" { + // Enzyme + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; +} + +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index 909afe35a179b..069f90c6878e2 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -22,8 +22,11 @@ use crate::common::AsCCharPtr; pub mod archive_ro; pub mod diagnostic; +pub mod enzyme_ffi; mod ffi; +pub use self::enzyme_ffi::*; + impl LLVMRustResult { pub fn into_result(self) -> Result<(), ()> { match self { @@ -196,6 +199,10 @@ pub fn set_thread_local_mode(global: &Value, mode: ThreadLocalMode) { } } +pub fn create_md_string<'a>(llcx: &'a Context, s: &str) -> &'a Metadata { + unsafe { LLVMMDStringInContext2(llcx, s.as_c_char_ptr(), s.len()) } +} + impl AttributeKind { /// Create an LLVM Attribute with no associated value. pub fn create_attr(self, llcx: &Context) -> &Attribute { diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ab8b06a05fc74..9fd984b6419ee 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,11 +1,13 @@ use std::ffi::CString; use std::sync::Arc; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::memmap::Mmap; use rustc_errors::FatalError; use super::write::CodegenContext; use crate::ModuleCodegen; +use crate::back::write::ModuleConfig; use crate::traits::*; pub struct ThinModule { @@ -81,6 +83,23 @@ impl LtoModuleCodegen { LtoModuleCodegen::Thin(ref m) => m.cost(), } } + + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat(module) => { + B::autodiff(cgcx, &module, diff_fncs, config)?; + } + _ => panic!("autodiff called with non-fat LTO module"), + } + + Ok(self) + } } pub enum SerializedModule { diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index aabe9e33c4aa1..97fe614aa10cd 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,3 +1,4 @@ +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -61,6 +62,12 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index b79205ff946d3..439f664de4b4f 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -388,6 +388,17 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, AddAttributes(Call, Index, Attrs, AttrsLen); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttributeKind RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); @@ -954,6 +965,47 @@ extern "C" void LLVMRustAddModuleFlagString( MDString::get(unwrap(M)->getContext(), StringRef(Value, ValueLen))); } +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + // Make sure we found the Instruction. + assert(It != BB.end()); + // We don't want to erase the instruction itself. + It--; + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD));