Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions tests/codegen-llvm/autodiff/abi_handling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ fn square(x: f32) -> f32 {
// CHECK-LABEL: ; abi_handling::df1
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal { float, float }
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0)
// debug-SAME: (ptr {{.*}}%x, ptr {{.*}}%bx_0)
// release-NEXT: define internal fastcc float
// release-SAME: (float %x.0.val, float %x.4.val)

// CHECK-LABEL: ; abi_handling::f1
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal float
// debug-SAME: (ptr align 4 %x)
// debug-SAME: (ptr {{.*}}%x)
// release-NEXT: define internal fastcc noundef float
// release-SAME: (float %x.0.val, float %x.4.val)
#[autodiff_forward(df1, Dual, Dual)]
Expand All @@ -58,7 +58,7 @@ fn f1(x: &[f32; 2]) -> f32 {
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal { float, float }
// debug-SAME: (ptr %f, float %x, float %dret)
// release-NEXT: define internal fastcc float
// release-NEXT: define internal fastcc noundef float
// release-SAME: (float noundef %x)

// CHECK-LABEL: ; abi_handling::f2
Expand All @@ -77,13 +77,13 @@ fn f2(f: fn(f32) -> f32, x: f32) -> f32 {
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal { float, float }
// debug-SAME: (ptr align 4 %x, ptr align 4 %bx_0, ptr align 4 %y, ptr align 4 %by_0)
// release-NEXT: define internal fastcc { float, float }
// release-NEXT: define internal fastcc float
// release-SAME: (float %x.0.val)

// CHECK-LABEL: ; abi_handling::f3
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal float
// debug-SAME: (ptr align 4 %x, ptr align 4 %y)
// debug-SAME: (ptr {{.*}}%x, ptr {{.*}}%y)
// release-NEXT: define internal fastcc noundef float
// release-SAME: (float %x.0.val)
#[autodiff_forward(df3, Dual, Dual, Dual)]
Expand Down Expand Up @@ -160,7 +160,7 @@ fn f6(i: NestedInput) -> f32 {
// CHECK-LABEL: ; abi_handling::f7
// CHECK-NEXT: Function Attrs
// debug-NEXT: define internal float
// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1)
// debug-SAME: (ptr {{.*}}%x.0, ptr {{.*}}%x.1)
// release-NEXT: define internal fastcc noundef float
// release-SAME: (float %x.0.0.val, float %x.1.0.val)
#[autodiff_forward(df7, Dual, Dual)]
Expand Down
83 changes: 16 additions & 67 deletions tests/codegen-llvm/autodiff/batched.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
//@ compile-flags: -Zautodiff=Enable,NoTT,NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//
// In Enzyme, we test against a large range of LLVM versions (5+) and don't have overly many
// breakages. One benefit is that we match the IR generated by Enzyme only after running it
// through LLVM's O3 pipeline, which will remove most of the noise.
// However, our integration test could also be affected by changes in how rustc lowers MIR into
// LLVM-IR, which could cause additional noise and thus breakages. If that's the case, we should
// reduce this test to only match the first lines and the ret instructions.

// This test combines two features of Enzyme, automatic differentiation and batching. As such, it is
// especially prone to breakages. I reduced it therefore to a minimal check matches argument/return
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's especially prone to breakages, but it definitely means the IR doens't have a guaranteed stable output (in contrast to functionally equivalent being stable).

Also note Enzyme does have this in MLIR already

// types. Based on the original batching author, implementing the batching feature over MLIR instead
// of LLVM should give significantly more reliable performance.

#![feature(autodiff)]

Expand All @@ -22,69 +20,20 @@ fn square(x: &f32) -> f32 {
x * x
}

// The base ("scalar") case d_square3, without batching.
// CHECK: define internal fastcc float @fwddiffesquare(float %x.0.val, float %"x'.0.val")
// CHECK: %0 = fadd fast float %"x'.0.val", %"x'.0.val"
// CHECK-NEXT: %1 = fmul fast float %0, %x.0.val
// CHECK-NEXT: ret float %1
// CHECK-NEXT: }

// d_square2
// CHECK: define internal [4 x float] @fwddiffe4square(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %_2 = load float, ptr %x, align 4
// CHECK-NEXT: %4 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT: %5 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT: %6 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT: %7 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT: %8 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT: %9 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT: %10 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT: %11 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT: %12 = fadd fast float %4, %8
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = fadd fast float %5, %9
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = fadd fast float %6, %10
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = fadd fast float %7, %11
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: ret [4 x float] %19
// CHECK: define internal fastcc [4 x float] @fwddiffe4square(float %x.0.val, [4 x ptr] %"x'")
// CHECK: ret [4 x float]
// CHECK-NEXT: }

// d_square3, the extra float is the original return value (x * x)
// CHECK: define internal { float, [4 x float] } @fwddiffe4square.1(ptr noalias noundef readonly align 4 captures(none) dereferenceable(4) %x, [4 x ptr] %"x'")
// CHECK-NEXT: start:
// CHECK-NEXT: %0 = extractvalue [4 x ptr] %"x'", 0
// CHECK-NEXT: %"_2'ipl" = load float, ptr %0, align 4
// CHECK-NEXT: %1 = extractvalue [4 x ptr] %"x'", 1
// CHECK-NEXT: %"_2'ipl1" = load float, ptr %1, align 4
// CHECK-NEXT: %2 = extractvalue [4 x ptr] %"x'", 2
// CHECK-NEXT: %"_2'ipl2" = load float, ptr %2, align 4
// CHECK-NEXT: %3 = extractvalue [4 x ptr] %"x'", 3
// CHECK-NEXT: %"_2'ipl3" = load float, ptr %3, align 4
// CHECK-NEXT: %_2 = load float, ptr %x, align 4
// CHECK-NEXT: %_0 = fmul float %_2, %_2
// CHECK-NEXT: %4 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT: %5 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT: %6 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT: %7 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT: %8 = fmul fast float %"_2'ipl", %_2
// CHECK-NEXT: %9 = fmul fast float %"_2'ipl1", %_2
// CHECK-NEXT: %10 = fmul fast float %"_2'ipl2", %_2
// CHECK-NEXT: %11 = fmul fast float %"_2'ipl3", %_2
// CHECK-NEXT: %12 = fadd fast float %4, %8
// CHECK-NEXT: %13 = insertvalue [4 x float] undef, float %12, 0
// CHECK-NEXT: %14 = fadd fast float %5, %9
// CHECK-NEXT: %15 = insertvalue [4 x float] %13, float %14, 1
// CHECK-NEXT: %16 = fadd fast float %6, %10
// CHECK-NEXT: %17 = insertvalue [4 x float] %15, float %16, 2
// CHECK-NEXT: %18 = fadd fast float %7, %11
// CHECK-NEXT: %19 = insertvalue [4 x float] %17, float %18, 3
// CHECK-NEXT: %20 = insertvalue { float, [4 x float] } undef, float %_0, 0
// CHECK-NEXT: %21 = insertvalue { float, [4 x float] } %20, [4 x float] %19, 1
// CHECK-NEXT: ret { float, [4 x float] } %21
// CHECK: define internal fastcc { float, [4 x float] } @fwddiffe4square.{{.*}}(float %x.0.val, [4 x ptr] %"x'")
// CHECK: ret { float, [4 x float] }
// CHECK-NEXT: }

fn main() {
Expand Down
43 changes: 28 additions & 15 deletions tests/codegen-llvm/autodiff/generic.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
//@ compile-flags: -Zautodiff=Enable -Zautodiff=NoPostopt -C opt-level=3 -Clto=fat
//@ no-prefer-dynamic
//@ needs-enzyme
//@ revisions: F32 F64 Main

// Here we verify that the function `square` can be differentiated over f64.
// This is interesting to test, since the user never calls `square` with f64, so on it's own rustc
// would have no reason to monomorphize it that way. However, Enzyme needs the f64 version of
// `square` in order to be able to differentiate it, so we have logic to enforce the
// monomorphization. Here, we test this logic.

#![feature(autodiff)]

use std::autodiff::autodiff_reverse;
Expand All @@ -12,32 +20,37 @@ fn square<T: std::ops::Mul<Output = T> + Copy>(x: &T) -> T {
}

// Ensure that `d_square::<f32>` code is generated
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs: {{.*}}
// CHECK-NEXT: define internal {{.*}} float
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul float

// F32-LABEL: ; generic::square::<f32>
// F32-NEXT: ; Function Attrs: {{.*}}
// F32-NEXT: define internal {{.*}} float
// F32-NEXT: start:
// F32-NOT: ret
// F32: fmul float

// Ensure that `d_square::<f64>` code is generated even if `square::<f64>` was never called
//
// CHECK: ; generic::square
// CHECK-NEXT: ; Function Attrs:
// CHECK-NEXT: define internal {{.*}} double
// CHECK-NEXT: start:
// CHECK-NOT: ret
// CHECK: fmul double

// F64-LABEL: ; generic::d_square::<f64>
// F64-NEXT: ; Function Attrs: {{.*}}
// F64-NEXT: define internal {{.*}} void
// F64-NEXT: start:
// F64-NEXT: {{(tail )?}}call {{(fastcc )?}}void @diffe_{{.*}}(double {{.*}}, ptr {{.*}})
// F64-NEXT: ret void

// Main-LABEL: ; generic::main
// Main: ; call generic::square::<f32>
// Main: ; call generic::d_square::<f64>

fn main() {
let xf32: f32 = std::hint::black_box(3.0);
let xf64: f64 = std::hint::black_box(3.0);
let seed: f64 = std::hint::black_box(1.0);

let outputf32 = square::<f32>(&xf32);
assert_eq!(9.0, outputf32);

let mut df_dxf64: f64 = std::hint::black_box(0.0);

let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, 1.0);
let output_f64 = d_square::<f64>(&xf64, &mut df_dxf64, seed);
assert_eq!(6.0, df_dxf64);
}
8 changes: 4 additions & 4 deletions tests/codegen-llvm/autodiff/identical_fnc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
// merged placeholder function anymore, and compilation would fail. We prevent this by disabling
// LLVM's merge_function pass before AD. Here we implicetely test that our solution keeps working.
// We also explicetly test that we keep running merge_function after AD, by checking for two
// identical function calls in the LLVM-IR, while having two different calls in the Rust code.
// identical function calls in the LLVM-IR, despite having two different calls in the Rust code.
#![feature(autodiff)]

use std::autodiff::autodiff_reverse;
Expand All @@ -27,14 +27,14 @@ fn square2(x: &f64) -> f64 {

// CHECK:; identical_fnc::main
// CHECK-NEXT:; Function Attrs:
// CHECK-NEXT:define internal void @_ZN13identical_fnc4main17h6009e4f751bf9407E()
// CHECK-NEXT:define internal void
// CHECK-NEXT:start:
// CHECK-NOT:br
// CHECK-NOT:ret
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square[[HASH:.+]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK-NEXT:call fastcc void @[[HASH:.+]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx1)
// CHECK:; call identical_fnc::d_square
// CHECK-NEXT:call fastcc void @_ZN13identical_fnc8d_square[[HASH]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)
// CHECK-NEXT:call fastcc void @[[HASH]](double %x.val, ptr noalias noundef align 8 dereferenceable(8) %dx2)

fn main() {
let x = std::hint::black_box(3.0);
Expand Down
Loading