1
1
use std:: str:: FromStr ;
2
2
3
- use rustc_abi:: ExternAbi ;
3
+ use rustc_abi:: { ExternAbi , Size } ;
4
4
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
5
5
use rustc_ast:: { LitKind , MetaItem , MetaItemInner , attr} ;
6
6
use rustc_attr_data_structures:: ReprAttr :: ReprAlign ;
@@ -16,7 +16,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
16
16
use rustc_middle:: mir:: mono:: Linkage ;
17
17
use rustc_middle:: query:: Providers ;
18
18
use rustc_middle:: span_bug;
19
- use rustc_middle:: ty:: { self as ty, TyCtxt } ;
19
+ use rustc_middle:: ty:: { self as ty, PseudoCanonicalInput , Ty , TyCtxt , TypingEnv } ;
20
20
use rustc_session:: parse:: feature_err;
21
21
use rustc_session:: { Session , lint} ;
22
22
use rustc_span:: { Ident , Span , sym} ;
@@ -138,6 +138,28 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
138
138
sym:: rustc_allocator_zeroed => {
139
139
codegen_fn_attrs. flags |= CodegenFnAttrFlags :: ALLOCATOR_ZEROED
140
140
}
141
+ sym:: rustc_autodiff => {
142
+ let list = attr. meta_item_list ( ) . unwrap_or_default ( ) ;
143
+ if list. is_empty ( ) {
144
+ // Add the flag only to the primal function so LLVM can
145
+ // optimize the derivative function.
146
+ if let Some ( sig) = fn_sig ( ) {
147
+ let sig = sig. skip_binder ( ) ;
148
+
149
+ let has_problematic_args = sig
150
+ . skip_binder ( )
151
+ . inputs ( )
152
+ . iter ( )
153
+ . any ( |ty| is_abi_opt_sensitive ( tcx, * ty) ) ;
154
+
155
+ if has_problematic_args {
156
+ codegen_fn_attrs. flags |= CodegenFnAttrFlags :: RUSTC_AUTODIFF_NO_ABI_OPT ;
157
+ }
158
+ }
159
+
160
+ // TODO(Sa4dUs): Handle static variable passed as argument case.
161
+ }
162
+ }
141
163
sym:: naked => codegen_fn_attrs. flags |= CodegenFnAttrFlags :: NAKED ,
142
164
sym:: no_mangle => {
143
165
no_mangle_span = Some ( attr. span ( ) ) ;
@@ -899,6 +921,44 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
899
921
Some ( AutoDiffAttrs { mode, width, ret_activity, input_activity : arg_activities } )
900
922
}
901
923
924
+ fn is_abi_opt_sensitive < ' tcx > ( tcx : TyCtxt < ' tcx > , ty : Ty < ' tcx > ) -> bool {
925
+ match ty. kind ( ) {
926
+ ty:: Ref ( _, inner, _) | ty:: RawPtr ( inner, _) => {
927
+ match inner. kind ( ) {
928
+ ty:: Slice ( _) => {
929
+ // Since we cannot guarantee that the slice length is large enough
930
+ // to avoid optimization, we assume it is ABI-opt sensitive.
931
+ return true ;
932
+ }
933
+ ty:: Array ( elem_ty, len) => {
934
+ let Some ( len_val) = len. try_to_target_usize ( tcx) else {
935
+ return false ;
936
+ } ;
937
+
938
+ let pci = PseudoCanonicalInput {
939
+ typing_env : TypingEnv :: fully_monomorphized ( ) ,
940
+ value : * elem_ty,
941
+ } ;
942
+
943
+ if elem_ty. is_scalar ( ) {
944
+ let elem_size =
945
+ tcx. layout_of ( pci) . ok ( ) . map ( |layout| layout. size ) . unwrap_or ( Size :: ZERO ) ;
946
+
947
+ if elem_size. bytes ( ) * len_val <= tcx. data_layout . pointer_size . bytes ( ) * 2 {
948
+ return true ;
949
+ }
950
+ }
951
+ }
952
+ _ => { }
953
+ }
954
+
955
+ false
956
+ }
957
+ ty:: FnPtr ( _, _) => true ,
958
+ _ => false ,
959
+ }
960
+ }
961
+
902
962
pub ( crate ) fn provide ( providers : & mut Providers ) {
903
963
* providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..* providers } ;
904
964
}
0 commit comments