Skip to content

Commit 36640f3

Browse files
committed
feat: support kwargs in macros
1 parent d18fd40 commit 36640f3

File tree

3 files changed

+40
-53
lines changed

3 files changed

+40
-53
lines changed

src/Compiler.jl

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -588,6 +588,7 @@ function compile_mlir!(
588588
optimize::Union{Bool,Symbol}=true,
589589
no_nan::Bool=false,
590590
backend="gpu",
591+
fn_kwargs=(),
591592
)
592593
# Explicitly don't use block! to avoid creating a closure, which creates
593594
# both compile-time and relocatability issues
@@ -597,7 +598,7 @@ function compile_mlir!(
597598
activate_callcache!(callcache)
598599

599600
mlir_fn_res = try
600-
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
601+
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
601602
finally
602603
deactivate_callcache!(callcache)
603604
MLIR.IR.deactivate!(MLIR.IR.body(mod))
@@ -984,6 +985,7 @@ function compile_call_expr(mod, compiler, options, args...)
984985
call = only(args)
985986
f_symbol = gensym(:f)
986987
args_symbol = gensym(:args)
988+
kwargs_symbol = gensym(:kwargs)
987989
compiled_symbol = gensym(:compiled)
988990

989991
if Meta.isexpr(call, :call)
@@ -999,19 +1001,37 @@ function compile_call_expr(mod, compiler, options, args...)
9991001
else
10001002
:($(fname))
10011003
end
1002-
args_rhs = Expr(:tuple, call.args[2:end]...)
1004+
args_rhs = call.args[2:end]
1005+
1006+
# if (;) is used, we need to extract the kwargs
1007+
if length(args_rhs) 1 && Meta.isexpr(args_rhs[1], :parameters)
1008+
kwargs_rhs = args_rhs[1].args
1009+
args_rhs = args_rhs[2:end]
1010+
else
1011+
kwargs_rhs = ()
1012+
end
1013+
kw_idxs = findall(Base.Fix2(Meta.isexpr, :kw), args_rhs)
1014+
arg_idxs = setdiff(1:length(args_rhs), kw_idxs)
1015+
1016+
kwargs_rhs = (kwargs_rhs..., args_rhs[kw_idxs]...)
1017+
args_rhs = Expr(:tuple, args_rhs[arg_idxs]...)
10031018
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
10041019
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
10051020
args_rhs = only(call.args[2:end])
1021+
kwargs_rhs = ()
10061022
else
10071023
error("Invalid function call: $(call)")
10081024
end
10091025

10101026
return quote
10111027
$(f_symbol) = $(fname)
10121028
$(args_symbol) = $(args_rhs)
1029+
$(kwargs_symbol) = (; $(kwargs_rhs...))
10131030
$(compiled_symbol) = $(compiler)(
1014-
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
1031+
$(f_symbol),
1032+
$(args_symbol);
1033+
fn_kwargs=$(kwargs_symbol),
1034+
$(Expr.(:kw, keys(options), values(options))...),
10151035
)
10161036
end,
10171037
(; compiled=compiled_symbol, args=args_symbol)

src/TracedUtils.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,14 +251,19 @@ function make_mlir_fn(
251251
# Explicitly don't use block! to avoid creating a closure, which creates
252252
# both compile-time and relocatability issues
253253
MLIR.IR.activate!(fnbody)
254+
254255
result = try
255256
for (i, arg) in enumerate(linear_args)
256257
raw_arg = MLIR.IR.argument(fnbody, i)
257258
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
258259
set_mlir_data!(arg, row_maj_arg)
259260
end
260261

261-
Reactant.call_with_reactant(f, traced_args...)
262+
if isempty(kwargs)
263+
Reactant.call_with_reactant(f, traced_args...)
264+
else
265+
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
266+
end
262267
finally
263268
MLIR.IR.deactivate!(fnbody)
264269
end

test/basic.jl

Lines changed: 11 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -163,55 +163,17 @@ end
163163
x = randn(2, 3, 4)
164164
x_ca = ConcreteRArray(x)
165165

166-
# XXX: @jit doesn't work with `;`
167-
# @test @jit(mean(x_ca)) ≈ mean(x)
168-
# @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1)
169-
# @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2))
170-
# @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3))
171-
172-
mean_fn1(x) = mean(x)
173-
mean_fn2(x) = mean(x; dims=1)
174-
mean_fn3(x) = mean(x; dims=(1, 2))
175-
mean_fn4(x) = mean(x; dims=(1, 3))
176-
mean_f1abs2(x) = mean(abs2, x)
177-
mean_f2abs2(x) = mean(abs2, x; dims=1)
178-
179-
mean_fn1_compiled = @compile mean_fn1(x_ca)
180-
mean_fn2_compiled = @compile mean_fn2(x_ca)
181-
mean_fn3_compiled = @compile mean_fn3(x_ca)
182-
mean_fn4_compiled = @compile mean_fn4(x_ca)
183-
mean_f1abs2_compiled = @compile mean_f1abs2(x_ca)
184-
mean_f2abs2_compiled = @compile mean_f2abs2(x_ca)
185-
186-
@test mean_fn1(x) mean_fn1_compiled(x_ca)
187-
@test mean_fn2(x) mean_fn2_compiled(x_ca)
188-
@test mean_fn3(x) mean_fn3_compiled(x_ca)
189-
@test mean_fn4(x) mean_fn4_compiled(x_ca)
190-
@test mean_f1abs2(x) mean_f1abs2_compiled(x_ca)
191-
@test mean_f2abs2(x) mean_f2abs2_compiled(x_ca)
192-
193-
# XXX: @jit doesn't work with `;`
194-
# @test @jit(var(x_ca)) ≈ var(x)
195-
# @test @jit(var(x_ca; dims=1)) ≈ var(x; dims=1)
196-
# @test @jit(var(x_ca; dims=(1, 2), corrected=false)) ≈
197-
# var(x; dims=(1, 2), corrected=false)
198-
# @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈
199-
# var(x; dims=(1, 3), corrected=false)
200-
201-
var_fn1(x) = var(x)
202-
var_fn2(x) = var(x; dims=1)
203-
var_fn3(x) = var(x; dims=(1, 2), corrected=false)
204-
var_fn4(x) = var(x; dims=(1, 3), corrected=false)
205-
206-
var_fn1_compiled = @compile var_fn1(x_ca)
207-
var_fn2_compiled = @compile var_fn2(x_ca)
208-
var_fn3_compiled = @compile var_fn3(x_ca)
209-
var_fn4_compiled = @compile var_fn4(x_ca)
210-
211-
@test var_fn1(x) var_fn1_compiled(x_ca)
212-
@test var_fn2(x) var_fn2_compiled(x_ca)
213-
@test var_fn3(x) var_fn3_compiled(x_ca)
214-
@test var_fn4(x) var_fn4_compiled(x_ca)
166+
@test @jit(mean(x_ca)) mean(x)
167+
@test @jit(mean(x_ca; dims=1)) mean(x; dims=1)
168+
@test @jit(mean(x_ca; dims=(1, 2))) mean(x; dims=(1, 2))
169+
@test @jit(mean(x_ca; dims=(1, 3))) mean(x; dims=(1, 3))
170+
171+
@test @jit(var(x_ca)) var(x)
172+
@test @jit(var(x_ca, dims=1)) var(x; dims=1)
173+
@test @jit(var(x_ca, dims=(1, 2); corrected=false))
174+
var(x; dims=(1, 2), corrected=false)
175+
@test @jit(var(x_ca; dims=(1, 3), corrected=false))
176+
var(x; dims=(1, 3), corrected=false)
215177
end
216178

217179
@testset "concatenation" begin

0 commit comments

Comments
 (0)