Skip to content

Commit

Permalink
only include rustc_codegen_llvm autodiff changes
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Oct 1, 2024
1 parent fb4aebd commit 8726c99
Show file tree
Hide file tree
Showing 24 changed files with 2,283 additions and 27 deletions.
79 changes: 79 additions & 0 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/// This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
/// we create an `AutoDiffItem` which contains the source and target function names. The source
/// is the function to which the autodiff attribute is applied, and the target is the function
/// getting generated by us (with a name given by the user as the first autodiff arg).
use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
/// No autodiff is applied (usually used during error handling).
Inactive,
/// The primal function which we will differentiate.
Source,
/// The target function, to be created using forward mode AD.
Forward,
/// The target function, to be created using reverse mode AD.
Reverse,
/// The target function, to be created using forward mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ForwardFirst,
/// The target function, to be created using reverse mode AD.
/// This target function will also be used as a source for higher order derivatives,
/// so compute it before all Forward/Reverse targets and optimize it through llvm.
ReverseFirst,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
/// Implicit or Explicit () return type, so a special case of Const.
None,
/// Don't compute derivatives with respect to this input/output.
Const,
/// Reverse Mode, Compute derivatives for this scalar input/output.
Active,
/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
/// the original return value.
ActiveOnly,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it.
Dual,
/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
/// with it. Drop the code which updates the original input/output for maximum performance.
DualOnly,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
Duplicated,
/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
/// Drop the code which updates the original input for maximum performance.
DuplicatedOnly,
/// All Integers must be Const, but these are used to mark the integer which represents the
/// length of a slice/vec. This is used for safety checks on slices.
FakeActivitySize,
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
/// The name of the function getting differentiated
pub source: String,
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
/// Despribe the memory layout of input types
pub inputs: Vec<TypeTree>,
/// Despribe the memory layout of the output type
pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
/// e.g. in the [JAX
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
pub mode: DiffMode,
pub ret_activity: DiffActivity,
pub input_activity: Vec<DiffActivity>,
}
2 changes: 2 additions & 0 deletions compiler/rustc_ast/src/expand/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ use rustc_span::symbol::Ident;
use crate::MetaItem;

pub mod allocator;
pub mod autodiff_attrs;
pub mod typetree;

#[derive(Debug, Clone, Encodable, Decodable, HashStable_Generic)]
pub struct StrippedCfgItem<ModId = DefId> {
Expand Down
69 changes: 69 additions & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
Anything,
Integer,
Pointer,
Half,
Float,
Double,
Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
pub fn new() -> Self {
Self(Vec::new())
}
pub fn all_ints() -> Self {
Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
}
pub fn int(size: usize) -> Self {
let mut ints = Vec::with_capacity(size);
for i in 0..size {
ints.push(Type {
offset: i as isize,
size: 1,
kind: Kind::Integer,
child: TypeTree::new(),
});
}
Self(ints)
}
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
pub args: Vec<TypeTree>,
pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
pub offset: isize,
pub size: usize,
pub kind: Kind,
pub child: TypeTree,
}

impl Type {
pub fn add_offset(self, add: isize) -> Self {
let offset = match self.offset {
-1 => add,
x => add + x,
};

Self { size: self.size, kind: self.kind, child: self.child, offset }
}
}

impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
<Self as fmt::Debug>::fmt(self, f)
}
}
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/messages.ftl
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ codegen_llvm_prepare_thin_lto_module_with_llvm_err = failed to prepare thin LTO
codegen_llvm_run_passes = failed to run LLVM passes
codegen_llvm_run_passes_with_llvm_err = failed to run LLVM passes: {$llvm_err}
codegen_llvm_prepare_autodiff = failed to prepare AutoDiff: src: {$src}, target: {$target}, {$error}
codegen_llvm_prepare_autodiff_with_llvm_err = failed to prepare AutoDiff: {$llvm_err}, src: {$src}, target: {$target}, {$error}
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
codegen_llvm_sanitizer_memtag_requires_mte =
`-Zsanitizer=memtag` requires `-Ctarget-feature=+mte`
Expand Down
7 changes: 7 additions & 0 deletions compiler/rustc_codegen_llvm/src/attributes.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
//! Set and unset common attributes on LLVM values.
use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr};
// FIXME(ZuseZ4): Re-enable once the middle-end is merged.
//use rustc_ast::expand::autodiff_attrs::AutoDiffAttrs;
use rustc_codegen_ssa::traits::*;
use rustc_hir::def_id::DefId;
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, PatchableFunctionEntry};
Expand Down Expand Up @@ -333,6 +335,8 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>(
instance: ty::Instance<'tcx>,
) {
let codegen_fn_attrs = cx.tcx.codegen_fn_attrs(instance.def_id());
// FIXME(ZuseZ4): Re-enable once the middle-end is merged.
//let autodiff_attrs: &AutoDiffAttrs = cx.tcx.autodiff_attrs(instance.def_id());

let mut to_add = SmallVec::<[_; 16]>::new();

Expand All @@ -350,6 +354,9 @@ pub(crate) fn llfn_attrs_from_instance<'ll, 'tcx>(
let inline =
if codegen_fn_attrs.inline == InlineAttr::None && instance.def.requires_inline(cx.tcx) {
InlineAttr::Hint
// FIXME(ZuseZ4): re-enable once the middle-end is merged.
//} else if autodiff_attrs.is_active() {
// InlineAttr::Never
} else {
codegen_fn_attrs.inline
};
Expand Down
14 changes: 12 additions & 2 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,12 @@ pub(crate) fn run_pass_manager(
}
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage)?;

// We will run this again with different values in the context of automatic differentiation.
let first_run = true;
let noop = false;
debug!("running llvm pm opt pipeline");
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, first_run, noop)?;
}
debug!("lto done");
Ok(())
Expand Down Expand Up @@ -723,7 +728,12 @@ pub(crate) unsafe fn optimize_thin_module(
let llcx = unsafe { llvm::LLVMRustContextCreate(cgcx.fewer_names) };
let llmod_raw = parse_module(llcx, module_name, thin_module.data(), dcx)? as *const _;
let mut module = ModuleCodegen {
module_llvm: ModuleLlvm { llmod_raw, llcx, tm: ManuallyDrop::new(tm) },
module_llvm: ModuleLlvm {
llmod_raw,
llcx,
tm: ManuallyDrop::new(tm),
typetrees: Default::default(),
},
name: thin_module.name().to_string(),
kind: ModuleKind::Regular,
};
Expand Down
Loading

0 comments on commit 8726c99

Please sign in to comment.