Skip to content

Commit 3409f1a

Browse files
committed
feat: support kwargs in macros
1 parent d18fd40 commit 3409f1a

File tree

2 files changed

+25
-5
lines changed

2 files changed

+25
-5
lines changed

src/Compiler.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ function compile_mlir!(
588588
optimize::Union{Bool,Symbol}=true,
589589
no_nan::Bool=false,
590590
backend="gpu",
591+
fn_kwargs=(),
591592
)
592593
# Explicitly don't use block! to avoid creating a closure, which creates
593594
# both compile-time and relocatability issues
@@ -597,7 +598,7 @@ function compile_mlir!(
597598
activate_callcache!(callcache)
598599

599600
mlir_fn_res = try
600-
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
601+
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
601602
finally
602603
deactivate_callcache!(callcache)
603604
MLIR.IR.deactivate!(MLIR.IR.body(mod))
@@ -984,6 +985,7 @@ function compile_call_expr(mod, compiler, options, args...)
984985
call = only(args)
985986
f_symbol = gensym(:f)
986987
args_symbol = gensym(:args)
988+
kwargs_symbol = gensym(:kwargs)
987989
compiled_symbol = gensym(:compiled)
988990

989991
if Meta.isexpr(call, :call)
@@ -999,19 +1001,32 @@ function compile_call_expr(mod, compiler, options, args...)
9991001
else
10001002
:($(fname))
10011003
end
1002-
args_rhs = Expr(:tuple, call.args[2:end]...)
1004+
args_rhs = call.args[2:end]
10031005
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
10041006
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
1005-
args_rhs = only(call.args[2:end])
1007+
@assert length(call.args[2:end]) == 1
1008+
args_rhs = (call.args[2],)
10061009
else
10071010
error("Invalid function call: $(call)")
10081011
end
10091012

1013+
if length(args_rhs) 1 && Meta.isexpr(args_rhs[1], :parameters)
1014+
kwargs_rhs = args_rhs[1].args
1015+
args_rhs = args_rhs[2:end]
1016+
else
1017+
kwargs_rhs = []
1018+
end
1019+
args_rhs = Expr(:tuple, args_rhs...)
1020+
10101021
return quote
10111022
$(f_symbol) = $(fname)
10121023
$(args_symbol) = $(args_rhs)
1024+
$(kwargs_symbol) = (; $(kwargs_rhs...))
10131025
$(compiled_symbol) = $(compiler)(
1014-
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
1026+
$(f_symbol),
1027+
$(args_symbol);
1028+
fn_kwargs=$(kwargs_symbol),
1029+
$(Expr.(:kw, keys(options), values(options))...),
10151030
)
10161031
end,
10171032
(; compiled=compiled_symbol, args=args_symbol)

src/TracedUtils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,19 @@ function make_mlir_fn(
251251
# Explicitly don't use block! to avoid creating a closure, which creates
252252
# both compile-time and relocatability issues
253253
MLIR.IR.activate!(fnbody)
254+
254255
result = try
255256
for (i, arg) in enumerate(linear_args)
256257
raw_arg = MLIR.IR.argument(fnbody, i)
257258
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
258259
set_mlir_data!(arg, row_maj_arg)
259260
end
260261

261-
Reactant.call_with_reactant(f, traced_args...)
262+
if isempty(kwargs)
263+
Reactant.call_with_reactant(f, traced_args...)
264+
else
265+
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
266+
end
262267
finally
263268
MLIR.IR.deactivate!(fnbody)
264269
end

0 commit comments

Comments
 (0)