@@ -3,20 +3,19 @@ use std::ptr;
3
3
use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , AutoDiffItem , DiffActivity , DiffMode } ;
4
4
use rustc_codegen_ssa:: ModuleCodegen ;
5
5
use rustc_codegen_ssa:: back:: write:: ModuleConfig ;
6
- use rustc_codegen_ssa:: traits:: { BaseTypeCodegenMethods , BuilderMethods } ;
7
6
use rustc_errors:: FatalError ;
8
- use rustc_middle:: ty:: TyCtxt ;
9
7
use rustc_session:: config:: Lto ;
10
8
use tracing:: { debug, trace} ;
11
9
12
10
use crate :: back:: write:: { llvm_err, llvm_optimize} ;
13
- use crate :: builder:: Builder ;
14
- use crate :: declare:: declare_raw_fn;
11
+ use crate :: builder:: SBuilder ;
12
+ use crate :: context:: SimpleCx ;
13
+ use crate :: declare:: declare_simple_fn;
15
14
use crate :: errors:: LlvmError ;
16
15
use crate :: llvm:: AttributePlace :: Function ;
17
16
use crate :: llvm:: { Metadata , True } ;
18
17
use crate :: value:: Value ;
19
- use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, context , llvm} ;
18
+ use crate :: { CodegenContext , LlvmCodegenBackend , ModuleLlvm , attributes, llvm} ;
20
19
21
20
fn get_params ( fnc : & Value ) -> Vec < & Value > {
22
21
unsafe {
@@ -38,8 +37,8 @@ fn get_params(fnc: &Value) -> Vec<&Value> {
38
37
/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
39
38
// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
40
39
// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
41
- fn generate_enzyme_call < ' ll , ' tcx > (
42
- cx : & context :: CodegenCx < ' ll , ' tcx > ,
40
+ fn generate_enzyme_call < ' ll > (
41
+ cx : & SimpleCx < ' ll > ,
43
42
fn_to_diff : & ' ll Value ,
44
43
outer_fn : & ' ll Value ,
45
44
attrs : AutoDiffAttrs ,
@@ -112,7 +111,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
112
111
//FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
113
112
// think a bit more about what should go here.
114
113
let cc = llvm:: LLVMGetFunctionCallConv ( outer_fn) ;
115
- let ad_fn = declare_raw_fn (
114
+ let ad_fn = declare_simple_fn (
116
115
cx,
117
116
& ad_name,
118
117
llvm:: CallConv :: try_from ( cc) . expect ( "invalid callconv" ) ,
@@ -132,7 +131,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
132
131
llvm:: LLVMRustEraseInstFromParent ( br) ;
133
132
134
133
let last_inst = llvm:: LLVMRustGetLastInstruction ( entry) . unwrap ( ) ;
135
- let mut builder = Builder :: build ( cx, entry) ;
134
+ let mut builder = SBuilder :: build ( cx, entry) ;
136
135
137
136
let num_args = llvm:: LLVMCountParams ( & fn_to_diff) ;
138
137
let mut args = Vec :: with_capacity ( num_args as usize + 1 ) ;
@@ -236,7 +235,7 @@ fn generate_enzyme_call<'ll, 'tcx>(
236
235
}
237
236
}
238
237
239
- let call = builder. call ( enzyme_ty, None , None , ad_fn, & args, None , None ) ;
238
+ let call = builder. call ( enzyme_ty, ad_fn, & args, None ) ;
240
239
241
240
// This part is a bit iffy. LLVM requires that a call to an inlineable function has some
242
241
// metadata attachted to it, but we just created this code oota. Given that the
@@ -274,10 +273,9 @@ fn generate_enzyme_call<'ll, 'tcx>(
274
273
}
275
274
}
276
275
277
- pub ( crate ) fn differentiate < ' ll , ' tcx > (
276
+ pub ( crate ) fn differentiate < ' ll > (
278
277
module : & ' ll ModuleCodegen < ModuleLlvm > ,
279
278
cgcx : & CodegenContext < LlvmCodegenBackend > ,
280
- tcx : TyCtxt < ' tcx > ,
281
279
diff_items : Vec < AutoDiffItem > ,
282
280
config : & ModuleConfig ,
283
281
) -> Result < ( ) , FatalError > {
@@ -286,8 +284,7 @@ pub(crate) fn differentiate<'ll, 'tcx>(
286
284
}
287
285
288
286
let diag_handler = cgcx. create_dcx ( ) ;
289
- let ( _, cgus) = tcx. collect_and_partition_mono_items ( ( ) ) ;
290
- let cx = context:: CodegenCx :: new ( tcx, & cgus. first ( ) . unwrap ( ) , & module. module_llvm ) ;
287
+ let cx = SimpleCx { llmod : module. module_llvm . llmod ( ) , llcx : module. module_llvm . llcx } ;
291
288
292
289
// Before dumping the module, we want all the TypeTrees to become part of the module.
293
290
for item in diff_items. iter ( ) {
0 commit comments