@@ -588,6 +588,7 @@ function compile_mlir!(
588588 optimize:: Union{Bool,Symbol} = true ,
589589 no_nan:: Bool = false ,
590590 backend= " gpu" ,
591+ fn_kwargs= (),
591592)
592593 # Explicitly don't use block! to avoid creating a closure, which creates
593594 # both compile-time and relocatability issues
@@ -597,7 +598,7 @@ function compile_mlir!(
597598 activate_callcache! (callcache)
598599
599600 mlir_fn_res = try
600- Reactant. TracedUtils. make_mlir_fn (f, args, () , " main" , true )
601+ Reactant. TracedUtils. make_mlir_fn (f, args, fn_kwargs , " main" , true )
601602 finally
602603 deactivate_callcache! (callcache)
603604 MLIR. IR. deactivate! (MLIR. IR. body (mod))
@@ -984,6 +985,7 @@ function compile_call_expr(mod, compiler, options, args...)
984985 call = only (args)
985986 f_symbol = gensym (:f )
986987 args_symbol = gensym (:args )
988+ kwargs_symbol = gensym (:kwargs )
987989 compiled_symbol = gensym (:compiled )
988990
989991 if Meta. isexpr (call, :call )
@@ -999,19 +1001,32 @@ function compile_call_expr(mod, compiler, options, args...)
9991001 else
10001002 :($ (fname))
10011003 end
1002- args_rhs = Expr ( :tuple , call. args[2 : end ]. .. )
1004+ args_rhs = call. args[2 : end ]
10031005 elseif Meta. isexpr (call, :(.), 2 ) && Meta. isexpr (call. args[2 ], :tuple )
10041006 fname = :($ (Base. Broadcast. BroadcastFunction)($ (call. args[1 ])))
1005- args_rhs = only (call. args[2 : end ])
1007+ @assert length (call. args[2 : end ]) == 1
1008+ args_rhs = (call. args[2 ],)
10061009 else
10071010 error (" Invalid function call: $(call) " )
10081011 end
10091012
1013+ if length (args_rhs) ≥ 1 && Meta. isexpr (args_rhs[1 ], :parameters )
1014+ kwargs_rhs = args_rhs[1 ]. args
1015+ args_rhs = args_rhs[2 : end ]
1016+ else
1017+ kwargs_rhs = []
1018+ end
1019+ args_rhs = Expr (:tuple , args_rhs... )
1020+
10101021 return quote
10111022 $ (f_symbol) = $ (fname)
10121023 $ (args_symbol) = $ (args_rhs)
1024+ $ (kwargs_symbol) = (; $ (kwargs_rhs... ))
10131025 $ (compiled_symbol) = $ (compiler)(
1014- $ (f_symbol), $ (args_symbol); $ (Expr .(:kw , keys (options), values (options))... )
1026+ $ (f_symbol),
1027+ $ (args_symbol);
1028+ fn_kwargs= $ (kwargs_symbol),
1029+ $ (Expr .(:kw , keys (options), values (options))... ),
10151030 )
10161031 end ,
10171032 (; compiled= compiled_symbol, args= args_symbol)
0 commit comments