From 2213001d8960235491e492973c218b40282c8c2e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 22 Feb 2025 11:02:01 -0600 Subject: [PATCH] feat: support kwargs in macros --- src/Compiler.jl | 26 +++++++++++++++++--- src/TracedUtils.jl | 7 +++++- test/basic.jl | 60 +++++++++------------------------------------- 3 files changed, 40 insertions(+), 53 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index dcf08aa1da..ce07730d07 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -588,6 +588,7 @@ function compile_mlir!( optimize::Union{Bool,Symbol}=true, no_nan::Bool=false, backend="gpu", + fn_kwargs=(), ) # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues @@ -597,7 +598,7 @@ function compile_mlir!( activate_callcache!(callcache) mlir_fn_res = try - Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true) + Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true) finally deactivate_callcache!(callcache) MLIR.IR.deactivate!(MLIR.IR.body(mod)) @@ -984,6 +985,7 @@ function compile_call_expr(mod, compiler, options, args...) call = only(args) f_symbol = gensym(:f) args_symbol = gensym(:args) + kwargs_symbol = gensym(:kwargs) compiled_symbol = gensym(:compiled) if Meta.isexpr(call, :call) @@ -999,10 +1001,24 @@ function compile_call_expr(mod, compiler, options, args...) else :($(fname)) end - args_rhs = Expr(:tuple, call.args[2:end]...) + args_rhs = call.args[2:end] + + # if (;) is used, we need to extract the kwargs + if length(args_rhs) ≥ 1 && Meta.isexpr(args_rhs[1], :parameters) + kwargs_rhs = args_rhs[1].args + args_rhs = args_rhs[2:end] + else + kwargs_rhs = () + end + kw_idxs = findall(Base.Fix2(Meta.isexpr, :kw), args_rhs) + arg_idxs = setdiff(1:length(args_rhs), kw_idxs) + + kwargs_rhs = (kwargs_rhs..., args_rhs[kw_idxs]...) + args_rhs = Expr(:tuple, args_rhs[arg_idxs]...) elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple) fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1]))) args_rhs = only(call.args[2:end]) + kwargs_rhs = () else error("Invalid function call: $(call)") end @@ -1010,8 +1026,12 @@ function compile_call_expr(mod, compiler, options, args...) return quote $(f_symbol) = $(fname) $(args_symbol) = $(args_rhs) + $(kwargs_symbol) = (; $(kwargs_rhs...)) $(compiled_symbol) = $(compiler)( - $(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...) + $(f_symbol), + $(args_symbol); + fn_kwargs=$(kwargs_symbol), + $(Expr.(:kw, keys(options), values(options))...), ) end, (; compiled=compiled_symbol, args=args_symbol) diff --git a/src/TracedUtils.jl b/src/TracedUtils.jl index 457a826899..b31833be6f 100644 --- a/src/TracedUtils.jl +++ b/src/TracedUtils.jl @@ -251,6 +251,7 @@ function make_mlir_fn( # Explicitly don't use block! to avoid creating a closure, which creates # both compile-time and relocatability issues MLIR.IR.activate!(fnbody) + result = try for (i, arg) in enumerate(linear_args) raw_arg = MLIR.IR.argument(fnbody, i) @@ -258,7 +259,11 @@ function make_mlir_fn( set_mlir_data!(arg, row_maj_arg) end - Reactant.call_with_reactant(f, traced_args...) + if isempty(kwargs) + Reactant.call_with_reactant(f, traced_args...) + else + Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...) + end finally MLIR.IR.deactivate!(fnbody) end diff --git a/test/basic.jl b/test/basic.jl index 827d70acaf..0ab61b4768 100644 --- a/test/basic.jl +++ b/test/basic.jl @@ -163,55 +163,17 @@ end x = randn(2, 3, 4) x_ca = ConcreteRArray(x) - # XXX: @jit doesn't work with `;` - # @test @jit(mean(x_ca)) ≈ mean(x) - # @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) - # @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) - # @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) - - mean_fn1(x) = mean(x) - mean_fn2(x) = mean(x; dims=1) - mean_fn3(x) = mean(x; dims=(1, 2)) - mean_fn4(x) = mean(x; dims=(1, 3)) - mean_f1abs2(x) = mean(abs2, x) - mean_f2abs2(x) = mean(abs2, x; dims=1) - - mean_fn1_compiled = @compile mean_fn1(x_ca) - mean_fn2_compiled = @compile mean_fn2(x_ca) - mean_fn3_compiled = @compile mean_fn3(x_ca) - mean_fn4_compiled = @compile mean_fn4(x_ca) - mean_f1abs2_compiled = @compile mean_f1abs2(x_ca) - mean_f2abs2_compiled = @compile mean_f2abs2(x_ca) - - @test mean_fn1(x) ≈ mean_fn1_compiled(x_ca) - @test mean_fn2(x) ≈ mean_fn2_compiled(x_ca) - @test mean_fn3(x) ≈ mean_fn3_compiled(x_ca) - @test mean_fn4(x) ≈ mean_fn4_compiled(x_ca) - @test mean_f1abs2(x) ≈ mean_f1abs2_compiled(x_ca) - @test mean_f2abs2(x) ≈ mean_f2abs2_compiled(x_ca) - - # XXX: @jit doesn't work with `;` - # @test @jit(var(x_ca)) ≈ var(x) - # @test @jit(var(x_ca; dims=1)) ≈ var(x; dims=1) - # @test @jit(var(x_ca; dims=(1, 2), corrected=false)) ≈ - # var(x; dims=(1, 2), corrected=false) - # @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ - # var(x; dims=(1, 3), corrected=false) - - var_fn1(x) = var(x) - var_fn2(x) = var(x; dims=1) - var_fn3(x) = var(x; dims=(1, 2), corrected=false) - var_fn4(x) = var(x; dims=(1, 3), corrected=false) - - var_fn1_compiled = @compile var_fn1(x_ca) - var_fn2_compiled = @compile var_fn2(x_ca) - var_fn3_compiled = @compile var_fn3(x_ca) - var_fn4_compiled = @compile var_fn4(x_ca) - - @test var_fn1(x) ≈ var_fn1_compiled(x_ca) - @test var_fn2(x) ≈ var_fn2_compiled(x_ca) - @test var_fn3(x) ≈ var_fn3_compiled(x_ca) - @test var_fn4(x) ≈ var_fn4_compiled(x_ca) + @test @jit(mean(x_ca)) ≈ mean(x) + @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1) + @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2)) + @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3)) + + @test @jit(var(x_ca)) ≈ var(x) + @test @jit(var(x_ca, dims=1)) ≈ var(x; dims=1) + @test @jit(var(x_ca, dims=(1, 2); corrected=false)) ≈ + var(x; dims=(1, 2), corrected=false) + @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈ + var(x; dims=(1, 3), corrected=false) end @testset "concatenation" begin