Skip to content

Commit c5adbf2

Browse files
committed
addressing most feedback
1 parent 1f64d63 commit c5adbf2

File tree

8 files changed

+46
-50
lines changed

8 files changed

+46
-50
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+10
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,16 @@ pub struct AutoDiffAttrs {
7777
/// e.g. in the [JAX
7878
/// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
7979
pub mode: DiffMode,
80+
/// A user-provided, batching width. If not given, we will default to 1 (no batching).
81+
/// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
82+
/// - Calling the function 50 times with a batch size of 2
83+
/// - Calling the function 25 times with a batch size of 4,
84+
/// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
85+
/// cache locality, better re-usal of primal values, and other optimizations.
86+
/// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
87+
/// times, so this massively increases code size. As such, values like 1024 are unlikely to
88+
/// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
89+
/// experiments for now and focus on documenting the implications of a large width.
8090
pub width: u32,
8191
pub ret_activity: DiffActivity,
8292
pub input_activity: Vec<DiffActivity>,

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+20-27
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ use std::ptr;
33
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, AutoDiffItem, DiffActivity, DiffMode};
44
use rustc_codegen_ssa::ModuleCodegen;
55
use rustc_codegen_ssa::back::write::ModuleConfig;
6+
use rustc_codegen_ssa::common::TypeKind;
7+
use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
68
use rustc_errors::FatalError;
79
use rustc_middle::bug;
810
use tracing::{debug, trace};
@@ -18,18 +20,18 @@ use crate::value::Value;
1820
use crate::{CodegenContext, LlvmCodegenBackend, ModuleLlvm, attributes, llvm};
1921

2022
fn get_params(fnc: &Value) -> Vec<&Value> {
23+
let param_num = llvm::LLVMCountParams(fnc) as usize;
24+
let mut fnc_args: Vec<&Value> = vec![];
25+
fnc_args.reserve(param_num);
2126
unsafe {
22-
let param_num = llvm::LLVMCountParams(fnc) as usize;
23-
let mut fnc_args: Vec<&Value> = vec![];
24-
fnc_args.reserve(param_num);
2527
llvm::LLVMGetParams(fnc, fnc_args.as_mut_ptr());
2628
fnc_args.set_len(param_num);
27-
fnc_args
2829
}
30+
fnc_args
2931
}
3032

3133
fn has_sret(fnc: &Value) -> bool {
32-
let num_args = unsafe { llvm::LLVMCountParams(fnc) as usize };
34+
let num_args = llvm::LLVMCountParams(fnc) as usize;
3335
if num_args == 0 {
3436
false
3537
} else {
@@ -121,23 +123,15 @@ fn match_args_from_caller_to_enzyme<'ll>(
121123
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
122124
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
123125
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
124-
assert!(unsafe {
125-
llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Integer
126-
});
126+
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
127127

128128
for _ in 0..width {
129129
let next_outer_arg2 = outer_args[outer_pos + 2];
130130
let next_outer_ty2 = cx.val_ty(next_outer_arg2);
131-
assert!(
132-
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty2) }
133-
== llvm::TypeKind::Pointer
134-
);
131+
assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
135132
let next_outer_arg3 = outer_args[outer_pos + 3];
136133
let next_outer_ty3 = cx.val_ty(next_outer_arg3);
137-
assert!(
138-
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty3) }
139-
== llvm::TypeKind::Integer
140-
);
134+
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
141135
args.push(next_outer_arg2);
142136
}
143137
args.push(cx.get_metadata_value(enzyme_const));
@@ -150,10 +144,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
150144
// (..., metadata! enzyme_dup, ptr, ptr, ...).
151145
if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
152146
{
153-
assert!(
154-
unsafe { llvm::LLVMRustGetTypeKind(next_outer_ty) }
155-
== llvm::TypeKind::Pointer
156-
);
147+
assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
157148
}
158149
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
159150
args.push(next_outer_arg);
@@ -213,8 +204,8 @@ fn compute_enzyme_fn_ty<'ll>(
213204
todo!("Handle sret for scalar ad");
214205
} else {
215206
// First we check if we also have to deal with the primal return.
216-
if attrs.mode.is_fwd() {
217-
match attrs.ret_activity {
207+
match attrs.mode {
208+
DiffMode::Forward => match attrs.ret_activity {
218209
DiffActivity::Dual => {
219210
let arr_ty =
220211
unsafe { llvm::LLVMArrayType2(inner_ret_ty, attrs.width as u64 + 1) };
@@ -231,11 +222,13 @@ fn compute_enzyme_fn_ty<'ll>(
231222
_ => {
232223
bug!("unreachable");
233224
}
225+
},
226+
DiffMode::Reverse => {
227+
todo!("Handle sret for reverse mode");
228+
}
229+
_ => {
230+
bug!("unreachable");
234231
}
235-
} else if attrs.mode.is_rev() {
236-
todo!("Handle sret for reverse mode");
237-
} else {
238-
bug!("unreachable");
239232
}
240233
}
241234
}
@@ -395,7 +388,7 @@ fn generate_enzyme_call<'ll>(
395388
// now store the result of the enzyme call into the sret pointer.
396389
let sret_ptr = outer_args[0];
397390
let call_ty = cx.val_ty(call);
398-
assert!(llvm::LLVMRustIsArrayTy(call_ty));
391+
assert_eq!(cx.type_kind(call_ty), TypeKind::Array);
399392
llvm::LLVMBuildStore(&builder.llbuilder, call, sret_ptr);
400393
}
401394
builder.ret_void();

compiler/rustc_codegen_llvm/src/context.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::str;
88
use rustc_abi::{HasDataLayout, Size, TargetDataLayout, VariantIdx};
99
use rustc_codegen_ssa::back::versioned_llvm_target;
1010
use rustc_codegen_ssa::base::{wants_msvc_seh, wants_wasm_eh};
11+
use rustc_codegen_ssa::common::TypeKind;
1112
use rustc_codegen_ssa::errors as ssa_errors;
1213
use rustc_codegen_ssa::traits::*;
1314
use rustc_data_structures::base_n::{ALPHANUMERIC_ONLY, ToBaseN};
@@ -645,7 +646,7 @@ impl<'ll, 'tcx> CodegenCx<'ll, 'tcx> {
645646
}
646647
impl<'ll> SimpleCx<'ll> {
647648
pub(crate) fn get_return_type(&self, ty: &'ll Type) -> &'ll Type {
648-
assert!(unsafe { llvm::LLVMRustIsFunctionTy(ty) });
649+
assert_eq!(self.type_kind(ty), TypeKind::Function);
649650
unsafe { llvm::LLVMGetReturnType(ty) }
650651
}
651652
pub(crate) fn get_type_of_global(&self, val: &'ll Value) -> &'ll Type {
@@ -671,6 +672,8 @@ impl<'ll, CX: Borrow<SCx<'ll>>> GenericCx<'ll, CX> {
671672
llvm::LLVMMetadataAsValue(self.llcx(), metadata)
672673
}
673674

675+
// FIXME(autodiff): We should split `ConstCodegenMethods` to pull the reusable parts
676+
// onto a trait that is also implemented for GenericCx.
674677
pub(crate) fn get_const_i64(&self, n: u64) -> &'ll Value {
675678
let ty = unsafe { llvm::LLVMInt64TypeInContext(self.llcx()) };
676679
unsafe { llvm::LLVMConstInt(ty, n, llvm::False) }

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

-3
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,6 @@ unsafe extern "C" {
1818
pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
1919
pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
2020
pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
21-
22-
pub(crate) fn LLVMRustIsFunctionTy(Ty: &Type) -> bool;
23-
pub(crate) fn LLVMRustIsArrayTy(Ty: &Type) -> bool;
2421
pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
2522
}
2623

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1172,7 +1172,7 @@ unsafe extern "C" {
11721172

11731173
// Operations on parameters
11741174
pub(crate) fn LLVMIsAArgument(Val: &Value) -> Option<&Value>;
1175-
pub(crate) fn LLVMCountParams(Fn: &Value) -> c_uint;
1175+
pub(crate) safe fn LLVMCountParams(Fn: &Value) -> c_uint;
11761176
pub(crate) fn LLVMGetParam(Fn: &Value, Index: c_uint) -> &Value;
11771177

11781178
// Operations on basic blocks

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+3-3
Original file line numberDiff line numberDiff line change
@@ -831,7 +831,7 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
831831
Err(_) => {
832832
span_bug!(w.span, "rustc_autodiff width should fit u32");
833833
}
834-
};
834+
}
835835
}
836836
MetaItemInner::Lit(lit) => {
837837
if let LitKind::Int(val, _) = lit.kind {
@@ -840,12 +840,12 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
840840
Err(_) => {
841841
span_bug!(lit.span, "rustc_autodiff width should fit u32");
842842
}
843-
};
843+
}
844844
} else {
845845
span_bug!(lit.span, "rustc_autodiff width should be an integer");
846846
}
847847
}
848-
}
848+
};
849849

