Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 0243b2b

Browse files
committedMar 13, 2025
autodiff batching mvp
1 parent f2d69d5 commit 0243b2b

File tree

14 files changed

+746
-286
lines changed

14 files changed

+746
-286
lines changed
 

‎compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
pub width: u32,
8081
pub ret_activity: DiffActivity,
8182
pub input_activity: Vec<DiffActivity>,
8283
}
@@ -222,13 +223,15 @@ impl AutoDiffAttrs {
222223
pub const fn error() -> Self {
223224
AutoDiffAttrs {
224225
mode: DiffMode::Error,
226+
width: 0,
225227
ret_activity: DiffActivity::None,
226228
input_activity: Vec::new(),
227229
}
228230
}
229231
pub fn source() -> Self {
230232
AutoDiffAttrs {
231233
mode: DiffMode::Source,
234+
width: 0,
232235
ret_activity: DiffActivity::None,
233236
input_activity: Vec::new(),
234237
}

‎compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
7676
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
7777
builtin_macros_autodiff_number_activities = expected {$expected} activities, but found {$found}
7878
builtin_macros_autodiff_ty_activity = {$act} can not be used for this type
79+
builtin_macros_autodiff_width = autodiff width must fit u32, but is {$width}
7980
8081
builtin_macros_autodiff_unknown_activity = did not recognize Activity: `{$act}`
8182
builtin_macros_bad_derive_target = `derive` may only be applied to `struct`s, `enum`s and `union`s

‎compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 289 additions & 158 deletions
Large diffs are not rendered by default.

‎compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,14 @@ mod autodiff {
193193
pub(crate) mode: String,
194194
}
195195

196+
#[derive(Diagnostic)]
197+
#[diag(builtin_macros_autodiff_width)]
198+
pub(crate) struct AutoDiffInvalidWidth {
199+
#[primary_span]
200+
pub(crate) span: Span,
201+
pub(crate) width: u128,
202+
}
203+
196204
#[derive(Diagnostic)]
197205
#[diag(builtin_macros_autodiff)]
198206
pub(crate) struct AutoDiffInvalidApplication {

‎compiler/rustc_codegen_llvm/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
codegen_llvm_autodiff_unused_args = implementation bug, failed to match all args on llvm level
12
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
23
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
34

‎compiler/rustc_codegen_llvm/src/back/lto.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,10 @@ pub(crate) fn run_pass_manager(
655655
unsafe {
656656
write::llvm_optimize(cgcx, dcx, module, None, config, opt_level, opt_stage, stage)?;
657657
}
658+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
659+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
660+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
661+
}
658662

659663
if cfg!(llvm_enzyme) && enable_ad {
660664
let opt_stage = llvm::OptStage::FatLTO;

‎compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 309 additions & 120 deletions
Large diffs are not rendered by default.

‎compiler/rustc_codegen_llvm/src/consts.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ impl<'ll> CodegenCx<'ll, '_> {
405405
let val_llty = self.val_ty(v);
406406

407407
let g = self.get_static_inner(def_id, val_llty);
408-
let llty = llvm::LLVMGlobalGetValueType(g);
408+
let llty = self.get_type_of_global(g);
409409

410410
let g = if val_llty == llty {
411411
g

‎compiler/rustc_codegen_llvm/src/context.rs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use crate::debuginfo::metadata::apply_vcall_visibility_metadata;
3737
use crate::llvm::Metadata;
3838
use crate::type_::Type;
3939
use crate::value::Value;
40-
use crate::{attributes, coverageinfo, debuginfo, llvm, llvm_util};
40+
use crate::{attributes, common, coverageinfo, debuginfo, llvm, llvm_util};
4141

4242
/// `TyCtxt` (and related cache datastructures) can't be move between threads.
4343
/// However, there are various cx related functions which we want to be available to the builder and
@@ -642,7 +642,18 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
642642
llvm::set_section(g, c"llvm.metadata");
643643
}
644644
}
645-
645+
impl<'ll> SimpleCx<'ll> {
646+
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
647+
assert!(unsafe { llvm::LLVMRustIsFunctionTy(ty) });
648+
unsafe { llvm::LLVMGetReturnType(ty) }
649+
}
650+
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
651+
unsafe { llvm::LLVMGlobalGetValueType(val) }
652+
}
653+
pub(crate) fn val_ty(&self, v: &'ll Value) -> &'ll Type {
654+
common::val_ty(v)
655+
}
656+
}
646657
impl<'ll> SimpleCx<'ll> {
647658
pub(crate) fn new(
648659
llmod: &'ll llvm::Module,
@@ -659,6 +670,11 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
659670
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
660671
}
661672

673+
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
674+
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
675+
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }
676+
}
677+
662678
pub(crate) fn get_function(&self, name: &str) -> Option<&'ll Value> {
663679
let name = SmallCStr::new(name);
664680
unsafe { llvm::LLVMGetNamedFunction((**self).borrow().llmod, name.as_ptr()) }

‎compiler/rustc_codegen_llvm/src/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,10 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
9494
#[diag(codegen_llvm_autodiff_without_lto)]
9595
pub(crate) struct AutoDiffWithoutLTO;
9696

97+
#[derive(Diagnostic)]
98+
#[diag(codegen_llvm_autodiff_unused_args)]
99+
pub(crate) struct AutoDiffUnusedArgs;
100+
97101
#[derive(Diagnostic)]
98102
#[diag(codegen_llvm_autodiff_without_enable)]
99103
pub(crate) struct AutoDiffWithoutEnable;

‎compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use libc::{c_char, c_uint};
55

66
use super::MetadataKindId;
7-
use super::ffi::{BasicBlock, Metadata, Module, Type, Value};
7+
use super::ffi::{AttributeKind, BasicBlock, Metadata, Module, Type, Value};
88
use crate::llvm::Bool;
99

1010
#[link(name = "llvm-wrapper", kind = "static")]
@@ -17,6 +17,11 @@ unsafe extern "C" {
1717
pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
1818
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
1919
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
20+
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
21+
22+
pub(crate) fn LLVMRustIsFunctionTy(Ty: &Type) -> bool;
23+
pub(crate) fn LLVMRustIsArrayTy(Ty: &Type) -> bool;
24+
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
2025
}
2126

2227
unsafe extern "C" {

‎compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use rustc_abi::ExternAbi;
44
use rustc_ast::expand::autodiff_attrs::{
55
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
66
};
7-
use rustc_ast::{MetaItem, MetaItemInner, attr};
7+
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
88
use rustc_attr_parsing::ReprAttr::ReprAlign;
99
use rustc_attr_parsing::{AttributeKind, InlineAttr, InstructionSetAttr, OptimizeAttr};
1010
use rustc_data_structures::fx::FxHashMap;
@@ -819,8 +819,8 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
819819
return Some(AutoDiffAttrs::source());
820820
}
821821

822-
let [mode, input_activities @ .., ret_activity] = &list[..] else {
823-
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode and activities");
822+
let [mode, width_meta, input_activities @ .., ret_activity] = &list[..] else {
823+
span_bug!(attr.span(), "rustc_autodiff attribute must contain mode, width and activities");
824824
};
825825
let mode = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = mode {
826826
p1.segments.first().unwrap().ident
@@ -837,6 +837,34 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
837837
}
838838
};
839839

840+
let width: u32;
841+
match width_meta {
842+
MetaItemInner::MetaItem(MetaItem { path: p1, .. }) => {
843+
let w = p1.segments.first().unwrap().ident;
844+
width = match w.as_str().parse() {
845+
Ok(val) => val,
846+
Err(_) => {
847+
span_bug!(w.span, "rustc_autodiff width should fit u32");
848+
}
849+
};
850+
}
851+
MetaItemInner::Lit(lit) => {
852+
if let LitKind::Int(val, _) = lit.kind {
853+
width = match val.get().try_into() {
854+
Ok(val) => val,
855+
Err(_) => {
856+
span_bug!(lit.span, "rustc_autodiff width should fit u32");
857+
}
858+
};
859+
} else {
860+
span_bug!(lit.span, "rustc_autodiff width should be an integer");
861+
}
862+
}
863+
_ => {
864+
span_bug!(width_meta.span(), "failed to parse rustc_autodiff width");
865+
}
866+
}
867+
840868
// First read the ret symbol from the attribute
841869
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {
842870
p1.segments.first().unwrap().ident
@@ -883,7 +911,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
883911
span_bug!(attr.span(), "Invalid return activity {} for {} mode", ret_activity, mode);
884912
}
885913

886-
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
914+
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
887915
}
888916

