@@ -3,6 +3,8 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
+ use rustc_codegen_ssa:: common:: TypeKind ;
7
+ use rustc_codegen_ssa:: traits:: BaseTypeCodegenMethods ;
6
8
use rustc_errors:: FatalError ;
7
9
use rustc_middle:: bug;
8
10
use tracing:: { debug, trace} ;
@@ -18,18 +20,18 @@ use crate::value::Value;
18
20
use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
19
21
20
22
fn get_params ( fnc : & Value ) -> Vec < & Value > {
23
+ let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
24
+ let mut fnc_args: Vec < & Value > = vec ! [ ] ;
25
+ fnc_args. reserve ( param_num) ;
21
26
unsafe {
22
- let param_num = llvm:: LLVMCountParams ( fnc) as usize ;
23
- let mut fnc_args: Vec < & Value > = vec ! [ ] ;
24
- fnc_args. reserve ( param_num) ;
25
27
llvm:: LLVMGetParams ( fnc, fnc_args. as_mut_ptr ( ) ) ;
26
28
fnc_args. set_len ( param_num) ;
27
- fnc_args
28
29
}
30
+ fnc_args
29
31
}
30
32
31
33
fn has_sret ( fnc : & Value ) -> bool {
32
- let num_args = unsafe { llvm:: LLVMCountParams ( fnc) as usize } ;
34
+ let num_args = llvm:: LLVMCountParams ( fnc) as usize ;
33
35
if num_args == 0 {
34
36
false
35
37
} else {
@@ -121,23 +123,15 @@ fn match_args_from_caller_to_enzyme<'ll>(
121
123
// (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
122
124
// FIXME(ZuseZ4): We will upstream a safety check later which asserts that
123
125
// int2 >= int1, which means the shadow vector is large enough to store the gradient.
124
- assert ! ( unsafe {
125
- llvm:: LLVMRustGetTypeKind ( next_outer_ty) == llvm:: TypeKind :: Integer
126
- } ) ;
126
+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Integer ) ;
127
127
128
128
for _ in 0 ..width {
129
129
let next_outer_arg2 = outer_args[ outer_pos + 2 ] ;
130
130
let next_outer_ty2 = cx. val_ty ( next_outer_arg2) ;
131
- assert ! (
132
- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty2) }
133
- == llvm:: TypeKind :: Pointer
134
- ) ;
131
+ assert_eq ! ( cx. type_kind( next_outer_ty2) , TypeKind :: Pointer ) ;
135
132
let next_outer_arg3 = outer_args[ outer_pos + 3 ] ;
136
133
let next_outer_ty3 = cx. val_ty ( next_outer_arg3) ;
137
- assert ! (
138
- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty3) }
139
- == llvm:: TypeKind :: Integer
140
- ) ;
134
+ assert_eq ! ( cx. type_kind( next_outer_ty3) , TypeKind :: Integer ) ;
141
135
args. push ( next_outer_arg2) ;
142
136
}
143
137
args. push ( cx. get_metadata_value ( enzyme_const) ) ;
@@ -150,10 +144,7 @@ fn match_args_from_caller_to_enzyme<'ll>(
150
144
// (..., metadata! enzyme_dup, ptr, ptr, ...).
151
145
if matches ! ( diff_activity, DiffActivity :: Duplicated | DiffActivity :: DuplicatedOnly )
152
146
{
153
- assert ! (
154
- unsafe { llvm:: LLVMRustGetTypeKind ( next_outer_ty) }
155
- == llvm:: TypeKind :: Pointer
156
- ) ;
147
+ assert_eq ! ( cx. type_kind( next_outer_ty) , TypeKind :: Pointer ) ;
157
148
}
158
149
// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
159
150
args. push ( next_outer_arg) ;
@@ -213,8 +204,8 @@ fn compute_enzyme_fn_ty<'ll>(
213
204
todo ! ( "Handle sret for scalar ad" ) ;
214
205
} else {
215
206
// First we check if we also have to deal with the primal return.
216
- if attrs. mode . is_fwd ( ) {
217
- match attrs. ret_activity {
207
+ match attrs. mode {
208
+ DiffMode :: Forward => match attrs. ret_activity {
218
209
DiffActivity :: Dual => {
219
210
let arr_ty =
220
211
unsafe { llvm:: LLVMArrayType2 ( inner_ret_ty, attrs. width as u64 + 1 ) } ;
@@ -231,11 +222,13 @@ fn compute_enzyme_fn_ty<'ll>(
231
222
_ => {
232
223
bug ! ( "unreachable" ) ;
233
224
}
225
+ } ,
226
+ DiffMode :: Reverse => {
227
+ todo ! ( "Handle sret for reverse mode" ) ;
228
+ }
229
+ _ => {
230
+ bug ! ( "unreachable" ) ;
234
231
}
235
- } else if attrs. mode . is_rev ( ) {
236
- todo ! ( "Handle sret for reverse mode" ) ;
237
- } else {
238
- bug ! ( "unreachable" ) ;
239
232
}
240
233
}
241
234
}
@@ -395,7 +388,7 @@ fn generate_enzyme_call<'ll>(
395
388
// now store the result of the enzyme call into the sret pointer.
396
389
let sret_ptr = outer_args[ 0 ] ;
397
390
let call_ty = cx. val_ty ( call) ;
398
- assert ! ( llvm :: LLVMRustIsArrayTy ( call_ty) ) ;
391
+ assert_eq ! ( cx . type_kind ( call_ty) , TypeKind :: Array ) ;
399
392
llvm:: LLVMBuildStore ( & builder. llbuilder , call, sret_ptr) ;
400
393
}
401
394
builder. ret_void ( ) ;
0 commit comments