From 727b8b553470cdb01aff01dc8b7e840afb8ac52b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 15 Jun 2025 15:56:30 +0000 Subject: [PATCH 1/5] Adjust autodiff activities for abi transformations (small tuples and structs) --- .../src/partitioning/autodiff.rs | 35 +++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 22d593b80b895..9f6bcf18e8943 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -1,3 +1,4 @@ +use rustc_abi::HasDataLayout; use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity}; use rustc_hir::def_id::LOCAL_CRATE; use rustc_middle::bug; @@ -16,6 +17,7 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec // We don't actually pass the types back into the type system. // All we do is decide how to handle the arguments. let sig = fn_ty.fn_sig(tcx).skip_binder(); + let pointer_size = tcx.data_layout().pointer_size; let mut new_activities = vec![]; let mut new_positions = vec![]; @@ -70,6 +72,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec continue; } } + + let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty }; + + let layout = match tcx.layout_of(pci) { + Ok(layout) => layout.layout, + Err(_) => { + bug!("failed to compute layout for type {:?}", ty); + } + }; + + let is_product = |t: Ty<'tcx>| matches!(t.kind(), ty::Tuple(_) | ty::Adt(_, _)); + + if layout.size() <= pointer_size * 2 && is_product(*ty) { + let n_scalars = count_scalar_fields(tcx, *ty); + for _ in 0..n_scalars.saturating_sub(1) { + new_activities.push(da[i].clone()); + new_positions.push(i + 1); + } + } } // now add the extra activities coming from slices // Reverse order to not invalidate the indices @@ -80,6 +101,20 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec } } +fn count_scalar_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize { + match ty.kind() { + ty::Float(_) | ty::Int(_) | ty::Uint(_) => 1, + ty::Adt(def, substs) if def.is_struct() => def + .non_enum_variant() + .fields + .iter() + .map(|f| count_scalar_fields(tcx, f.ty(tcx, substs))) + .sum(), + ty::Tuple(substs) => substs.iter().map(|t| count_scalar_fields(tcx, t)).sum(), + _ => 0, + } +} + pub(crate) fn find_autodiff_source_functions<'tcx>( tcx: TyCtxt<'tcx>, usage_map: &UsageMap<'tcx>, From a96d085be4865893febea8b0da50f191d157bf0a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 15 Jun 2025 16:04:26 +0000 Subject: [PATCH 2/5] Add codegen tests Note(Sa4dUs): As LLVM-IR opt passes are executed after passing LLVM to Enzyme, most of the cases have turned out to not be problematic. Anyways, we still test them to prevent any kind of regression. --- tests/codegen/autodiff/abi_handling.rs | 263 +++++++++++++++++++++++++ 1 file changed, 263 insertions(+) create mode 100644 tests/codegen/autodiff/abi_handling.rs diff --git a/tests/codegen/autodiff/abi_handling.rs b/tests/codegen/autodiff/abi_handling.rs new file mode 100644 index 0000000000000..4f6cb575a00b4 --- /dev/null +++ b/tests/codegen/autodiff/abi_handling.rs @@ -0,0 +1,263 @@ +//@ revisions: debug release + +//@[debug] compile-flags: -Zautodiff=Enable -C opt-level=0 -Clto=fat +//@[release] compile-flags: -Zautodiff=Enable -C opt-level=3 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This does only test the funtion attribute handling for autodiff. +// Function argument changes are troublesome for Enzyme, so we have to +// ensure that arguments remain the same, or if we change them, be aware +// of the changes to handle it correctly. + +#![feature(autodiff)] + +use std::autodiff::{autodiff_forward, autodiff_reverse}; + +#[derive(Copy, Clone)] +struct Input { + x: f32, + y: f32, +} + +#[derive(Copy, Clone)] +struct Wrapper { + z: f32, +} + +#[derive(Copy, Clone)] +struct NestedInput { + x: f32, + y: Wrapper, +} + +fn square(x: f32) -> f32 { + x * x +} + +// CHECK: ; abi_handling::f1 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val) +#[autodiff_forward(df1, Dual, Dual)] +fn f1(x: &[f32; 2]) -> f32 { + x[0] + x[1] +} + +// CHECK: ; abi_handling::f2 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x) +#[autodiff_reverse(df2, Const, Active, Active)] +fn f2(f: fn(f32) -> f32, x: f32) -> f32 { + f(x) +} + +// CHECK: ; abi_handling::f3 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %y) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val) +#[autodiff_forward(df3, Dual, Dual, Dual)] +fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 { + *x * *y +} + +// CHECK: ; abi_handling::f4 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %x.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float noundef %x.1) +#[autodiff_forward(df4, Dual, Dual)] +fn f4(x: (f32, f32)) -> f32 { + x.0 * x.1 +} + +// CHECK: ; abi_handling::f5 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float noundef %i.1) +#[autodiff_forward(df5, Dual, Dual)] +fn f5(i: Input) -> f32 { + i.x + i.y +} + +// CHECK: ; abi_handling::f6 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float noundef %i.1) +#[autodiff_forward(df6, Dual, Dual)] +fn f6(i: NestedInput) -> f32 { + i.x + i.y.z * i.y.z +} + +// df1 +// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val) +// release-NEXT: start: +// release-NEXT: %_0 = fadd float %x.0.val, %x.4.val +// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float 1.000000e+00, 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x, ptr align 4 %"x'") +// debug-NEXT: start: +// debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0 +// debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0 +// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4, !alias.scope !4, !noalias !7 +// debug-NEXT: %_2 = load float, ptr %0, align 4, !alias.scope !7, !noalias !4 +// debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1 +// debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1 +// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4, !alias.scope !4, !noalias !7 +// debug-NEXT: %_5 = load float, ptr %1, align 4, !alias.scope !7, !noalias !4 +// debug-NEXT: %_0 = fadd float %_2, %_5 +// debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl" +// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 +// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 +// debug-NEXT: ret { float, float } %4 +// debug-NEXT: } + +// df2 +// release: define internal fastcc { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x) +// release-NEXT: invertstart: +// release-NEXT: %_0.i = fmul float %x, %x +// release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float 0.000000e+00, 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x, float %differeturn) +// debug-NEXT: start: +// debug-NEXT: %"x'de" = alloca float, align 4 +// debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4 +// debug-NEXT: %toreturn = alloca float, align 4 +// debug-NEXT: %_0 = call float %f(float %x) #12 +// debug-NEXT: store float %_0, ptr %toreturn, align 4 +// debug-NEXT: br label %invertstart +// debug-EMPTY: +// debug-NEXT: invertstart: ; preds = %start +// debug-NEXT: %retreload = load float, ptr %toreturn, align 4 +// debug-NEXT: %0 = load float, ptr %"x'de", align 4 +// debug-NEXT: %1 = insertvalue { float, float } undef, float %retreload, 0 +// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1 +// debug-NEXT: ret { float, float } %2 +// debug-NEXT: } + +// df3 +// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val) +// release-NEXT: start: +// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'") +// debug-NEXT: start: +// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4, !alias.scope !9, !noalias !12 +// debug-NEXT: %_3 = load float, ptr %x, align 4, !alias.scope !12, !noalias !9 +// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4, !alias.scope !14, !noalias !17 +// debug-NEXT: %_4 = load float, ptr %y, align 4, !alias.scope !17, !noalias !14 +// debug-NEXT: %_0 = fmul float %_3, %_4 +// debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4 +// debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3 +// debug-NEXT: %2 = fadd fast float %0, %1 +// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 +// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 +// debug-NEXT: ret { float, float } %4 +// debug-NEXT: } + +// df4 +// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float %"x.0'") +// release-NEXT: start: +// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %"x.0'", float %x.1, float %"x.1'") +// debug-NEXT: start: +// debug-NEXT: %_0 = fmul float %x.0, %x.1 +// debug-NEXT: %0 = fmul fast float %"x.0'", %x.1 +// debug-NEXT: %1 = fmul fast float %"x.1'", %x.0 +// debug-NEXT: %2 = fadd fast float %0, %1 +// debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 +// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 +// debug-NEXT: ret { float, float } %4 +// debug-NEXT: } + +// df5 +// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float %"i.0'") +// release-NEXT: start: +// release-NEXT: %_0 = fadd float %i.0, 1.000000e+00 +// release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float %"i.0'", 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %"i.0'", float %i.1, float %"i.1'") +// debug-NEXT: start: +// debug-NEXT: %_0 = fadd float %i.0, %i.1 +// debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'" +// debug-NEXT: %1 = insertvalue { float, float } undef, float %_0, 0 +// debug-NEXT: %2 = insertvalue { float, float } %1, float %0, 1 +// debug-NEXT: ret { float, float } %2 +// debug-NEXT: } + +// df6 +// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'") +// release-NEXT: start: +// release-NEXT: %_3 = fmul float %i.1, %i.1 +// release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'" +// release-NEXT: %1 = fmul fast float %0, %i.1 +// release-NEXT: %_0 = fadd float %i.0, %_3 +// release-NEXT: %2 = fadd fast float %"i.0'", %1 +// release-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 +// release-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 +// release-NEXT: ret { float, float } %4 +// release-NEXT: } + +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %"i.0'", float %i.1, float %"i.1'") +// debug-NEXT: start: +// debug-NEXT: %_3 = fmul float %i.1, %i.1 +// debug-NEXT: %0 = fmul fast float %"i.1'", %i.1 +// debug-NEXT: %1 = fmul fast float %"i.1'", %i.1 +// debug-NEXT: %2 = fadd fast float %0, %1 +// debug-NEXT: %_0 = fadd float %i.0, %_3 +// debug-NEXT: %3 = fadd fast float %"i.0'", %2 +// debug-NEXT: %4 = insertvalue { float, float } undef, float %_0, 0 +// debug-NEXT: %5 = insertvalue { float, float } %4, float %3, 1 +// debug-NEXT: ret { float, float } %5 +// debug-NEXT: } + +fn main() { + let x = std::hint::black_box(2.0); + let y = std::hint::black_box(3.0); + let z = std::hint::black_box(4.0); + static Y: f32 = std::hint::black_box(3.2); + + let in_f1 = [x, y]; + dbg!(f1(&in_f1)); + let res_f1 = df1(&in_f1, &[1.0, 0.0]); + dbg!(res_f1); + + dbg!(f2(square, x)); + let res_f2 = df2(square, x, 1.0); + dbg!(res_f2); + + dbg!(f3(&x, &Y)); + let res_f3 = df3(&x, &Y, &1.0, &0.0); + dbg!(res_f3); + + let in_f4 = (x, y); + dbg!(f4(in_f4)); + let res_f4 = df4(in_f4, (1.0, 0.0)); + dbg!(res_f4); + + let in_f5 = Input { x, y }; + dbg!(f5(in_f5)); + let res_f5 = df5(in_f5, Input { x: 1.0, y: 0.0 }); + dbg!(res_f5); + + let in_f6 = NestedInput { x, y: Wrapper { z: y } }; + dbg!(f6(in_f6)); + let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } }); + dbg!(res_f6); +} From 56dbc6a3a5537f2a935c62628fa9d2091464bbb4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 24 Jun 2025 14:24:06 +0000 Subject: [PATCH 3/5] Split tests for tidy checks --- tests/codegen/autodiff/abi_handling.rs | 78 ++++++++++++++++++-------- 1 file changed, 54 insertions(+), 24 deletions(-) diff --git a/tests/codegen/autodiff/abi_handling.rs b/tests/codegen/autodiff/abi_handling.rs index 4f6cb575a00b4..5de4463c6cf17 100644 --- a/tests/codegen/autodiff/abi_handling.rs +++ b/tests/codegen/autodiff/abi_handling.rs @@ -37,8 +37,10 @@ fn square(x: f32) -> f32 { // CHECK: ; abi_handling::f1 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val) +// debug-NEXT: define internal float @_ZN12abi_handling2f117h536ac8081c1e4101E +// debug-SAME: (ptr align 4 %x) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f117h536ac8081c1e4101E +// release-SAME: (float %x.0.val, float %x.4.val) #[autodiff_forward(df1, Dual, Dual)] fn f1(x: &[f32; 2]) -> f32 { x[0] + x[1] @@ -46,8 +48,10 @@ fn f1(x: &[f32; 2]) -> f32 { // CHECK: ; abi_handling::f2 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x) +// debug-NEXT: define internal float @_ZN12abi_handling2f217h33732e9f83c91bc9E +// debug-SAME: (ptr %f, float %x) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f217h33732e9f83c91bc9E +// release-SAME: (float noundef %x) #[autodiff_reverse(df2, Const, Active, Active)] fn f2(f: fn(f32) -> f32, x: f32) -> f32 { f(x) @@ -55,8 +59,10 @@ fn f2(f: fn(f32) -> f32, x: f32) -> f32 { // CHECK: ; abi_handling::f3 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %y) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val) +// debug-NEXT: define internal float @_ZN12abi_handling2f317h9cd1fc602b0815a4E +// debug-SAME: (ptr align 4 %x, ptr align 4 %y) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f317h9cd1fc602b0815a4E +// release-SAME: (float %x.0.val) #[autodiff_forward(df3, Dual, Dual, Dual)] fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 { *x * *y @@ -64,8 +70,10 @@ fn f3<'a>(x: &'a f32, y: &'a f32) -> f32 { // CHECK: ; abi_handling::f4 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %x.1) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float noundef %x.1) +// debug-NEXT: define internal float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE +// debug-SAME: (float %x.0, float %x.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f417h2f4a9a7492d91e9fE +// release-SAME: (float noundef %x.0, float noundef %x.1) #[autodiff_forward(df4, Dual, Dual)] fn f4(x: (f32, f32)) -> f32 { x.0 * x.1 @@ -73,8 +81,10 @@ fn f4(x: (f32, f32)) -> f32 { // CHECK: ; abi_handling::f5 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %i.1) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float noundef %i.1) +// debug-NEXT: define internal float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E +// debug-SAME: (float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f517hf8d4ac4d2c2a3976E +// release-SAME: (float noundef %i.0, float noundef %i.1) #[autodiff_forward(df5, Dual, Dual)] fn f5(i: Input) -> f32 { i.x + i.y @@ -82,15 +92,19 @@ fn f5(i: Input) -> f32 { // CHECK: ; abi_handling::f6 // CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} -// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %i.1) -// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float noundef %i.1) +// debug-NEXT: define internal float @_ZN12abi_handling2f617h5784b207bbb2483eE +// debug-SAME: (float %i.0, float %i.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f617h5784b207bbb2483eE +// release-SAME: (float noundef %i.0, float noundef %i.1) #[autodiff_forward(df6, Dual, Dual)] fn f6(i: NestedInput) -> f32 { i.x + i.y.z * i.y.z } // df1 -// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(float %x.0.val, float %x.4.val) +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E +// release-SAME: (float %x.0.val, float %x.4.val) // release-NEXT: start: // release-NEXT: %_0 = fadd float %x.0.val, %x.4.val // release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 @@ -98,7 +112,8 @@ fn f6(i: NestedInput) -> f32 { // release-NEXT: ret { float, float } %1 // release-NEXT: } -// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E(ptr align 4 %x, ptr align 4 %"x'") +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E +// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'") // debug-NEXT: start: // debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0 // debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0 @@ -116,7 +131,9 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: } // df2 -// release: define internal fastcc { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(float noundef %x) +// release: define internal fastcc { float, float } +// release-SAME: @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E +// release-SAME: (float noundef %x) // release-NEXT: invertstart: // release-NEXT: %_0.i = fmul float %x, %x // release-NEXT: %0 = insertvalue { float, float } undef, float %_0.i, 0 @@ -124,7 +141,8 @@ fn f6(i: NestedInput) -> f32 { // release-NEXT: ret { float, float } %1 // release-NEXT: } -// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E(ptr %f, float %x, float %differeturn) +// debug: define internal { float, float } @diffe_ZN12abi_handling2f217h33732e9f83c91bc9E +// debug-SAME: (ptr %f, float %x, float %differeturn) // debug-NEXT: start: // debug-NEXT: %"x'de" = alloca float, align 4 // debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4 @@ -142,14 +160,17 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: } // df3 -// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(float %x.0.val) +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E +// release-SAME: (float %x.0.val) // release-NEXT: start: // release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.val, 0 // release-NEXT: %1 = insertvalue { float, float } %0, float 0x40099999A0000000, 1 // release-NEXT: ret { float, float } %1 // release-NEXT: } -// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E(ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'") +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E +// debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'") // debug-NEXT: start: // debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4, !alias.scope !9, !noalias !12 // debug-NEXT: %_3 = load float, ptr %x, align 4, !alias.scope !12, !noalias !9 @@ -165,14 +186,17 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: } // df4 -// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float noundef %x.0, float %"x.0'") +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE +// release-SAME: (float noundef %x.0, float %"x.0'") // release-NEXT: start: // release-NEXT: %0 = insertvalue { float, float } undef, float %x.0, 0 // release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'", 1 // release-NEXT: ret { float, float } %1 // release-NEXT: } -// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE(float %x.0, float %"x.0'", float %x.1, float %"x.1'") +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f417h2f4a9a7492d91e9fE +// debug-SAME: (float %x.0, float %"x.0'", float %x.1, float %"x.1'") // debug-NEXT: start: // debug-NEXT: %_0 = fmul float %x.0, %x.1 // debug-NEXT: %0 = fmul fast float %"x.0'", %x.1 @@ -184,7 +208,9 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: } // df5 -// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float noundef %i.0, float %"i.0'") +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E +// release-SAME: (float noundef %i.0, float %"i.0'") // release-NEXT: start: // release-NEXT: %_0 = fadd float %i.0, 1.000000e+00 // release-NEXT: %0 = insertvalue { float, float } undef, float %_0, 0 @@ -192,7 +218,8 @@ fn f6(i: NestedInput) -> f32 { // release-NEXT: ret { float, float } %1 // release-NEXT: } -// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E(float %i.0, float %"i.0'", float %i.1, float %"i.1'") +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f517hf8d4ac4d2c2a3976E +// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'") // debug-NEXT: start: // debug-NEXT: %_0 = fadd float %i.0, %i.1 // debug-NEXT: %0 = fadd fast float %"i.0'", %"i.1'" @@ -202,7 +229,9 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: } // df6 -// release: define internal fastcc { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'") +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE +// release-SAME: (float noundef %i.0, float %"i.0'", float noundef %i.1, float %"i.1'") // release-NEXT: start: // release-NEXT: %_3 = fmul float %i.1, %i.1 // release-NEXT: %0 = fadd fast float %"i.1'", %"i.1'" @@ -214,7 +243,8 @@ fn f6(i: NestedInput) -> f32 { // release-NEXT: ret { float, float } %4 // release-NEXT: } -// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE(float %i.0, float %"i.0'", float %i.1, float %"i.1'") +// debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f617h5784b207bbb2483eE +// debug-SAME: (float %i.0, float %"i.0'", float %i.1, float %"i.1'") // debug-NEXT: start: // debug-NEXT: %_3 = fmul float %i.1, %i.1 // debug-NEXT: %0 = fmul fast float %"i.1'", %i.1 From 39d1efc4a9c1b129a079686c68dd1e25ddf94a5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 26 Jun 2025 13:38:57 +0000 Subject: [PATCH 4/5] Update activity adjustment logic Update `count_scalar_fields`->`count_leaf_fields` to support more types Add extra activities only if `count_scalar_fields` is leq 2 Logic can be optimized if needed Removed metadata specific fields from test to avoid future fails. --- .../src/partitioning/autodiff.rs | 29 +++++++--- tests/codegen/autodiff/abi_handling.rs | 56 ++++++++++++++++--- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs index 9f6bcf18e8943..a5313296ed287 100644 --- a/compiler/rustc_monomorphize/src/partitioning/autodiff.rs +++ b/compiler/rustc_monomorphize/src/partitioning/autodiff.rs @@ -84,11 +84,15 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec let is_product = |t: Ty<'tcx>| matches!(t.kind(), ty::Tuple(_) | ty::Adt(_, _)); + // NOTE: When an ADT (Algebraic Data Type) has fewer than two fields and a total size less than pointer_size * 2, + // LLVM will pass its fields separately instead of as a single aggregate. if layout.size() <= pointer_size * 2 && is_product(*ty) { - let n_scalars = count_scalar_fields(tcx, *ty); - for _ in 0..n_scalars.saturating_sub(1) { - new_activities.push(da[i].clone()); - new_positions.push(i + 1); + let n_scalars = count_leaf_fields(tcx, *ty); + if n_scalars <= 2 { + for _ in 0..n_scalars.saturating_sub(1) { + new_activities.push(da[i].clone()); + new_positions.push(i + 1); + } } } } @@ -101,16 +105,25 @@ fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec } } -fn count_scalar_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize { +fn count_leaf_fields<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> usize { match ty.kind() { - ty::Float(_) | ty::Int(_) | ty::Uint(_) => 1, + ty::Bool | ty::Char | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::FnPtr(_, _) => 1, + ty::RawPtr(ty, _) => count_leaf_fields(tcx, *ty), + ty::Ref(_, ty, _) => count_leaf_fields(tcx, *ty), + ty::Array(ty, len) => { + if let Some(len) = len.try_to_target_usize(tcx) { + count_leaf_fields(tcx, *ty) * len as usize + } else { + 1 // Not sure about how to handle this case + } + } ty::Adt(def, substs) if def.is_struct() => def .non_enum_variant() .fields .iter() - .map(|f| count_scalar_fields(tcx, f.ty(tcx, substs))) + .map(|f| count_leaf_fields(tcx, f.ty(tcx, substs))) .sum(), - ty::Tuple(substs) => substs.iter().map(|t| count_scalar_fields(tcx, t)).sum(), + ty::Tuple(substs) => substs.iter().map(|t| count_leaf_fields(tcx, t)).sum(), _ => 0, } } diff --git a/tests/codegen/autodiff/abi_handling.rs b/tests/codegen/autodiff/abi_handling.rs index 5de4463c6cf17..18ff0bbc1ff2e 100644 --- a/tests/codegen/autodiff/abi_handling.rs +++ b/tests/codegen/autodiff/abi_handling.rs @@ -101,6 +101,17 @@ fn f6(i: NestedInput) -> f32 { i.x + i.y.z * i.y.z } +// CHECK: ; abi_handling::f7 +// CHECK-NEXT: ; Function Attrs: {{.*}}noinline{{.*}} +// debug-NEXT: define internal float @_ZN12abi_handling2f717h44e3cff234e3b2d5E +// debug-SAME: (ptr align 4 %x.0, ptr align 4 %x.1) +// release-NEXT: define internal fastcc noundef float @_ZN12abi_handling2f717h44e3cff234e3b2d5E +// release-SAME: (float %x.0.0.val, float %x.1.0.val) +#[autodiff_forward(df7, Dual, Dual)] +fn f7(x: (&f32, &f32)) -> f32 { + x.0 * x.1 +} + // df1 // release: define internal fastcc { float, float } // release-SAME: @fwddiffe_ZN12abi_handling2f117h536ac8081c1e4101E @@ -117,12 +128,12 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: start: // debug-NEXT: %"'ipg" = getelementptr inbounds float, ptr %"x'", i64 0 // debug-NEXT: %0 = getelementptr inbounds nuw float, ptr %x, i64 0 -// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4, !alias.scope !4, !noalias !7 -// debug-NEXT: %_2 = load float, ptr %0, align 4, !alias.scope !7, !noalias !4 +// debug-NEXT: %"_2'ipl" = load float, ptr %"'ipg", align 4 +// debug-NEXT: %_2 = load float, ptr %0, align 4 // debug-NEXT: %"'ipg2" = getelementptr inbounds float, ptr %"x'", i64 1 // debug-NEXT: %1 = getelementptr inbounds nuw float, ptr %x, i64 1 -// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4, !alias.scope !4, !noalias !7 -// debug-NEXT: %_5 = load float, ptr %1, align 4, !alias.scope !7, !noalias !4 +// debug-NEXT: %"_5'ipl" = load float, ptr %"'ipg2", align 4 +// debug-NEXT: %_5 = load float, ptr %1, align 4 // debug-NEXT: %_0 = fadd float %_2, %_5 // debug-NEXT: %2 = fadd fast float %"_2'ipl", %"_5'ipl" // debug-NEXT: %3 = insertvalue { float, float } undef, float %_0, 0 @@ -147,7 +158,7 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: %"x'de" = alloca float, align 4 // debug-NEXT: store float 0.000000e+00, ptr %"x'de", align 4 // debug-NEXT: %toreturn = alloca float, align 4 -// debug-NEXT: %_0 = call float %f(float %x) #12 +// debug-NEXT: %_0 = call float %f(float %x) // debug-NEXT: store float %_0, ptr %toreturn, align 4 // debug-NEXT: br label %invertstart // debug-EMPTY: @@ -172,10 +183,10 @@ fn f6(i: NestedInput) -> f32 { // debug: define internal { float, float } @fwddiffe_ZN12abi_handling2f317h9cd1fc602b0815a4E // debug-SAME: (ptr align 4 %x, ptr align 4 %"x'", ptr align 4 %y, ptr align 4 %"y'") // debug-NEXT: start: -// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4, !alias.scope !9, !noalias !12 -// debug-NEXT: %_3 = load float, ptr %x, align 4, !alias.scope !12, !noalias !9 -// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4, !alias.scope !14, !noalias !17 -// debug-NEXT: %_4 = load float, ptr %y, align 4, !alias.scope !17, !noalias !14 +// debug-NEXT: %"_3'ipl" = load float, ptr %"x'", align 4 +// debug-NEXT: %_3 = load float, ptr %x, align 4 +// debug-NEXT: %"_4'ipl" = load float, ptr %"y'", align 4 +// debug-NEXT: %_4 = load float, ptr %y, align 4 // debug-NEXT: %_0 = fmul float %_3, %_4 // debug-NEXT: %0 = fmul fast float %"_3'ipl", %_4 // debug-NEXT: %1 = fmul fast float %"_4'ipl", %_3 @@ -257,6 +268,28 @@ fn f6(i: NestedInput) -> f32 { // debug-NEXT: ret { float, float } %5 // debug-NEXT: } +// df7 +// release: define internal fastcc { float, float } +// release-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E +// release-SAME: (float %x.0.0.val, float %"x.0'.0.val") +// release-NEXT: start: +// release-NEXT: %0 = insertvalue { float, float } undef, float %x.0.0.val, 0 +// release-NEXT: %1 = insertvalue { float, float } %0, float %"x.0'.0.val", 1 +// release-NEXT: ret { float, float } %1 +// release-NEXT: } + +// debug: define internal { float, float } +// debug-SAME: @fwddiffe_ZN12abi_handling2f717h44e3cff234e3b2d5E +// debug-SAME: (ptr align 4 %x.0, ptr align 4 %"x.0'", ptr align 4 %x.1, ptr align 4 %"x.1'") +// debug-NEXT: start: +// debug-NEXT: %0 = call fast { float, float } @"fwddiffe_ZN49_{{.*}}" +// debug-NEXT: %1 = extractvalue { float, float } %0, 0 +// debug-NEXT: %2 = extractvalue { float, float } %0, 1 +// debug-NEXT: %3 = insertvalue { float, float } undef, float %1, 0 +// debug-NEXT: %4 = insertvalue { float, float } %3, float %2, 1 +// debug-NEXT: ret { float, float } %4 +// debug-NEXT: } + fn main() { let x = std::hint::black_box(2.0); let y = std::hint::black_box(3.0); @@ -290,4 +323,9 @@ fn main() { dbg!(f6(in_f6)); let res_f6 = df6(in_f6, NestedInput { x, y: Wrapper { z } }); dbg!(res_f6); + + let in_f7 = (&x, &y); + dbg!(f7(in_f7)); + let res_f7 = df7(in_f7, (&1.0, &0.0)); + dbg!(res_f7); } From d25ed22102b3426113e83f79f246f3a8ae9ae6af Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Fri, 27 Jun 2025 07:15:00 +0000 Subject: [PATCH 5/5] Update test header comment --- tests/codegen/autodiff/abi_handling.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/codegen/autodiff/abi_handling.rs b/tests/codegen/autodiff/abi_handling.rs index 18ff0bbc1ff2e..bb8dc61c7ccb5 100644 --- a/tests/codegen/autodiff/abi_handling.rs +++ b/tests/codegen/autodiff/abi_handling.rs @@ -5,10 +5,10 @@ //@ no-prefer-dynamic //@ needs-enzyme -// This does only test the funtion attribute handling for autodiff. -// Function argument changes are troublesome for Enzyme, so we have to -// ensure that arguments remain the same, or if we change them, be aware -// of the changes to handle it correctly. +// This test checks that Rust types are lowered to LLVM-IR types in a way +// we expect and Enzyme can handle. We explicitly check release mode to +// ensure that LLVM's O3 pipeline doesn't rewrite function signatures +// into forms that Enzyme can't process correctly. #![feature(autodiff)]