889917
pub(crate) fn provide(providers: &mut Providers) {

‎compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,12 @@ static inline void AddAttributes(T *t, unsigned Index, LLVMAttributeRef *Attrs,
384384
t->setAttributes(PALNew);
385385
}
386386

387+
extern "C" bool LLVMRustHasAttributeAtIndex(LLVMValueRef Fn, unsigned Index,
388+
LLVMRustAttributeKind RustAttr) {
389+
Function *F = unwrap<Function>(Fn);
390+
return F->hasParamAttribute(Index, fromRust(RustAttr));
391+
}
392+
387393
extern "C" void LLVMRustAddFunctionAttributes(LLVMValueRef Fn, unsigned Index,
388394
LLVMAttributeRef *Attrs,
389395
size_t AttrsLen) {
@@ -635,6 +641,17 @@ static InlineAsm::AsmDialect fromRust(LLVMRustAsmDialect Dialect) {
635641
report_fatal_error("bad AsmDialect.");
636642
}
637643
}
644+
extern "C" bool LLVMRustIsFunctionTy(LLVMTypeRef Ty) {
645+
return unwrap(Ty)->isFunctionTy();
646+
}
647+
648+
extern "C" bool LLVMRustIsArrayTy(LLVMTypeRef Ty) {
649+
return unwrap(Ty)->isArrayTy();
650+
}
651+
652+
extern "C" uint64_t LLVMRustGetArrayNumElements(LLVMTypeRef Ty) {
653+
return unwrap(Ty)->getArrayNumElements();
654+
}
638655

639656
extern "C" LLVMValueRef
640657
LLVMRustInlineAsm(LLVMTypeRef Ty, char *AsmString, size_t AsmStringLen,

‎tests/codegen/autodiffv.rs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
//@ compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat
2+
//@ no-prefer-dynamic
3+
//@ needs-enzyme
4+
#![feature(autodiff)]
5+
6+
use std::autodiff::autodiff;
7+
8+
#[autodiff(d_square, Reverse, 4, Duplicated, Active)]
9+
#[no_mangle]
10+
fn square(x: &f64) -> f64 {
11+
x * x
12+
}
13+
14+
// CHECK:define internal fastcc void @diffe4square([4 x ptr] %"x'"
15+
// CHECK-NEXT:invertstart:
16+
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
17+
// CHECK-NEXT: %1 = load double, ptr %0, align 8, !alias.scope !15950, !noalias !15953
18+
// CHECK-NEXT: %2 = fadd fast double %1, 6.000000e+00
19+
// CHECK-NEXT: store double %2, ptr %0, align 8, !alias.scope !15950, !noalias !15953
20+
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 1
21+
// CHECK-NEXT: %4 = load double, ptr %3, align 8, !alias.scope !15958, !noalias !15959
22+
// CHECK-NEXT: %5 = fadd fast double %4, 6.000000e+00
23+
// CHECK-NEXT: store double %5, ptr %3, align 8, !alias.scope !15958, !noalias !15959
24+
// CHECK-NEXT: %6 = extractvalue [4 x ptr] %"x'", 2
25+
// CHECK-NEXT: %7 = load double, ptr %6, align 8, !alias.scope !15960, !noalias !15961
26+
// CHECK-NEXT: %8 = fadd fast double %7, 6.000000e+00
27+
// CHECK-NEXT: store double %8, ptr %6, align 8, !alias.scope !15960, !noalias !15961
28+
// CHECK-NEXT: %9 = extractvalue [4 x ptr] %"x'", 3
29+
// CHECK-NEXT: %10 = load double, ptr %9, align 8, !alias.scope !15962, !noalias !15963
30+
// CHECK-NEXT: %11 = fadd fast double %10, 6.000000e+00
31+
// CHECK-NEXT: store double %11, ptr %9, align 8, !alias.scope !15962, !noalias !15963
32+
// CHECK-NEXT: ret void
33+
// CHECK-NEXT:}
34+
35+
fn main() {
36+
let x = 3.0;
37+
let output = square(&x);
38+
assert_eq!(9.0, output);
39+
40+
let mut df_dx1 = 0.0;
41+
let mut df_dx2 = 0.0;
42+
let mut df_dx3 = 0.0;
43+
let mut df_dx4 = 0.0;
44+
let [o1, o2, o3, o4] = d_square(&x, &mut df_dx1, &mut df_dx2, &mut df_dx3, &mut df_dx4, 1.0);
45+
assert_eq!(output, o1);
46+
assert_eq!(output, o2);
47+
assert_eq!(output, o3);
48+
assert_eq!(output, o4);
49+
assert_eq!(6.0, df_dx1);
50+
assert_eq!(6.0, df_dx2);
51+
assert_eq!(6.0, df_dx3);
52+
assert_eq!(6.0, df_dx4);
53+
}

0 commit comments

Comments
 (0)
Please sign in to comment.