850850
// First read the ret symbol from the attribute
851851
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: p1, .. }) = ret_activity {

compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

-7
Original file line numberDiff line numberDiff line change
@@ -641,13 +641,6 @@ static InlineAsm::AsmDialect fromRust(LLVMRustAsmDialect Dialect) {
641641
report_fatal_error("bad AsmDialect.");
642642
}
643643
}
644-
extern "C" bool LLVMRustIsFunctionTy(LLVMTypeRef Ty) {
645-
return unwrap(Ty)->isFunctionTy();
646-
}
647-
648-
extern "C" bool LLVMRustIsArrayTy(LLVMTypeRef Ty) {
649-
return unwrap(Ty)->isArrayTy();
650-
}
651644

652645
extern "C" uint64_t LLVMRustGetArrayNumElements(LLVMTypeRef Ty) {
653646
return unwrap(Ty)->getArrayNumElements();

tests/codegen/autodiffv.rs

+8-8
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@ fn square(x: &f32) -> f32 {
1818
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
1919
// CHECK-NEXT: start:
2020
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
21-
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !38, !noalias !39
21+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
2222
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
23-
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !40, !noalias !41
23+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
2424
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
25-
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !42, !noalias !43
25+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
2626
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
27-
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !44, !noalias !45
27+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
2828
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
2929
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1
3030
// CHECK-NEXT: %6 = insertelement <4 x float> %5, float %"_2'ipl2", i64 2
@@ -48,13 +48,13 @@ fn square(x: &f32) -> f32 {
4848
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.1(float %x.0.val, [4 x ptr] %"x'")
4949
// CHECK-NEXT: start:
5050
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
51-
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4, !alias.scope !46, !noalias !47
51+
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
5252
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
53-
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4, !alias.scope !48, !noalias !49
53+
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
5454
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
55-
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4, !alias.scope !50, !noalias !51
55+
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
5656
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
57-
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4, !alias.scope !52, !noalias !53
57+
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
5858
// CHECK-NEXT: %_0 = fmul float %x.0.val, %x.0.val
5959
// CHECK-NEXT: %4 = insertelement <4 x float> poison, float %"_2'ipl", i64 0
6060
// CHECK-NEXT: %5 = insertelement <4 x float> %4, float %"_2'ipl1", i64 1

0 commit comments

Comments
 (0)