diff --git a/Cargo.lock b/Cargo.lock index 4d54b5aeb4e1d..71ec774a33582 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4121,6 +4121,7 @@ dependencies = [ name = "rustc_monomorphize" version = "0.0.0" dependencies = [ + "rustc_ast", "rustc_data_structures", "rustc_errors", "rustc_fluent_macro", @@ -4129,6 +4130,7 @@ dependencies = [ "rustc_middle", "rustc_session", "rustc_span", + "rustc_symbol_mangling", "rustc_target", "serde", "serde_json", diff --git a/compiler/rustc_ast/src/expand/autodiff_attrs.rs b/compiler/rustc_ast/src/expand/autodiff_attrs.rs index 05714731b9d4d..dded4e2ba33a4 100644 --- a/compiler/rustc_ast/src/expand/autodiff_attrs.rs +++ b/compiler/rustc_ast/src/expand/autodiff_attrs.rs @@ -6,7 +6,6 @@ use std::fmt::{self, Display, Formatter}; use std::str::FromStr; -use crate::expand::typetree::TypeTree; use crate::expand::{Decodable, Encodable, HashStable_Generic}; use crate::ptr::P; use crate::{Ty, TyKind}; @@ -79,10 +78,6 @@ pub struct AutoDiffItem { /// The name of the function being generated pub target: String, pub attrs: AutoDiffAttrs, - /// Describe the memory layout of input types - pub inputs: Vec, - /// Describe the memory layout of the output type - pub output: TypeTree, } #[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)] pub struct AutoDiffAttrs { @@ -266,18 +261,14 @@ impl AutoDiffAttrs { self, source: String, target: String, - inputs: Vec, - output: TypeTree, ) -> AutoDiffItem { - AutoDiffItem { source, target, inputs, output, attrs: self } + AutoDiffItem { source, target, attrs: self } } } impl fmt::Display for AutoDiffItem { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "Differentiating {} -> {}", self.source, self.target)?; - write!(f, " with attributes: {:?}", self.attrs)?; - write!(f, " with inputs: {:?}", self.inputs)?; - write!(f, " with output: {:?}", self.output) + write!(f, " with attributes: {:?}", self.attrs) } } diff --git a/compiler/rustc_codegen_llvm/messages.ftl b/compiler/rustc_codegen_llvm/messages.ftl index 0950e4bb26bac..26d699d5cc1d1 100644 --- a/compiler/rustc_codegen_llvm/messages.ftl +++ b/compiler/rustc_codegen_llvm/messages.ftl @@ -51,6 +51,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO codegen_llvm_run_passes = failed to run LLVM passes codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err} +codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error} +codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error} +codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto + codegen_llvm_sanitizer_memtag_requires_mte = `-Zsanitizer=memtag` requires `-Ctarget-feature=+mte` diff --git a/compiler/rustc_codegen_llvm/src/attributes.rs b/compiler/rustc_codegen_llvm/src/attributes.rs index 2c5ec9dad59f1..a94ef2d5e8eea 100644 --- a/compiler/rustc_codegen_llvm/src/attributes.rs +++ b/compiler/rustc_codegen_llvm/src/attributes.rs @@ -1,5 +1,6 @@ //! Set and unset common attributes on LLVM values. +use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs; use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr}; use rustc_codegen_ssa::traits::*; use rustc_hir::def_id::DefId; @@ -332,6 +333,7 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>( instance: ty::Instance<'tcx>, ) { let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id()); + let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id()); let mut to_add = SmallVec::<[_; 16]>::new(); @@ -349,6 +351,8 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>( let inline = if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) { InlineAttr::Hint + } else if autodiff_attrs.is_active() { + InlineAttr::Never } else { codegen_fn_attrs.inline }; diff --git a/compiler/rustc_codegen_llvm/src/back/lto.rs b/compiler/rustc_codegen_llvm/src/back/lto.rs index 1f7a923dd2c68..5327ef18667c7 100644 --- a/compiler/rustc_codegen_llvm/src/back/lto.rs +++ b/compiler/rustc_codegen_llvm/src/back/lto.rs @@ -616,7 +616,11 @@ pub(crate) fn run_pass_manager( } let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO }; let opt_level = config.opt_level.unwrap_or(config::OptLevel::No); - write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?; + + // We will run this again with different values in the context of automatic differentiation. + let first_run = true; + debug!("running llvm pm opt pipeline"); + write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run)?; } debug!("lto done"); Ok(()) @@ -723,7 +727,11 @@ pub(crate) unsafe fn optimize_thin_module( let llcx = unsafe { llvm::LLVMRustContextCreate(cgcx.fewer_names) }; let llmod_raw = parse_module(llcx, module_name, thin_module.data(), dcx)? as *const _; let mut module = ModuleCodegen { - module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }, + module_llvm: ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + }, name: thin_module.name().to_string(), kind: ModuleKind::Regular, }; diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index afdd2b581b86e..2d9a9d914f444 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -4,10 +4,12 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use std::{fs, slice, str}; -use libc::{c_char, c_int, c_void, size_t}; +use libc::{c_char, c_int, c_uint, c_void, size_t}; use llvm::{ - LLVMRustLLVMHasZlibCompressionForDebugSymbols, LLVMRustLLVMHasZstdCompressionForDebugSymbols, + LLVMRustLLVMHasZlibCompressionForDebugSymbols, + LLVMRustLLVMHasZstdCompressionForDebugSymbols, }; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::link::ensure_removed; use rustc_codegen_ssa::back::write::{ BitcodeSection, CodegenContext, EmitObj, ModuleConfig, TargetMachineFactoryConfig, @@ -22,12 +24,13 @@ use rustc_fs_util::{link_or_copy, path_to_c_string}; use rustc_middle::ty::TyCtxt; use rustc_session::Session; use rustc_session::config::{ - self, Lto, OutputType, Passes, RemapPathScopeComponents, SplitDwarfKind, SwitchWithOptPath, + self, AutoDiff, Lto, OutputType, Passes, RemapPathScopeComponents, SplitDwarfKind, + SwitchWithOptPath, }; use rustc_span::InnerSpan; use rustc_span::symbol::sym; use rustc_target::spec::{CodeModel, RelocModel, SanitizerSet, SplitDebuginfo, TlsModel}; -use tracing::debug; +use tracing::{debug, trace}; use crate::back::lto::ThinBuffer; use crate::back::owned_target_machine::OwnedTargetMachine; @@ -39,7 +42,17 @@ use crate::errors::{ WithLlvmError, WriteBytecode, }; use crate::llvm::diagnostic::OptimizationDiagnosticKind::*; -use crate::llvm::{self, DiagnosticInfo, PassManager}; +use crate::llvm::{ + self, AttributeKind, DiagnosticInfo, + LLVMCreateStringAttribute, LLVMDumpModule, + LLVMGetFirstFunction, LLVMGetNextFunction, + LLVMGetStringAttributeAtIndex, LLVMIsEnumAttribute, + LLVMIsStringAttribute, + LLVMRemoveStringAttributeAtIndex, LLVMRustAddEnumAttributeAtIndex, + LLVMRustAddFunctionAttributes, + LLVMRustGetEnumAttributeAtIndex, + LLVMRustRemoveEnumAttributeAtIndex, PassManager, +}; use crate::type_::Type; use crate::{LlvmCodegenBackend, ModuleLlvm, base, common, llvm_util}; @@ -515,9 +528,34 @@ pub(crate) unsafe fn llvm_optimize( config: &ModuleConfig, opt_level: config::OptLevel, opt_stage: llvm::OptStage, + first_run: bool, ) -> Result<(), FatalError> { - let unroll_loops = - opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + // Enzyme: + // The whole point of compiler based AD is to differentiate optimized IR instead of unoptimized + // source code. However, benchmarks show that optimizations increasing the code size + // tend to reduce AD performance. Therefore deactivate them before AD, then differentiate the code + // and finally re-optimize the module, now with all optimizations available. + // TODO: In a future update we could figure out how to only optimize functions getting + // differentiated. + + let unroll_loops; + let vectorize_slp; + let vectorize_loop; + + if first_run { + unroll_loops = false; + vectorize_slp = false; + vectorize_loop = false; + } else { + unroll_loops = + opt_level != config::OptLevel::Size && opt_level != config::OptLevel::SizeMin; + vectorize_slp = config.vectorize_slp; + vectorize_loop = config.vectorize_loop; + } + trace!( + "Enzyme: Running with unroll_loops: {}, vectorize_slp: {}, vectorize_loop: {}", + unroll_loops, vectorize_slp, vectorize_loop + ); let using_thin_buffers = opt_stage == llvm::OptStage::PreLinkThinLTO || config.bitcode_needed(); let pgo_gen_path = get_pgo_gen_path(config); let pgo_use_path = get_pgo_use_path(config); @@ -581,8 +619,8 @@ pub(crate) unsafe fn llvm_optimize( using_thin_buffers, config.merge_functions, unroll_loops, - config.vectorize_slp, - config.vectorize_loop, + vectorize_slp, + vectorize_loop, config.no_builtins, config.emit_lifetime_markers, sanitizer_options.as_ref(), @@ -605,6 +643,162 @@ pub(crate) unsafe fn llvm_optimize( result.into_result().map_err(|()| llvm_err(dcx, LlvmError::RunLlvmPasses)) } +#[allow(non_snake_case)] +unsafe fn EnzymeSetCLBool(ptr: *mut c_int, val: u8) { + unsafe { + //let ptr = ptr as *mut c_int; + *ptr = val as c_int; + } +} + +pub(crate) fn differentiate( + module: &ModuleCodegen, + cgcx: &CodegenContext, + diff_items: Vec, + config: &ModuleConfig, +) -> Result<(), FatalError> { + for item in &diff_items { + trace!("{}", item); + } + + let llmod = module.module_llvm.llmod(); + let llcx = &module.module_llvm.llcx; + let diag_handler = cgcx.create_dcx(); + + //llvm::set_strict_aliasing(false); + + let ad = &config.autodiff; + + if ad.contains(&AutoDiff::LooseTypes) { + dbg!("Setting loose types to true"); + #[allow(non_upper_case_globals)] + static mut looseTypeAnalysis: c_int = 0; + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), true as u8); + } + } + + // Before dumping the module, we want all the tt to become part of the module. + for item in diff_items.iter() { + let name = CString::new(item.source.clone()).unwrap(); + let fn_def: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, name.as_ptr()) }; + let fn_def = match fn_def { + Some(x) => x, + None => return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find source function".to_owned(), + })), + }; + let tgt_name = CString::new(item.target.clone()).unwrap(); + dbg!("Target name: {:?}", &tgt_name); + let fn_target: Option<&llvm::Value> = + unsafe { llvm::LLVMGetNamedFunction(llmod, tgt_name.as_ptr()) }; + let fn_target = match fn_target { + Some(x) => x, + None => return Err(llvm_err(diag_handler.handle(), LlvmError::PrepareAutoDiff { + src: item.source.clone(), + target: item.target.clone(), + error: "could not find target function".to_owned(), + })), + }; + + crate::builder::add_opt_dbg_helper2( + llmod, + llcx, + fn_def, + fn_target, + item.attrs.clone(), + ); + } + + if ad.contains(&AutoDiff::PrintModBefore) { + unsafe { + LLVMDumpModule(llmod); + } + } + + if ad.contains(&AutoDiff::Inline) { + trace!("Setting Enzyme inline to true"); + #[allow(non_upper_case_globals)] + static mut EnzymeInline: c_int = 0; + unsafe { + EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), true as u8); + } + } + + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let attr = LLVMGetStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + if LLVMIsStringAttribute(attr) { + LLVMRemoveStringAttributeAtIndex( + lf, + c_uint::MAX, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + ); + } else { + LLVMRustRemoveEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + + if let Some(opt_level) = config.opt_level { + let opt_stage = match cgcx.lto { + Lto::Fat => llvm::OptStage::PreLinkFatLTO, + Lto::Thin | Lto::ThinLocal => llvm::OptStage::PreLinkThinLTO, + _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, + _ => llvm::OptStage::PreLinkNoLTO, + }; + let mut first_run = false; + dbg!("Running Module Optimization after differentiation"); + if ad.contains(&AutoDiff::NoVecUnroll) { + // disables vectorization and loop unrolling + first_run = true; + } + unsafe { + llvm_optimize( + cgcx, + diag_handler.handle(), + module, + config, + opt_level, + opt_stage, + first_run, + )? + }; + } + if ad.contains(&AutoDiff::PrintModAfterEnzyme) { + unsafe {LLVMDumpModule(llmod)}; + } + + if ad.contains(&AutoDiff::PrintModAfterOpts) { + unsafe { + LLVMDumpModule(llmod); + } + } + dbg!("Done with differentiate()"); + + Ok(()) +} + // Unsafe due to LLVM calls. pub(crate) unsafe fn optimize( cgcx: &CodegenContext, @@ -627,6 +821,47 @@ pub(crate) unsafe fn optimize( unsafe { llvm::LLVMWriteBitcodeToFile(llmod, out.as_ptr()) }; } + // This code enables Enzyme to differentiate code containing Rust enums. + // By adding the SanitizeHWAddress attribute we prevent LLVM from Optimizing + // away the enums and allows Enzyme to understand why a value can be of different types in + // different code sections. We remove this attribute after Enzyme is done, to not affect the + // rest of the compilation. + // TODO: only enable this code when at least one function gets differentiated. + unsafe { + let mut f = LLVMGetFirstFunction(llmod); + loop { + if let Some(lf) = f { + f = LLVMGetNextFunction(lf); + let myhwattr = "enzyme_hw"; + let myhwv = ""; + let prevattr = LLVMRustGetEnumAttributeAtIndex( + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + if LLVMIsEnumAttribute(prevattr) { + let attr = LLVMCreateStringAttribute( + llcx, + myhwattr.as_ptr() as *const c_char, + myhwattr.as_bytes().len() as c_uint, + myhwv.as_ptr() as *const c_char, + myhwv.as_bytes().len() as c_uint, + ); + LLVMRustAddFunctionAttributes(lf, c_uint::MAX, &attr, 1); + } else { + LLVMRustAddEnumAttributeAtIndex( + llcx, + lf, + c_uint::MAX, + AttributeKind::SanitizeHWAddress, + ); + } + } else { + break; + } + } + } + if let Some(opt_level) = config.opt_level { let opt_stage = match cgcx.lto { Lto::Fat => llvm::OptStage::PreLinkFatLTO, @@ -634,7 +869,12 @@ pub(crate) unsafe fn optimize( _ if cgcx.opts.cg.linker_plugin_lto.enabled() => llvm::OptStage::PreLinkThinLTO, _ => llvm::OptStage::PreLinkNoLTO, }; - return unsafe { llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage) }; + + // Second run only relevant for AD + let first_run = true; + return unsafe { + llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run) + }; } Ok(()) } diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 8702532c36eee..2c248000d5c54 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -5,6 +5,7 @@ use std::{iter, ptr}; use libc::{c_char, c_uint}; use rustc_abi as abi; use rustc_abi::{Align, Size, WrappingRange}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode}; use rustc_codegen_ssa::MemFlags; use rustc_codegen_ssa::common::{IntPredicate, RealPredicate, SynchronizationScope, TypeKind}; use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue}; @@ -24,17 +25,218 @@ use rustc_span::Span; use rustc_target::abi::call::FnAbi; use rustc_target::spec::{HasTargetSpec, SanitizerSet, Target}; use smallvec::SmallVec; -use tracing::{debug, instrument}; +use tracing::{debug, instrument, trace}; use crate::abi::FnAbiLlvmExt; use crate::attributes; use crate::common::Funclet; use crate::context::CodegenCx; -use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, True}; +use crate::llvm::{self, AtomicOrdering, AtomicRmwBinOp, BasicBlock, False, Metadata, True}; use crate::type_::Type; use crate::type_of::LayoutLlvmExt; use crate::value::Value; +fn get_params(fnc: &Value) -> Vec<&Value> { + unsafe { + let param_num = llvm::LLVMCountParams(fnc) as usize; + let mut fnc_args: Vec<&Value> = vec![]; + fnc_args.reserve(param_num); + llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr()); + fnc_args.set_len(param_num); + fnc_args + } +} + +pub(crate) fn add_opt_dbg_helper2<'ll>( + llmod: &'ll llvm::Module, + llcx: &'ll llvm::Context, + todiff: &'ll Value, + outer_fn: &'ll Value, + attrs: AutoDiffAttrs, +) { + let inputs = attrs.input_activity; + let output = attrs.ret_activity; + let mut ad_name: String = match attrs.mode { + DiffMode::Forward => "__enzyme_fwddiff", + DiffMode::Reverse => "__enzyme_autodiff", + DiffMode::ForwardFirst => "__enzyme_fwddiff", + DiffMode::ReverseFirst => "__enzyme_autodiff", + _ => panic!("Why are we here?"), + }.to_string(); + // add outer_fn name to ad_name to make it unique + let outer_fn_name = unsafe { + let mut len: usize = 0; + let name = llvm::LLVMGetValueName2(outer_fn, &mut len as *mut usize); + std::ffi::CStr::from_ptr(name).to_str().unwrap() + }; + ad_name.push_str(outer_fn_name.to_string().as_str()); + + // Assuming that our todiff is the fnc square, want to generate the following llvm-ir: + // declare double @__enzyme_autodiff(...) + // + // define double @dsquare(double %x) { + // entry: + // %0 = tail call double (...) @__enzyme_autodiff(double (double)* nonnull @square, double %x) + // ret double %0 + // } + + unsafe { + let fn_ty = llvm::LLVMRustGetFunctionType(outer_fn); + let ret_ty = llvm::LLVMGetReturnType(fn_ty); + + // First we add the declaration of the __enzyme function + let enzyme_ty = llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, True); + let ad_fn = llvm::LLVMRustGetOrInsertFunction( + llmod, + ad_name.as_ptr() as *const c_char, + ad_name.len().try_into().unwrap(), + enzyme_ty, + ); + llvm::LLVMRustAddEnumAttributeAtIndex( + llcx, + ad_fn, + c_uint::MAX, + llvm::AttributeKind::NoInline, + ); + + // first, remove all calls from fnc + let entry = llvm::LLVMGetFirstBasicBlock(outer_fn); + let br = llvm::LLVMRustGetTerminator(entry); + llvm::LLVMRustEraseInstFromParent(br); + + let builder = llvm::LLVMCreateBuilderInContext(llcx); + let last_inst = llvm::LLVMRustGetLastInstruction(entry).unwrap(); + llvm::LLVMPositionBuilderAtEnd(builder, entry); + + let num_args = llvm::LLVMCountParams(todiff); + let mut args = Vec::with_capacity(num_args as usize + 1); + args.push(todiff); + let enzyme_const = + llvm::LLVMMDStringInContext2(llcx, "enzyme_const".as_ptr() as *const c_char, 12); + let enzyme_out = + llvm::LLVMMDStringInContext2(llcx, "enzyme_out".as_ptr() as *const c_char, 10); + let enzyme_dup = + llvm::LLVMMDStringInContext2(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); + let enzyme_dupnoneed = + llvm::LLVMMDStringInContext2(llcx, "enzyme_dupnoneed".as_ptr() as *const c_char, 16); + let enzyme_primal_ret = + llvm::LLVMMDStringInContext2(llcx, "enzyme_primal_return".as_ptr() as *const c_char, 20); + + match output { + DiffActivity::Dual => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + }, + DiffActivity::Active => { + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_primal_ret)); + }, + _ => {}, + } + + trace!("Matching args"); + // Idea: We follow the outer function's arguments and activities. + // We still keep track of our inner function's arguments, but just + // to verify that they match. + let mut outer_pos: usize = 0; + let mut activity_pos = 0; + let outer_args: Vec<&llvm::Value> = get_params(outer_fn); + + while activity_pos < inputs.len() { + let activity = inputs[activity_pos as usize]; + let (activity, duplicated): (&Metadata, bool) = match activity { + DiffActivity::None => panic!(), + DiffActivity::Const => (enzyme_const, false), + DiffActivity::Active => (enzyme_out, false), + DiffActivity::ActiveOnly => (enzyme_out, false), + DiffActivity::Dual => (enzyme_dup, true), + DiffActivity::DualOnly => (enzyme_dupnoneed, true), + DiffActivity::Duplicated => (enzyme_dup, true), + DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true), + DiffActivity::FakeActivitySize => (enzyme_const, false), + }; + let outer_arg = outer_args[outer_pos]; + args.push(llvm::LLVMMetadataAsValue(llcx, activity)); + args.push(outer_arg); + // Now if we have a slice and duplicate, then it get's interesting. + // + if duplicated { + let next_outer_arg = outer_args[outer_pos + 1]; + let next_outer_ty = llvm::LLVMTypeOf(next_outer_arg); + let slice = { + if activity_pos + 1 >= inputs.len() { + // If there is no arg following our ptr, it also can't be a slice, + // since that would lead to a ptr, int pair. + false + } else { + let next_activity = inputs[activity_pos + 1]; + next_activity == DiffActivity::FakeActivitySize + } + }; + if slice { + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer); + let next_outer_arg2 = outer_args[outer_pos + 2]; + let next_outer_ty2 = llvm::LLVMTypeOf(next_outer_arg2); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty2) == llvm::TypeKind::Pointer); + let next_outer_arg3 = outer_args[outer_pos + 3]; + let next_outer_ty3 = llvm::LLVMTypeOf(next_outer_arg3); + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty3) == llvm::TypeKind::Integer); + args.push(next_outer_arg2); + args.push(llvm::LLVMMetadataAsValue(llcx, enzyme_const)); + args.push(next_outer_arg); + outer_pos += 4; + activity_pos += 2; + } else { + assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer); + args.push(next_outer_arg); + outer_pos += 2; + activity_pos += 1; + } + } else { + outer_pos += 1; + activity_pos += 1; + } + } + + let call = llvm::LLVMBuildCall2( + builder, + enzyme_ty, + ad_fn, + args.as_mut_ptr(), + args.len().try_into().unwrap(), + ad_name.as_ptr() as *const c_char, + ); + + // Add dummy dbg info to our newly generated call, if we have any. + let md_ty = llvm::LLVMGetMDKindIDInContext( + llcx, + "dbg".as_ptr() as *const c_char, + "dbg".len() as c_uint, + ); + + if llvm::LLVMRustHasMetadata(last_inst, md_ty) { + let md = llvm::LLVMRustDIGetInstMetadata(last_inst); + let md_todiff = llvm::LLVMMetadataAsValue(llcx, md); + let _md2 = llvm::LLVMSetMetadata(call, md_ty, md_todiff); + } else { + trace!("No dbg info"); + } + llvm::LLVMRustEraseInstBefore(entry, last_inst); + + + let void_ty = llvm::LLVMVoidTypeInContext(llcx); + if llvm::LLVMTypeOf(call) != void_ty { + llvm::LLVMBuildRet(builder, call); + } else { + llvm::LLVMBuildRetVoid(builder); + }; + llvm::LLVMDisposeBuilder(builder); + + let _fnc_ok = + llvm::LLVMVerifyFunction(outer_fn, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + let _fnc_ok = + llvm::LLVMVerifyFunction(todiff, llvm::LLVMVerifierFailureAction::LLVMAbortProcessAction); + } +} + // All Builders must have an llfn associated with them #[must_use] pub(crate) struct Builder<'a, 'll, 'tcx> { diff --git a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs index 61e474031bb08..0545e79b80da0 100644 --- a/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs +++ b/compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs @@ -427,7 +427,7 @@ struct UsageSets<'tcx> { /// Prepare sets of definitions that are relevant to deciding whether something /// is an "unused function" for coverage purposes. fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> { - let (all_mono_items, cgus) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, cgus) = tcx.collect_and_partition_mono_items(()); // Obtain a MIR body for each function participating in codegen, via an // arbitrary instance. diff --git a/compiler/rustc_codegen_llvm/src/errors.rs b/compiler/rustc_codegen_llvm/src/errors.rs index 0d436e1891ece..718504cd2c7fe 100644 --- a/compiler/rustc_codegen_llvm/src/errors.rs +++ b/compiler/rustc_codegen_llvm/src/errors.rs @@ -80,6 +80,11 @@ impl Diagnostic<'_, G> for ParseTargetMachineConfig<'_> { } } +#[derive(Diagnostic)] +#[diag(codegen_llvm_autodiff_without_lto)] +#[note] +pub(crate) struct AutoDiffWithoutLTO; + #[derive(Diagnostic)] #[diag(codegen_llvm_lto_disallowed)] pub(crate) struct LtoDisallowed; @@ -122,6 +127,8 @@ pub enum LlvmError<'a> { PrepareThinLtoModule, #[diag(codegen_llvm_parse_bitcode)] ParseBitcode, + #[diag(codegen_llvm_prepare_autodiff)] + PrepareAutoDiff { src: String, target: String, error: String }, } pub(crate) struct WithLlvmError<'a>(pub LlvmError<'a>, pub String); @@ -143,6 +150,7 @@ impl Diagnostic<'_, G> for WithLlvmError<'_> { } PrepareThinLtoModule => fluent::codegen_llvm_prepare_thin_lto_module_with_llvm_err, ParseBitcode => fluent::codegen_llvm_parse_bitcode_with_llvm_err, + PrepareAutoDiff { .. } => fluent::codegen_llvm_prepare_autodiff_with_llvm_err, }; self.0 .into_diag(dcx, level) diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index b85d28a2f1f71..4f6784934d666 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -27,9 +27,10 @@ use std::mem::ManuallyDrop; use back::owned_target_machine::OwnedTargetMachine; use back::write::{create_informational_target_machine, create_target_machine}; -use errors::ParseTargetMachineConfig; +use errors::{AutoDiffWithoutLTO, ParseTargetMachineConfig}; pub use llvm_util::target_features; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_codegen_ssa::back::lto::{LtoModuleCodegen, SerializedModule, ThinModule}; use rustc_codegen_ssa::back::write::{ CodegenContext, FatLtoInput, ModuleConfig, TargetMachineFactoryConfig, TargetMachineFactoryFn, @@ -43,7 +44,7 @@ use rustc_middle::dep_graph::{WorkProduct, WorkProductId}; use rustc_middle::ty::TyCtxt; use rustc_middle::util::Providers; use rustc_session::Session; -use rustc_session::config::{OptLevel, OutputFilenames, PrintKind, PrintRequest}; +use rustc_session::config::{Lto, OptLevel, OutputFilenames, PrintKind, PrintRequest}; use rustc_span::symbol::Symbol; mod back { @@ -211,6 +212,7 @@ impl WriteBackendMethods for LlvmCodegenBackend { ) -> Result<(Vec>, Vec), FatalError> { back::lto::run_thin(cgcx, modules, cached_modules) } + // Manuel unsafe fn optimize( cgcx: &CodegenContext, dcx: DiagCtxtHandle<'_>, @@ -250,6 +252,19 @@ impl WriteBackendMethods for LlvmCodegenBackend { fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer) { (module.name, back::lto::ModuleBuffer::new(module.module_llvm.llmod())) } + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError> { + if cgcx.lto != Lto::Fat { + let dcx = cgcx.create_dcx(); + return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutLTO {})); + } + back::write::differentiate(module, cgcx, diff_fncs, config) + } } unsafe impl Send for LlvmCodegenBackend {} // Llvm is on a per-thread basis @@ -405,6 +420,7 @@ impl CodegenBackend for LlvmCodegenBackend { } } +#[allow(dead_code)] pub struct ModuleLlvm { llcx: &'static mut llvm::Context, llmod_raw: *const llvm::Module, @@ -459,7 +475,11 @@ impl ModuleLlvm { } }; - Ok(ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) }) + Ok(ModuleLlvm { + llmod_raw, + llcx, + tm: ManuallyDrop::new(tm), + }) } } diff --git a/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs new file mode 100644 index 0000000000000..1bd1c44a5a691 --- /dev/null +++ b/compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs @@ -0,0 +1,66 @@ +#![allow(non_camel_case_types)] + +use libc::{c_char, c_uint, size_t}; +use super::ffi::*; + +extern "C" { + // Enzyme + pub fn LLVMRustAddFncParamAttr<'a>(F: &'a Value, index: c_uint, Attr: &'a Attribute); + pub fn LLVMRustAddRetFncAttr(F: &Value, attr: &Attribute); + pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool; + pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value); + pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; + pub fn LLVMRustDIGetInstMetadata(I: &Value) -> &Metadata; + pub fn LLVMRustEraseInstFromParent(V: &Value); + pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value; + pub fn LLVMRustAddEnumAttributeAtIndex( + C: &Context, + V: &Value, + index: c_uint, + attr: AttributeKind, + ); + pub fn LLVMRustRemoveEnumAttributeAtIndex(V: &Value, index: c_uint, attr: AttributeKind); + pub fn LLVMRustGetEnumAttributeAtIndex( + V: &Value, + index: c_uint, + attr: AttributeKind, + ) -> &Attribute; + pub fn LLVMRustAddParamAttr<'a>(Instr: &'a Value, index: c_uint, Attr: &'a Attribute); + + + pub fn LLVMGetReturnType(T: &Type) -> &Type; + pub fn LLVMDumpModule(M: &Module); + pub fn LLVMCountStructElementTypes(T: &Type) -> c_uint; + pub fn LLVMVerifyFunction(V: &Value, action: LLVMVerifierFailureAction) -> bool; + pub fn LLVMGetParams(Fnc: &Value, parms: *mut &Value); + pub fn LLVMBuildCall2<'a>( + arg1: &Builder<'a>, + ty: &Type, + func: &Value, + args: *mut &Value, + num_args: size_t, + name: *const c_char, + ) -> &'a Value; + pub fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>; + pub fn LLVMGetNextFunction(V: &Value) -> Option<&Value>; + pub fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>; + pub fn LLVMRustGetFunctionType(fnc: &Value) -> &Type; + + pub fn LLVMRemoveStringAttributeAtIndex(F: &Value, Idx: c_uint, K: *const c_char, KLen: c_uint); + pub fn LLVMGetStringAttributeAtIndex( + F: &Value, + Idx: c_uint, + K: *const c_char, + KLen: c_uint, + ) -> &Attribute; + pub fn LLVMIsEnumAttribute(A: &Attribute) -> bool; + pub fn LLVMIsStringAttribute(A: &Attribute) -> bool; + +} + +#[repr(C)] +pub enum LLVMVerifierFailureAction { + LLVMAbortProcessAction, + LLVMPrintMessageAction, + LLVMReturnStatusAction, +} diff --git a/compiler/rustc_codegen_llvm/src/llvm/mod.rs b/compiler/rustc_codegen_llvm/src/llvm/mod.rs index e837022044ee0..7e582f949a101 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/mod.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/mod.rs @@ -20,8 +20,10 @@ pub use self::RealPredicate::*; pub mod archive_ro; pub mod diagnostic; +pub mod enzyme_ffi; mod ffi; +pub use self::enzyme_ffi::*; pub use self::ffi::*; impl LLVMRustResult { diff --git a/compiler/rustc_codegen_ssa/messages.ftl b/compiler/rustc_codegen_ssa/messages.ftl index d07274920feaf..d376fdf0f1428 100644 --- a/compiler/rustc_codegen_ssa/messages.ftl +++ b/compiler/rustc_codegen_ssa/messages.ftl @@ -351,3 +351,6 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp codegen_ssa_version_script_write_failure = failed to write version script: {$error} codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload + +codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto + diff --git a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs index 11bcd727501c9..27767e498fa36 100644 --- a/compiler/rustc_codegen_ssa/src/assert_module_sources.rs +++ b/compiler/rustc_codegen_ssa/src/assert_module_sources.rs @@ -48,7 +48,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr } let available_cgus = - tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect(); + tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect(); let mut ams = AssertModuleSource { tcx, diff --git a/compiler/rustc_codegen_ssa/src/back/lto.rs b/compiler/rustc_codegen_ssa/src/back/lto.rs index ab8b06a05fc74..1356e7240e422 100644 --- a/compiler/rustc_codegen_ssa/src/back/lto.rs +++ b/compiler/rustc_codegen_ssa/src/back/lto.rs @@ -1,11 +1,13 @@ use std::ffi::CString; use std::sync::Arc; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::memmap::Mmap; use rustc_errors::FatalError; use super::write::CodegenContext; use crate::ModuleCodegen; +use crate::back::write::ModuleConfig; use crate::traits::*; pub struct ThinModule { @@ -72,6 +74,23 @@ impl LtoModuleCodegen { } } + /// Run autodiff on Fat LTO module + pub unsafe fn autodiff( + self, + cgcx: &CodegenContext, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result, FatalError> { + match &self { + LtoModuleCodegen::Fat(module) => { + B::autodiff(cgcx, &module, diff_fncs, config)?; + } + _ => panic!("Unreachable? Autodiff called with non-fat LTO module"), + } + + Ok(self) + } + /// A "gauge" of how costly it is to optimize this module, used to sort /// biggest modules first. pub fn cost(&self) -> u64 { diff --git a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs index 77c35a1fe79e9..c638546f779ff 100644 --- a/compiler/rustc_codegen_ssa/src/back/symbol_export.rs +++ b/compiler/rustc_codegen_ssa/src/back/symbol_export.rs @@ -293,7 +293,7 @@ fn exported_symbols_provider_local( // external linkage is enough for monomorphization to be linked to. let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib; - let (_, cgus) = tcx.collect_and_partition_mono_items(()); + let (_, _, cgus) = tcx.collect_and_partition_mono_items(()); // The symbols created in this loop are sorted below it #[allow(rustc::potential_query_instability)] diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index e3d11cfaf4fe3..e05db9529c4ed 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -8,6 +8,7 @@ use std::{fs, io, mem, str, thread}; use jobserver::{Acquired, Client}; use rustc_ast::attr; +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_data_structures::fx::{FxHashMap, FxIndexMap}; use rustc_data_structures::memmap::Mmap; use rustc_data_structures::profiling::{SelfProfilerRef, VerboseTimingGuard}; @@ -41,7 +42,7 @@ use tracing::debug; use super::link::{self, ensure_removed}; use super::lto::{self, SerializedModule}; use super::symbol_export::symbol_name_for_instance_in_crate; -use crate::errors::ErrorCreatingRemarkDir; +use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir}; use crate::traits::*; use crate::{ CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind, @@ -120,6 +121,7 @@ pub struct ModuleConfig { pub merge_functions: bool, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub autodiff: Vec, } impl ModuleConfig { @@ -280,6 +282,7 @@ impl ModuleConfig { emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), + autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), } } @@ -401,6 +404,7 @@ impl CodegenContext { fn generate_lto_work( cgcx: &CodegenContext, + autodiff: Vec, needs_fat_lto: Vec>, needs_thin_lto: Vec<(String, B::ThinBuffer)>, import_only_modules: Vec<(SerializedModule, WorkProduct)>, @@ -409,11 +413,19 @@ fn generate_lto_work( if !needs_fat_lto.is_empty() { assert!(needs_thin_lto.is_empty()); - let module = + let mut module = B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise()); + if cgcx.lto == Lto::Fat { + let config = cgcx.config(ModuleKind::Regular); + module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() }; + } // We are adding a single work item, so the cost doesn't matter. vec![(WorkItem::LTO(module), 0)] } else { + if !autodiff.is_empty() { + let dcx = cgcx.create_dcx(); + dcx.handle().emit_fatal(AutodiffWithoutLto {}); + } assert!(needs_fat_lto.is_empty()); let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules) .unwrap_or_else(|e| e.raise()); @@ -1041,6 +1053,9 @@ pub(crate) enum Message { /// Sent from a backend worker thread. WorkItem { result: Result, Option>, worker_id: usize }, + /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme. + AddAutoDiffItems(Vec), + /// The frontend has finished generating something (backend IR or a /// post-LTO artifact) for a codegen unit, and it should be passed to the /// backend. Sent from the main thread. @@ -1367,6 +1382,7 @@ fn start_executing_work( // This is where we collect codegen units that have gone all the way // through codegen and LLVM. + let mut autodiff_items = Vec::new(); let mut compiled_modules = vec![]; let mut compiled_allocator_module = None; let mut needs_link = Vec::new(); @@ -1478,9 +1494,13 @@ fn start_executing_work( let needs_thin_lto = mem::take(&mut needs_thin_lto); let import_only_modules = mem::take(&mut lto_import_only_modules); - for (work, cost) in - generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules) - { + for (work, cost) in generate_lto_work( + &cgcx, + autodiff_items.clone(), + needs_fat_lto, + needs_thin_lto, + import_only_modules, + ) { let insertion_index = work_items .binary_search_by_key(&cost, |&(_, cost)| cost) .unwrap_or_else(|e| e); @@ -1615,6 +1635,10 @@ fn start_executing_work( main_thread_state = MainThreadState::Idle; } + Message::AddAutoDiffItems(mut items) => { + autodiff_items.append(&mut items); + } + Message::CodegenComplete => { if codegen_state != Aborted { codegen_state = Completed; @@ -2092,6 +2116,10 @@ impl OngoingCodegen { drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::))); } + pub(crate) fn submit_autodiff_items(&self, items: Vec) { + drop(self.coordinator.sender.send(Box::new(Message::::AddAutoDiffItems(items)))); + } + pub(crate) fn check_for_errors(&self, sess: &Session) { self.shared_emitter_main.check(sess, false); } diff --git a/compiler/rustc_codegen_ssa/src/base.rs b/compiler/rustc_codegen_ssa/src/base.rs index a726ee73aaa26..e201d93986e07 100644 --- a/compiler/rustc_codegen_ssa/src/base.rs +++ b/compiler/rustc_codegen_ssa/src/base.rs @@ -621,7 +621,8 @@ pub fn codegen_crate( // Run the monomorphization collector and partition the collected items into // codegen units. - let codegen_units = tcx.collect_and_partition_mono_items(()).1; + let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(()); + let autodiff_fncs = autodiff_fncs.to_vec(); // Force all codegen_unit queries so they are already either red or green // when compile_codegen_unit accesses them. We are not able to re-execute @@ -692,6 +693,10 @@ pub fn codegen_crate( ); } + if !autodiff_fncs.is_empty() { + ongoing_codegen.submit_autodiff_items(autodiff_fncs); + } + // For better throughput during parallel processing by LLVM, we used to sort // CGUs largest to smallest. This would lead to better thread utilization // by, for example, preventing a large CGU from being processed last and @@ -1051,7 +1056,7 @@ pub(crate) fn provide(providers: &mut Providers) { config::OptLevel::SizeMin => config::OptLevel::Default, }; - let (defids, _) = tcx.collect_and_partition_mono_items(cratenum); + let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum); let any_for_speed = defids.items().any(|id| { let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id); diff --git a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs index a5bd3adbcddc9..3edaeb9fee992 100644 --- a/compiler/rustc_codegen_ssa/src/codegen_attrs.rs +++ b/compiler/rustc_codegen_ssa/src/codegen_attrs.rs @@ -1,4 +1,9 @@ -use rustc_ast::{MetaItemInner, MetaItemKind, ast, attr}; +use std::str::FromStr; + +use rustc_ast::expand::autodiff_attrs::{ + AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, +}; +use rustc_ast::{MetaItem, MetaItemInner, MetaItemKind, ast, attr}; use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr, list_contains_name}; use rustc_data_structures::fx::FxHashMap; use rustc_errors::codes::*; @@ -779,6 +784,133 @@ fn check_link_name_xor_ordinal( } } +/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)] +/// macros. There are two forms. The pure one without args to mark primal functions (the functions +/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the +/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never +/// panic, unless we introduced a bug when parsing the autodiff macro. +fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs { + let attrs = tcx.get_attrs(id, sym::rustc_autodiff); + + let attrs = + attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::>(); + + // check for exactly one autodiff attribute on placeholder functions. + // There should only be one, since we generate a new placeholder per ad macro. + // TODO: re-enable this. We should fix that rustc_autodiff isn't applied multiple times to the + // source function. + let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once"; + let attr = match attrs.len() { + 0 => return AutoDiffAttrs::error(), + 1 => attrs.get(0).unwrap(), + _ => { + attrs.get(0).unwrap() + //tcx.dcx().struct_span_err(attrs[1].span, msg_once).with_note("more than one").emit(); + //return AutoDiffAttrs::error(); + } + }; + + let list = attr.meta_item_list().unwrap_or_default(); + + // empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions + if list.len() == 0 { + return AutoDiffAttrs::source(); + } + + let [mode, input_activities @ .., ret_activity] = &list[..] else { + tcx.dcx() + .struct_span_err(attr.span, msg_once) + .with_note("Implementation bug in autodiff_attrs. Please report this!") + .emit(); + return AutoDiffAttrs::error(); + }; + let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain autodiff mode"; + tcx.dcx().struct_span_err(attr.span, msg).with_note("empty argument list").emit(); + return AutoDiffAttrs::error(); + }; + + // parse mode + let msg_mode = "mode should be either forward or reverse"; + let mode = match mode.as_str() { + "Forward" => DiffMode::Forward, + "Reverse" => DiffMode::Reverse, + "ForwardFirst" => DiffMode::ForwardFirst, + "ReverseFirst" => DiffMode::ReverseFirst, + _ => { + tcx.dcx().struct_span_err(attr.span, msg_mode).with_note("invalid mode").emit(); + return AutoDiffAttrs::error(); + } + }; + + // First read the ret symbol from the attribute + let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity { + p1.segments.first().unwrap().ident + } else { + let msg = "autodiff attribute must contain the return activity"; + tcx.dcx().struct_span_err(attr.span, msg).with_note("missing return activity").emit(); + return AutoDiffAttrs::error(); + }; + + // Then parse it into an actual DiffActivity + let msg_unknown_ret_activity = "unknown return activity"; + let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) { + Ok(x) => x, + Err(_) => { + tcx.dcx() + .struct_span_err(attr.span, msg_unknown_ret_activity) + .with_note("invalid return activity") + .emit(); + return AutoDiffAttrs::error(); + } + }; + + // Now parse all the intermediate (input) activities + let msg_arg_activity = "autodiff attribute must contain the return activity"; + let mut arg_activities: Vec = vec![]; + for arg in input_activities { + let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg { + p2.segments.first().unwrap().ident + } else { + tcx.dcx() + .struct_span_err(attr.span, msg_arg_activity) + .with_note("Implementation bug, please report this!") + .emit(); + return AutoDiffAttrs::error(); + }; + + match DiffActivity::from_str(arg_symbol.as_str()) { + Ok(arg_activity) => arg_activities.push(arg_activity), + Err(_) => { + tcx.dcx() + .struct_span_err(attr.span, msg_unknown_ret_activity) + .with_note("invalid input activity") + .emit(); + return AutoDiffAttrs::error(); + } + } + } + + let mut msg = "".to_string(); + for &input in &arg_activities { + if !valid_input_activity(mode, input) { + msg = format!("Invalid input activity {} for {} mode", input, mode); + } + } + if !valid_ret_activity(mode, ret_activity) { + msg = format!("Invalid return activity {} for {} mode", ret_activity, mode); + } + if msg != "".to_string() { + tcx.dcx().struct_span_err(attr.span, msg).with_note("invalid activity").emit(); + return AutoDiffAttrs::error(); + } + + AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities } +} + pub(crate) fn provide(providers: &mut Providers) { - *providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers }; + *providers = + Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers }; } diff --git a/compiler/rustc_codegen_ssa/src/errors.rs b/compiler/rustc_codegen_ssa/src/errors.rs index d67cf0e3a6d5f..e4f60286ee19c 100644 --- a/compiler/rustc_codegen_ssa/src/errors.rs +++ b/compiler/rustc_codegen_ssa/src/errors.rs @@ -37,6 +37,10 @@ pub(crate) struct CguNotRecorded<'a> { pub cgu_name: &'a str, } +#[derive(Diagnostic)] +#[diag(codegen_ssa_autodiff_without_lto)] +pub struct AutodiffWithoutLto; + #[derive(Diagnostic)] #[diag(codegen_ssa_unknown_reuse_kind)] pub(crate) struct UnknownReuseKind { diff --git a/compiler/rustc_codegen_ssa/src/traits/write.rs b/compiler/rustc_codegen_ssa/src/traits/write.rs index aabe9e33c4aa1..986632a071408 100644 --- a/compiler/rustc_codegen_ssa/src/traits/write.rs +++ b/compiler/rustc_codegen_ssa/src/traits/write.rs @@ -1,3 +1,4 @@ +use rustc_ast::expand::autodiff_attrs::AutoDiffItem; use rustc_errors::{DiagCtxtHandle, FatalError}; use rustc_middle::dep_graph::WorkProduct; @@ -12,6 +13,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { type ModuleBuffer: ModuleBufferMethods; type ThinData: Send + Sync; type ThinBuffer: ThinBufferMethods; + //type TypeTree: Clone; /// Merge all modules into main_module and returning it fn run_link( @@ -36,6 +38,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { ) -> Result<(Vec>, Vec), FatalError>; fn print_pass_timings(&self); fn print_statistics(&self); + // does enzyme prep work, should do ad too. unsafe fn optimize( cgcx: &CodegenContext, dcx: DiagCtxtHandle<'_>, @@ -61,6 +64,13 @@ pub trait WriteBackendMethods: 'static + Sized + Clone { want_summary: bool, ) -> (String, Self::ThinBuffer); fn serialize_module(module: ModuleCodegen) -> (String, Self::ModuleBuffer); + /// Generate autodiff rules + fn autodiff( + cgcx: &CodegenContext, + module: &ModuleCodegen, + diff_fncs: Vec, + config: &ModuleConfig, + ) -> Result<(), FatalError>; } pub trait ThinBufferMethods: Send + Sync { diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index d3762e739db80..04f822d91ebdc 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -758,6 +758,7 @@ fn test_unstable_options_tracking_hash() { tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(assume_incomplete_release, true); + tracked!(autodiff, vec![String::from("ad_flags")]); tracked!(binary_dep_depinfo, true); tracked!(box_noalias, false); tracked!( diff --git a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp index cb75888abd76d..24792926f3643 100644 --- a/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp +++ b/compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp @@ -344,11 +344,44 @@ extern "C" void LLVMRustAddCallSiteAttributes(LLVMValueRef Instr, AddAttributes(Call, Index, Attrs, AttrsLen); } +extern "C" LLVMValueRef LLVMRustGetTerminator(LLVMBasicBlockRef BB) { + Instruction *ret = unwrap(BB)->getTerminator(); + return wrap(ret); +} + +extern "C" void LLVMRustEraseInstFromParent(LLVMValueRef Instr) { + if (auto I = dyn_cast(unwrap(Instr))) { + I->eraseFromParent(); + } +} + +extern "C" LLVMTypeRef LLVMRustGetFunctionType(LLVMValueRef Fn) { + auto Ftype = unwrap(Fn)->getFunctionType(); + return wrap(Ftype); +} + +extern "C" void LLVMRustRemoveEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMRemoveEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAttrNoValue(LLVMContextRef C, LLVMRustAttribute RustAttr) { return wrap(Attribute::get(*unwrap(C), fromRust(RustAttr))); } +extern "C" void LLVMRustAddEnumAttributeAtIndex(LLVMContextRef C, + LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + LLVMAddAttributeAtIndex(F, index, LLVMRustCreateAttrNoValue(C, RustAttr)); +} + +extern "C" LLVMAttributeRef +LLVMRustGetEnumAttributeAtIndex(LLVMValueRef F, size_t index, + LLVMRustAttribute RustAttr) { + return LLVMGetEnumAttributeAtIndex(F, index, fromRust(RustAttr)); +} + extern "C" LLVMAttributeRef LLVMRustCreateAlignmentAttr(LLVMContextRef C, uint64_t Bytes) { return wrap(Attribute::getWithAlignment(*unwrap(C), llvm::Align(Bytes))); @@ -872,6 +905,60 @@ extern "C" bool LLVMRustHasModuleFlag(LLVMModuleRef M, const char *Name, return unwrap(M)->getModuleFlag(StringRef(Name, Len)) != nullptr; } +// pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>; +extern "C" LLVMValueRef LLVMRustGetLastInstruction(LLVMBasicBlockRef BB) { + auto Point = unwrap(BB)->rbegin(); + if (Point != unwrap(BB)->rend()) + return wrap(&*Point); + return nullptr; +} + +extern "C" void LLVMRustEraseInstBefore(LLVMBasicBlockRef bb, LLVMValueRef I) { + auto &BB = *unwrap(bb); + auto &Inst = *unwrap(I); + auto It = BB.begin(); + while (&*It != &Inst) + ++It; + assert(It != BB.end()); + // Delete in rev order to ensure no dangling references. + while (It != BB.begin()) { + auto Prev = std::prev(It); + It->eraseFromParent(); + It = Prev; + } + It->eraseFromParent(); +} + +extern "C" bool LLVMRustHasMetadata(LLVMValueRef inst, unsigned kindID) { + if (auto *I = dyn_cast(unwrap(inst))) { + return I->hasMetadata(kindID); + } + return false; +} + +extern "C" void LLVMRustAddFncParamAttr(LLVMValueRef F, unsigned i, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addParamAttr(i, unwrap(RustAttr)); + } +} + +extern "C" void LLVMRustAddRetFncAttr(LLVMValueRef F, + LLVMAttributeRef RustAttr) { + if (auto *Fn = dyn_cast(unwrap(F))) { + Fn->addRetAttr(unwrap(RustAttr)); + } +} + +extern "C" LLVMMetadataRef LLVMRustDIGetInstMetadata(LLVMValueRef x) { + if (auto *I = dyn_cast(unwrap(x))) { + // auto *MD = I->getMetadata(LLVMContext::MD_dbg); + auto *MD = I->getDebugLoc().getAsMDNode(); + return wrap(MD); + } + return nullptr; +} + extern "C" void LLVMRustGlobalAddMetadata(LLVMValueRef Global, unsigned Kind, LLVMMetadataRef MD) { unwrap(Global)->addMetadata(Kind, *unwrap(MD)); diff --git a/compiler/rustc_middle/messages.ftl b/compiler/rustc_middle/messages.ftl index 39485a324f2e2..d1a9235064ae1 100644 --- a/compiler/rustc_middle/messages.ftl +++ b/compiler/rustc_middle/messages.ftl @@ -1,3 +1,7 @@ +middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe + +middle_unsupported_union = we don't support unions yet: '{$ty_name}' + middle_adjust_for_foreign_abi_error = target architecture {$arch} does not support `extern {$abi}` ABI diff --git a/compiler/rustc_middle/src/arena.rs b/compiler/rustc_middle/src/arena.rs index 52fe9956b4777..8279348378fd7 100644 --- a/compiler/rustc_middle/src/arena.rs +++ b/compiler/rustc_middle/src/arena.rs @@ -86,6 +86,7 @@ macro_rules! arena_types { [] dyn_compatibility_violations: rustc_middle::traits::DynCompatibilityViolation, [] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>, [decode] attribute: rustc_ast::Attribute, + [] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem, [] name_set: rustc_data_structures::unord::UnordSet, [] ordered_name_set: rustc_data_structures::fx::FxIndexSet, [] pats: rustc_middle::ty::PatternKind<'tcx>, diff --git a/compiler/rustc_middle/src/error.rs b/compiler/rustc_middle/src/error.rs index 5c2aa0005d405..aa9a931d42389 100644 --- a/compiler/rustc_middle/src/error.rs +++ b/compiler/rustc_middle/src/error.rs @@ -30,6 +30,20 @@ pub struct OpaqueHiddenTypeMismatch<'tcx> { pub sub: TypeMismatchReason, } +#[derive(Diagnostic)] +#[diag(middle_unsupported_union)] +pub struct UnsupportedUnion { + pub ty_name: String, +} + +#[derive(Diagnostic)] +#[diag(middle_autodiff_unsafe_inner_const_ref)] +pub struct AutodiffUnsafeInnerConstRef { + #[primary_span] + pub span: Span, + pub ty: String, +} + #[derive(Subdiagnostic)] pub enum TypeMismatchReason { #[label(middle_conflict_types)] diff --git a/compiler/rustc_middle/src/query/erase.rs b/compiler/rustc_middle/src/query/erase.rs index 5f8427bd707aa..c4eefa82f0950 100644 --- a/compiler/rustc_middle/src/query/erase.rs +++ b/compiler/rustc_middle/src/query/erase.rs @@ -221,6 +221,9 @@ impl EraseType for (&'_ T0, &'_ [T1]) { impl EraseType for (&'_ T0, Result<(), ErrorGuaranteed>) { type Result = [u8; size_of::<(&'static (), Result<(), ErrorGuaranteed>)>()]; } +impl EraseType for (&'_ T0, &'_ [T1], &'_ [T2]) { + type Result = [u8; size_of::<(&'static (), &'static [()], &'static [()])>()]; +} macro_rules! trivial { ($($ty:ty),+ $(,)?) => { diff --git a/compiler/rustc_middle/src/query/mod.rs b/compiler/rustc_middle/src/query/mod.rs index d03fc39c9ade1..a5637ae99ef33 100644 --- a/compiler/rustc_middle/src/query/mod.rs +++ b/compiler/rustc_middle/src/query/mod.rs @@ -14,6 +14,7 @@ use std::sync::Arc; use rustc_arena::TypedArena; use rustc_ast::expand::StrippedCfgItem; use rustc_ast::expand::allocator::AllocatorKind; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem}; use rustc_data_structures::fingerprint::Fingerprint; use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; use rustc_data_structures::sorted_map::SortedMap; @@ -1270,6 +1271,13 @@ rustc_queries! { feedable } + /// The list autodiff extern functions in current crate + query autodiff_attrs(def_id: DefId) -> &'tcx AutoDiffAttrs { + desc { |tcx| "computing autodiff attributes of `{}`", tcx.def_path_str(def_id) } + arena_cache + cache_on_disk_if { def_id.is_local() } + } + query asm_target_features(def_id: DefId) -> &'tcx FxIndexSet { desc { |tcx| "computing target features for inline asm of `{}`", tcx.def_path_str(def_id) } } @@ -1977,7 +1985,7 @@ rustc_queries! { separate_provide_extern } - query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [CodegenUnit<'tcx>]) { + query collect_and_partition_mono_items(_: ()) -> (&'tcx DefIdSet, &'tcx [AutoDiffItem], &'tcx [CodegenUnit<'tcx>]) { eval_always desc { "collect_and_partition_mono_items" } } diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index 854147648178f..46614a696287f 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -213,6 +213,9 @@ pub struct ResolverAstLowering { pub next_node_id: ast::NodeId, + /// Mapping of autodiff function IDs + pub autodiff_map: FxHashMap, + pub node_id_to_def_id: NodeMap, pub trait_map: NodeMap>, diff --git a/compiler/rustc_monomorphize/Cargo.toml b/compiler/rustc_monomorphize/Cargo.toml index c7f1b9fa78454..bce8d9c4a9878 100644 --- a/compiler/rustc_monomorphize/Cargo.toml +++ b/compiler/rustc_monomorphize/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] # tidy-alphabetical-start +rustc_ast = { path = "../rustc_ast" } rustc_data_structures = { path = "../rustc_data_structures" } rustc_errors = { path = "../rustc_errors" } rustc_fluent_macro = { path = "../rustc_fluent_macro" } @@ -13,6 +14,7 @@ rustc_macros = { path = "../rustc_macros" } rustc_middle = { path = "../rustc_middle" } rustc_session = { path = "../rustc_session" } rustc_span = { path = "../rustc_span" } +rustc_symbol_mangling = { path = "../rustc_symbol_mangling" } rustc_target = { path = "../rustc_target" } serde = "1" serde_json = "1" diff --git a/compiler/rustc_monomorphize/src/collector.rs b/compiler/rustc_monomorphize/src/collector.rs index 3f9a0df030150..8b41fe2582950 100644 --- a/compiler/rustc_monomorphize/src/collector.rs +++ b/compiler/rustc_monomorphize/src/collector.rs @@ -250,7 +250,7 @@ pub(crate) enum MonoItemCollectionStrategy { pub(crate) struct UsageMap<'tcx> { // Maps every mono item to the mono items used by it. - used_map: UnordMap, Vec>>, + pub used_map: UnordMap, Vec>>, // Maps every mono item to the mono items that use it. user_map: UnordMap, Vec>>, diff --git a/compiler/rustc_monomorphize/src/partitioning.rs b/compiler/rustc_monomorphize/src/partitioning.rs index 9bf7e67417eff..2d1ce1d275f8c 100644 --- a/compiler/rustc_monomorphize/src/partitioning.rs +++ b/compiler/rustc_monomorphize/src/partitioning.rs @@ -98,6 +98,7 @@ use std::fs::{self, File}; use std::io::Write; use std::path::{Path, PathBuf}; +use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity}; use rustc_data_structures::fx::{FxIndexMap, FxIndexSet}; use rustc_data_structures::sync; use rustc_data_structures::unord::{UnordMap, UnordSet}; @@ -114,13 +115,14 @@ use rustc_middle::mir::mono::{ }; use rustc_middle::ty::print::{characteristic_def_id_of_type, with_no_trimmed_paths}; use rustc_middle::ty::visit::TypeVisitableExt; -use rustc_middle::ty::{self, InstanceKind, TyCtxt}; +use rustc_middle::ty::{self, InstanceKind, ParamEnv, Ty, TyCtxt}; use rustc_middle::util::Providers; use rustc_session::CodegenUnits; use rustc_session::config::{DumpMonoStatsFormat, SwitchWithOptPath}; use rustc_span::symbol::Symbol; +use rustc_symbol_mangling::symbol_name_for_instance_in_crate; use rustc_target::spec::SymbolVisibility; -use tracing::debug; +use tracing::{debug, trace}; use crate::collector::{self, MonoItemCollectionStrategy, UsageMap}; use crate::errors::{CouldntDumpMonoStats, SymbolAlreadyDefined, UnknownCguCollectionMode}; @@ -251,7 +253,14 @@ where &mut can_be_internalized, export_generics, ); - if visibility == Visibility::Hidden && can_be_internalized { + + // We can't differentiate something that got inlined. + let autodiff_active = match characteristic_def_id { + Some(def_id) => cx.tcx.autodiff_attrs(def_id).is_active(), + None => false, + }; + + if !autodiff_active && visibility == Visibility::Hidden && can_be_internalized { internalization_candidates.insert(mono_item); } let size_estimate = mono_item.size_estimate(cx.tcx); @@ -1102,7 +1111,66 @@ where } } -fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[CodegenUnit<'_>]) { +pub(crate) fn adjust_activity_to_abi<'tcx>( + tcx: TyCtxt<'tcx>, + fn_ty: Ty<'tcx>, + da: &mut Vec) { + if !fn_ty.is_fn() { + // Error? + return; + } + let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx); + + // If rustc compiles the unmodified primal, we know that this copy of the function + // also has correct lifetimes. We know that Enzyme won't free the shadow too early + // (or actually at all), so let's strip lifetimes when computing the layout. + // Recommended by compiler-errors: + // https://discord.com/channels/273534239310479360/957720175619215380/1223454360676208751 + let x = tcx.instantiate_bound_regions_with_erased(fnc_binder); + let mut new_activities = vec![]; + let mut new_positions = vec![]; + for (i, ty) in x.inputs().iter().enumerate() { + if ty.is_unsafe_ptr() || ty.is_ref() || ty.is_box() { + if ty.is_fn_ptr() { + unimplemented!("what to do whith fn ptr?"); + } + let inner_ty = ty.builtin_deref(true).unwrap(); + if inner_ty.is_slice() { + // We know that the lenght will be passed as extra arg. + if !da.is_empty() { + // We are looking at a slice. The length of that slice will become an + // extra integer on llvm level. Integers are always const. + // However, if the slice get's duplicated, we want to know to later check the + // size. So we mark the new size argument as FakeActivitySize. + let activity = match da[i] { + DiffActivity::DualOnly + | DiffActivity::Dual + | DiffActivity::DuplicatedOnly + | DiffActivity::Duplicated => DiffActivity::FakeActivitySize, + DiffActivity::Const => DiffActivity::Const, + _ => panic!("unexpected activity for ptr/ref"), + }; + new_activities.push(activity); + new_positions.push(i + 1); + } + trace!("ABI MATCHING!"); + continue; + } + } + } + // now add the extra activities coming from slices + // Reverse order to not invalidate the indices + for _ in 0..new_activities.len() { + let pos = new_positions.pop().unwrap(); + let activity = new_activities.pop().unwrap(); + da.insert(pos, activity); + } +} + +fn collect_and_partition_mono_items( + tcx: TyCtxt<'_>, + (): (), +) -> (&DefIdSet, &[AutoDiffItem], &[CodegenUnit<'_>]) { let collection_strategy = match tcx.sess.opts.unstable_opts.print_mono_items { Some(ref s) => { let mode = s.to_lowercase(); @@ -1164,6 +1232,59 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co }) .collect(); + let autodiff_items2: Vec<_> = items + .iter() + .filter_map(|item| match *item { + MonoItem::Fn(ref instance) => Some((item, instance)), + _ => None, + }) + .collect(); + let mut autodiff_items: Vec = vec![]; + + for (item, instance) in autodiff_items2 { + let target_id = instance.def_id(); + let target_attrs: &AutoDiffAttrs = tcx.autodiff_attrs(target_id); + let mut input_activities: Vec = target_attrs.input_activity.clone(); + if target_attrs.is_source() { + trace!("source found: {:?}", target_id); + } + if !target_attrs.apply_autodiff() { + continue; + } + + let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE); + + let source = + usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item { + MonoItem::Fn(ref instance_s) => { + let source_id = instance_s.def_id(); + if tcx.autodiff_attrs(source_id).is_active() { + return Some(instance_s); + } + None + } + _ => None, + }); + let inst = match source { + Some(source) => source, + None => continue, + }; + + println!("source_id: {:?}", inst.def_id()); + let fn_ty = inst.ty(tcx, ParamEnv::empty()); + assert!(fn_ty.is_fn()); + adjust_activity_to_abi(tcx, fn_ty, &mut input_activities); + //let (inputs, output) = (fnc_tree.args, fnc_tree.ret); + let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE); + + let mut new_target_attrs = target_attrs.clone(); + new_target_attrs.input_activity = input_activities; + let itm = new_target_attrs.into_item(symb, target_symbol); + //let itm = new_target_attrs.into_item(symb, target_symbol, inputs, output); + autodiff_items.push(itm); + } + let autodiff_items = tcx.arena.alloc_from_iter(autodiff_items); + // Output monomorphization stats per def_id if let SwitchWithOptPath::Enabled(ref path) = tcx.sess.opts.unstable_opts.dump_mono_stats { if let Err(err) = @@ -1224,7 +1345,14 @@ fn collect_and_partition_mono_items(tcx: TyCtxt<'_>, (): ()) -> (&DefIdSet, &[Co } } - (tcx.arena.alloc(mono_items), codegen_units) + if autodiff_items.len() > 0 { + trace!("AUTODIFF ITEMS EXIST"); + for item in &mut *autodiff_items { + trace!("{}", &item); + } + } + + (tcx.arena.alloc(mono_items), autodiff_items, codegen_units) } /// Outputs stats about instantiation counts and estimated size, per `MonoItem`'s @@ -1308,12 +1436,12 @@ pub(crate) fn provide(providers: &mut Providers) { providers.collect_and_partition_mono_items = collect_and_partition_mono_items; providers.is_codegened_item = |tcx, def_id| { - let (all_mono_items, _) = tcx.collect_and_partition_mono_items(()); + let (all_mono_items, _, _) = tcx.collect_and_partition_mono_items(()); all_mono_items.contains(&def_id) }; providers.codegen_unit = |tcx, name| { - let (_, all) = tcx.collect_and_partition_mono_items(()); + let (_, _, all) = tcx.collect_and_partition_mono_items(()); all.iter() .find(|cgu| cgu.name() == name) .unwrap_or_else(|| panic!("failed to find cgu with name {name:?}")) diff --git a/compiler/rustc_passes/messages.ftl b/compiler/rustc_passes/messages.ftl index 3f98236595bd7..c1fd3efdb3b7a 100644 --- a/compiler/rustc_passes/messages.ftl +++ b/compiler/rustc_passes/messages.ftl @@ -13,6 +13,10 @@ passes_abi_ne = passes_abi_of = fn_abi_of({$fn_name}) = {$fn_abi} +passes_autodiff_attr = + `#[autodiff]` should be applied to a function + .label = not a function + passes_allow_incoherent_impl = `rustc_allow_incoherent_impl` attribute should be applied to impl items .label = the only currently supported targets are inherent methods @@ -49,10 +53,6 @@ passes_attr_crate_level = passes_attr_only_in_functions = `{$attr}` attribute can only be used on functions -passes_autodiff_attr = - `#[autodiff]` should be applied to a function - .label = not a function - passes_both_ffi_const_and_pure = `#[ffi_const]` function cannot be `#[ffi_pure]` diff --git a/compiler/rustc_resolve/src/lib.rs b/compiler/rustc_resolve/src/lib.rs index 35d491cfc18e7..14e6091ef2118 100644 --- a/compiler/rustc_resolve/src/lib.rs +++ b/compiler/rustc_resolve/src/lib.rs @@ -1168,6 +1168,8 @@ pub struct Resolver<'ra, 'tcx> { node_id_to_def_id: NodeMap>, def_id_to_node_id: IndexVec, + autodiff_map: FxHashMap, + /// Indices of unnamed struct or variant fields with unresolved attributes. placeholder_field_indices: FxHashMap, /// When collecting definitions from an AST fragment produced by a macro invocation `ExpnId` @@ -1539,6 +1541,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> { next_node_id: CRATE_NODE_ID, node_id_to_def_id, def_id_to_node_id, + autodiff_map: Default::default(), placeholder_field_indices: Default::default(), invocation_parents, trait_impl_items: Default::default(), @@ -1668,6 +1671,7 @@ impl<'ra, 'tcx> Resolver<'ra, 'tcx> { .into_items() .map(|(k, f)| (k, f.key())) .collect(), + autodiff_map: self.autodiff_map, trait_map: self.trait_map, lifetime_elision_allowed: self.lifetime_elision_allowed, lint_buffer: Steal::new(self.lint_buffer), diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index d733e32f209db..999e401b71ad7 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -194,6 +194,39 @@ impl Default for CoverageLevel { } } +/// The different settings that the `-Z ad` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum AutoDiff { + /// Print TypeAnalysis information + PrintTA, + /// Print ActivityAnalysis Information + PrintAA, + /// Print Performance Warnings from Enzyme + PrintPerf, + /// Combines the three print flags above. + Print, + /// Print the whole module, before running opts. + PrintModBefore, + /// Print the whole module just before we pass it to Enzyme. + /// For Debug purpose, prefer the OPT flag below + PrintModAfterOpts, + /// Print the module after Enzyme differentiated everything. + PrintModAfterEnzyme, + + /// Enzyme's loose type debug helper (can cause incorrect gradients) + LooseTypes, + + /// More flags + NoModOptAfter, + /// Tell Enzyme to run LLVM Opts on each function it generated. By default off, + /// since we already optimize the whole module after Enzyme is done. + EnableFncOpt, + NoVecUnroll, + RuntimeActivity, + /// Runs Enzyme specific Inlining + Inline, +} + /// Settings for `-Z instrument-xray` flag. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct InstrumentXRay { @@ -3022,7 +3055,7 @@ pub(crate) mod dep_tracking { }; use super::{ - BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, + AutoDiff, BranchProtection, CFGuard, CFProtection, CollapseMacroDebuginfo, CoverageOptions, CrateType, DebugInfo, DebugInfoCompression, ErrorOutputType, FmtDebug, FunctionReturn, InliningThreshold, InstrumentCoverage, InstrumentXRay, LinkerPluginLto, LocationDetail, LtoCli, NextSolverConfig, OomStrategy, OptLevel, OutFileName, OutputType, OutputTypes, @@ -3070,6 +3103,7 @@ pub(crate) mod dep_tracking { } impl_dep_tracking_hash_via_hash!( + AutoDiff, bool, usize, NonZero, diff --git a/compiler/rustc_session/src/config/cfg.rs b/compiler/rustc_session/src/config/cfg.rs index ccc01728958b8..d9da9e9f26a45 100644 --- a/compiler/rustc_session/src/config/cfg.rs +++ b/compiler/rustc_session/src/config/cfg.rs @@ -176,6 +176,8 @@ pub(crate) fn default_configuration(sess: &Session) -> Cfg { // NOTE: These insertions should be kept in sync with // `CheckCfg::fill_well_known` below. + ins_none!(sym::autodiff_fallback); + if sess.opts.debug_assertions { ins_none!(sym::debug_assertions); } @@ -339,6 +341,7 @@ impl CheckCfg { // Don't forget to update `src/doc/rustc/src/check-cfg.md` // in the unstable book as well! + ins!(sym::autodiff_fallback, no_values); ins!(sym::debug_assertions, no_values); ins!(sym::fmt_debug, empty_values).extend(FmtDebug::all()); diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 54a4621db2462..472baf8ec0174 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -370,6 +370,7 @@ mod desc { pub(crate) const parse_list: &str = "a space-separated list of strings"; pub(crate) const parse_list_with_polarity: &str = "a comma-separated list of strings, with elements beginning with + or -"; + pub(crate) const parse_autodiff: &str = "various values"; pub(crate) const parse_comma_list: &str = "a comma-separated list of strings"; pub(crate) const parse_opt_comma_list: &str = parse_comma_list; pub(crate) const parse_number: &str = "a number"; @@ -996,6 +997,35 @@ mod parse { } } + pub(crate) fn parse_autodiff(slot: &mut Vec, v: Option<&str>) -> bool { + let Some(v) = v else { + *slot = vec![]; + return true; + }; + let mut v: Vec<&str> = v.split(",").collect(); + v.sort_unstable(); + for &val in v.iter() { + let variant = match val { + "PrintTA" => AutoDiff::PrintTA, + "PrintAA" => AutoDiff::PrintAA, + "PrintPerf" => AutoDiff::PrintPerf, + "Print" => AutoDiff::Print, + "PrintModBefore" => AutoDiff::PrintModBefore, + "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts, + "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme, + "LooseTypes" => AutoDiff::LooseTypes, + "NoModOptAfter" => AutoDiff::NoModOptAfter, + "EnableFncOpt" => AutoDiff::EnableFncOpt, + "NoVecUnroll" => AutoDiff::NoVecUnroll, + "Inline" => AutoDiff::Inline, + _ => return false, + }; + slot.push(variant); + } + + true + } + pub(crate) fn parse_instrument_coverage( slot: &mut InstrumentCoverage, v: Option<&str>, @@ -1680,6 +1710,8 @@ options! { either `loaded` or `not-loaded`."), assume_incomplete_release: bool = (false, parse_bool, [TRACKED], "make cfg(version) treat the current version as incomplete (default: no)"), + autodiff: Vec = (Vec::new(), parse_autodiff, [TRACKED], + "a list autodiff flags to enable (comma separated)"), #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")] binary_dep_depinfo: bool = (false, parse_bool, [TRACKED], "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \ diff --git a/config.example.toml b/config.example.toml index a52968e9a414d..440274d51e344 100644 --- a/config.example.toml +++ b/config.example.toml @@ -158,6 +158,9 @@ # Whether to build the clang compiler. #clang = false +# Wheter to build Enzyme as AutoDiff backend. +#enzyme = true + # Whether to enable llvm compilation warnings. #enable-warnings = false diff --git a/src/bootstrap/src/core/build_steps/compile.rs b/src/bootstrap/src/core/build_steps/compile.rs index 27bbc8bd8ff22..0033673079067 100644 --- a/src/bootstrap/src/core/build_steps/compile.rs +++ b/src/bootstrap/src/core/build_steps/compile.rs @@ -1068,10 +1068,6 @@ pub fn rustc_cargo( // . cargo.rustflag("-Zon-broken-pipe=kill"); - if builder.config.llvm_enzyme { - cargo.rustflag("-l").rustflag("Enzyme-19"); - } - // We currently don't support cross-crate LTO in stage0. This also isn't hugely necessary // and may just be a time sink. if compiler.stage != 0 {