Skip to content

Commit 59222c6

Browse files
committed
Partially disable ABI optimizations for ad functions
1 parent d00435f commit 59222c6

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use std::str::FromStr;
22

3-
use rustc_abi::ExternAbi;
3+
use rustc_abi::{ExternAbi, Size};
44
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
55
use rustc_ast::{LitKind, MetaItem, MetaItemInner, attr};
66
use rustc_attr_data_structures::ReprAttr::ReprAlign;
@@ -16,7 +16,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
1616
use rustc_middle::mir::mono::Linkage;
1717
use rustc_middle::query::Providers;
1818
use rustc_middle::span_bug;
19-
use rustc_middle::ty::{self as ty, TyCtxt};
19+
use rustc_middle::ty::{self as ty, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv};
2020
use rustc_session::parse::feature_err;
2121
use rustc_session::{Session, lint};
2222
use rustc_span::{Ident, Span, sym};
@@ -138,6 +138,28 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
138138
sym::rustc_allocator_zeroed => {
139139
codegen_fn_attrs.flags |= CodegenFnAttrFlags::ALLOCATOR_ZEROED
140140
}
141+
sym::rustc_autodiff => {
142+
let list = attr.meta_item_list().unwrap_or_default();
143+
if list.is_empty() {
144+
// Add the flag only to the primal function so LLVM can
145+
// optimize the derivative function.
146+
if let Some(sig) = fn_sig() {
147+
let sig = sig.skip_binder();
148+
149+
let has_problematic_args = sig
150+
.skip_binder()
151+
.inputs()
152+
.iter()
153+
.any(|ty| is_abi_opt_sensitive(tcx, *ty));
154+
155+
if has_problematic_args {
156+
codegen_fn_attrs.flags |= CodegenFnAttrFlags::RUSTC_AUTODIFF_NO_ABI_OPT;
157+
}
158+
}
159+
160+
// TODO(Sa4dUs): Handle static variable passed as argument case.
161+
}
162+
}
141163
sym::naked => codegen_fn_attrs.flags |= CodegenFnAttrFlags::NAKED,
142164
sym::no_mangle => {
143165
no_mangle_span = Some(attr.span());
@@ -899,6 +921,44 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
899921
Some(AutoDiffAttrs { mode, width, ret_activity, input_activity: arg_activities })
900922
}
901923

924+
fn is_abi_opt_sensitive<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> bool {
925+
match ty.kind() {
926+
ty::Ref(_, inner, _) | ty::RawPtr(inner, _) => {
927+
match inner.kind() {
928+
ty::Slice(_) => {
929+
// Since we cannot guarantee that the slice length is large enough
930+
// to avoid optimization, we assume it is ABI-opt sensitive.
931+
return true;
932+
}
933+
ty::Array(elem_ty, len) => {
934+
let Some(len_val) = len.try_to_target_usize(tcx) else {
935+
return false;
936+
};
937+
938+
let pci = PseudoCanonicalInput {
939+
typing_env: TypingEnv::fully_monomorphized(),
940+
value: *elem_ty,
941+
};
942+
943+
if elem_ty.is_scalar() {
944+
let elem_size =
945+
tcx.layout_of(pci).ok().map(|layout| layout.size).unwrap_or(Size::ZERO);
946+
947+
if elem_size.bytes() * len_val <= tcx.data_layout.pointer_size.bytes() * 2 {
948+
return true;
949+
}
950+
}
951+
}
952+
_ => {}
953+
}
954+
955+
false
956+
}
957+
ty::FnPtr(_, _) => true,
958+
_ => false,
959+
}
960+
}
961+
902962
pub(crate) fn provide(providers: &mut Providers) {
903963
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
904964
}

compiler/rustc_middle/src/middle/codegen_fn_attrs.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,10 @@ bitflags::bitflags! {
136136
const ALLOCATOR_ZEROED = 1 << 14;
137137
/// `#[no_builtins]`: indicates that disable implicit builtin knowledge of functions for the function.
138138
const NO_BUILTINS = 1 << 15;
139+
/// `#[rustc_autodiff_no_abi_opt]`: internal marker applied to `#[rustc_autodiff]` primal functions
140+
/// whose argument layout may be sensitive to ABI-level optimizations. This marker prevents certain
141+
/// optimizations that could otherwise break compatibility with Enzyme's expectations.
142+
const RUSTC_AUTODIFF_NO_ABI_OPT = 1 << 16;
139143
}
140144
}
141145
rustc_data_structures::external_bitflags_debug! { CodegenFnAttrFlags }
@@ -175,6 +179,7 @@ impl CodegenFnAttrs {
175179
self.flags.contains(CodegenFnAttrFlags::NO_MANGLE)
176180
|| self.flags.contains(CodegenFnAttrFlags::RUSTC_STD_INTERNAL_SYMBOL)
177181
|| self.export_name.is_some()
182+
|| self.flags.contains(CodegenFnAttrFlags::RUSTC_AUTODIFF_NO_ABI_OPT)
178183
|| match self.linkage {
179184
// These are private, so make sure we don't try to consider
180185
// them external.

0 commit comments

Comments
 (0)