Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions compiler/rustc_ast/src/expand/autodiff_attrs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::{Ty, TyKind};

Expand Down Expand Up @@ -84,6 +85,8 @@ pub struct AutoDiffItem {
/// The name of the function being generated
pub target: String,
pub attrs: AutoDiffAttrs,
pub inputs: Vec<TypeTree>,
pub output: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
Expand Down Expand Up @@ -275,14 +278,22 @@ impl AutoDiffAttrs {
!matches!(self.mode, DiffMode::Error | DiffMode::Source)
}

pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
AutoDiffItem { source, target, attrs: self }
pub fn into_item(
self,
source: String,
target: String,
inputs: Vec<TypeTree>,
output: TypeTree,
) -> AutoDiffItem {
AutoDiffItem { source, target, inputs, output, attrs: self }
}
}

impl fmt::Display for AutoDiffItem {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
write!(f, " with attributes: {:?}", self.attrs)
write!(f, " with attributes: {:?}", self.attrs)?;
write!(f, " with inputs: {:?}", self.inputs)?;
write!(f, " with output: {:?}", self.output)
}
}
1 change: 1 addition & 0 deletions compiler/rustc_ast/src/expand/typetree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pub enum Kind {
Half,
Float,
Double,
F128,
Unknown,
}

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for making a separate datastructure here on the rust side and converting it to an Enzyme TypeTree, instead of just calling the API for Enzyme's TypeTree directly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, this was already there, @ZuseZ4 can tell, but I will check again and will remove it

Copy link
Member

@ZuseZ4 ZuseZ4 Sep 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the rustc frontend (rustc_ast) shouldn't depend on an optional plugin for the LLVM backend, which itself is only one out of multiple rustc backends. We generally try to avoid directly using unsafe LLVM / C functions.

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_gcc/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,7 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
_src_align: Align,
size: RValue<'gcc>,
flags: MemFlags,
_tt: Option<rustc_ast::expand::typetree::FncTree>, // Autodiff TypeTrees are LLVM-only, ignored in GCC backend
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_size_t(), false);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_gcc/src/intrinsic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -771,6 +771,7 @@ impl<'gcc, 'tcx> ArgAbiExt<'gcc, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(self.layout.size.bytes()),
MemFlags::empty(),
None,
);

bx.lifetime_end(scratch, scratch_size);
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ impl<'ll, 'tcx> ArgAbiExt<'ll, 'tcx> for ArgAbi<'tcx, Ty<'tcx>> {
scratch_align,
bx.const_usize(copy_bytes),
MemFlags::empty(),
None,
);
bx.lifetime_end(llscratch, scratch_size);
}
Expand Down
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_llvm/src/back/lto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,8 @@ fn enable_autodiff_settings(ad: &[config::AutoDiff]) {
config::AutoDiff::Enable => {}
// We handle this below
config::AutoDiff::NoPostopt => {}
// Disables TypeTree generation
config::AutoDiff::NoTT => {}
}
}
// This helps with handling enums for now.
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_codegen_llvm/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::borrow::{Borrow, Cow};
use std::ops::Deref;
use std::{iter, ptr};

use rustc_ast::expand::typetree::FncTree;
pub(crate) mod autodiff;
pub(crate) mod gpu_offload;

Expand Down Expand Up @@ -1107,11 +1108,12 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align: Align,
size: &'ll Value,
flags: MemFlags,
tt: Option<FncTree>,
) {
assert!(!flags.contains(MemFlags::NONTEMPORAL), "non-temporal memcpy not supported");
let size = self.intcast(size, self.type_isize(), false);
let is_volatile = flags.contains(MemFlags::VOLATILE);
unsafe {
let memcpy = unsafe {
llvm::LLVMRustBuildMemCpy(
self.llbuilder,
dst,
Expand All @@ -1120,7 +1122,16 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
src_align.bytes() as c_uint,
size,
is_volatile,
);
)
};

// TypeTree metadata for memcpy is especially important: when Enzyme encounters
// a memcpy during autodiff, it needs to know the structure of the data being
// copied to properly track derivatives. For example, copying an array of floats
// vs. copying a struct with mixed types requires different derivative handling.
// The TypeTree tells Enzyme exactly what memory layout to expect.
if let Some(tt) = tt {
crate::typetree::add_tt(self.cx().llmod, self.cx().llcx, memcpy, tt);
}
}

Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_codegen_llvm/src/builder/autodiff.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::ptr;

use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
use rustc_ast::expand::typetree::FncTree;
use rustc_codegen_ssa::common::TypeKind;
use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
Expand Down Expand Up @@ -294,6 +295,7 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args: &[&'ll Value],
attrs: AutoDiffAttrs,
dest: PlaceRef<'tcx, &'ll Value>,
fnc_tree: FncTree,
) {
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
let mut ad_name: String = match attrs.mode {
Expand Down Expand Up @@ -370,6 +372,10 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
fn_args,
);

if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
}

let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);

builder.store_to_place(call, dest.val);
Expand Down
4 changes: 4 additions & 0 deletions compiler/rustc_codegen_llvm/src/intrinsic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1213,6 +1213,9 @@ fn codegen_autodiff<'ll, 'tcx>(
&mut diff_attrs.input_activity,
);

let fnc_tree =
rustc_middle::ty::fnc_typetrees(tcx, fn_source.ty(tcx, TypingEnv::fully_monomorphized()));

// Build body
generate_enzyme_call(
bx,
Expand All @@ -1223,6 +1226,7 @@ fn codegen_autodiff<'ll, 'tcx>(
&val_arr,
diff_attrs.clone(),
result,
fnc_tree,
);
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_codegen_llvm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ mod llvm_util;
mod mono_item;
mod type_;
mod type_of;
mod typetree;
mod va_arg;
mod value;

Expand Down
Loading
Loading