Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ function compile_mlir!(
optimize::Union{Bool,Symbol}=true,
no_nan::Bool=false,
backend="gpu",
fn_kwargs=(),
)
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
Expand All @@ -597,7 +598,7 @@ function compile_mlir!(
activate_callcache!(callcache)

mlir_fn_res = try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
Reactant.TracedUtils.make_mlir_fn(f, args, fn_kwargs, "main", true)
finally
deactivate_callcache!(callcache)
MLIR.IR.deactivate!(MLIR.IR.body(mod))
Expand Down Expand Up @@ -984,6 +985,7 @@ function compile_call_expr(mod, compiler, options, args...)
call = only(args)
f_symbol = gensym(:f)
args_symbol = gensym(:args)
kwargs_symbol = gensym(:kwargs)
compiled_symbol = gensym(:compiled)

if Meta.isexpr(call, :call)
Expand All @@ -999,19 +1001,37 @@ function compile_call_expr(mod, compiler, options, args...)
else
:($(fname))
end
args_rhs = Expr(:tuple, call.args[2:end]...)
args_rhs = call.args[2:end]

# if (;) is used, we need to extract the kwargs
if length(args_rhs) ≥ 1 && Meta.isexpr(args_rhs[1], :parameters)
kwargs_rhs = args_rhs[1].args
args_rhs = args_rhs[2:end]
else
kwargs_rhs = ()
end
kw_idxs = findall(Base.Fix2(Meta.isexpr, :kw), args_rhs)
arg_idxs = setdiff(1:length(args_rhs), kw_idxs)

kwargs_rhs = (kwargs_rhs..., args_rhs[kw_idxs]...)
args_rhs = Expr(:tuple, args_rhs[arg_idxs]...)
elseif Meta.isexpr(call, :(.), 2) && Meta.isexpr(call.args[2], :tuple)
fname = :($(Base.Broadcast.BroadcastFunction)($(call.args[1])))
args_rhs = only(call.args[2:end])
kwargs_rhs = ()
else
error("Invalid function call: $(call)")
end

return quote
$(f_symbol) = $(fname)
$(args_symbol) = $(args_rhs)
$(kwargs_symbol) = (; $(kwargs_rhs...))
$(compiled_symbol) = $(compiler)(
$(f_symbol), $(args_symbol); $(Expr.(:kw, keys(options), values(options))...)
$(f_symbol),
$(args_symbol);
fn_kwargs=$(kwargs_symbol),
$(Expr.(:kw, keys(options), values(options))...),
)
end,
(; compiled=compiled_symbol, args=args_symbol)
Expand Down
7 changes: 6 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -251,14 +251,19 @@ function make_mlir_fn(
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)

result = try
for (i, arg) in enumerate(linear_args)
raw_arg = MLIR.IR.argument(fnbody, i)
row_maj_arg = do_transpose ? transpose_val(raw_arg) : raw_arg
set_mlir_data!(arg, row_maj_arg)
end

Reactant.call_with_reactant(f, traced_args...)
if isempty(kwargs)
Reactant.call_with_reactant(f, traced_args...)
else
Reactant.call_with_reactant(Core.kwcall, kwargs, f, traced_args...)
end
finally
MLIR.IR.deactivate!(fnbody)
end
Expand Down
60 changes: 11 additions & 49 deletions test/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -163,55 +163,17 @@ end
x = randn(2, 3, 4)
x_ca = ConcreteRArray(x)

# XXX: @jit doesn't work with `;`
# @test @jit(mean(x_ca)) ≈ mean(x)
# @test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1)
# @test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2))
# @test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3))

mean_fn1(x) = mean(x)
mean_fn2(x) = mean(x; dims=1)
mean_fn3(x) = mean(x; dims=(1, 2))
mean_fn4(x) = mean(x; dims=(1, 3))
mean_f1abs2(x) = mean(abs2, x)
mean_f2abs2(x) = mean(abs2, x; dims=1)

mean_fn1_compiled = @compile mean_fn1(x_ca)
mean_fn2_compiled = @compile mean_fn2(x_ca)
mean_fn3_compiled = @compile mean_fn3(x_ca)
mean_fn4_compiled = @compile mean_fn4(x_ca)
mean_f1abs2_compiled = @compile mean_f1abs2(x_ca)
mean_f2abs2_compiled = @compile mean_f2abs2(x_ca)

@test mean_fn1(x) ≈ mean_fn1_compiled(x_ca)
@test mean_fn2(x) ≈ mean_fn2_compiled(x_ca)
@test mean_fn3(x) ≈ mean_fn3_compiled(x_ca)
@test mean_fn4(x) ≈ mean_fn4_compiled(x_ca)
@test mean_f1abs2(x) ≈ mean_f1abs2_compiled(x_ca)
@test mean_f2abs2(x) ≈ mean_f2abs2_compiled(x_ca)

# XXX: @jit doesn't work with `;`
# @test @jit(var(x_ca)) ≈ var(x)
# @test @jit(var(x_ca; dims=1)) ≈ var(x; dims=1)
# @test @jit(var(x_ca; dims=(1, 2), corrected=false)) ≈
# var(x; dims=(1, 2), corrected=false)
# @test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈
# var(x; dims=(1, 3), corrected=false)

var_fn1(x) = var(x)
var_fn2(x) = var(x; dims=1)
var_fn3(x) = var(x; dims=(1, 2), corrected=false)
var_fn4(x) = var(x; dims=(1, 3), corrected=false)

var_fn1_compiled = @compile var_fn1(x_ca)
var_fn2_compiled = @compile var_fn2(x_ca)
var_fn3_compiled = @compile var_fn3(x_ca)
var_fn4_compiled = @compile var_fn4(x_ca)

@test var_fn1(x) ≈ var_fn1_compiled(x_ca)
@test var_fn2(x) ≈ var_fn2_compiled(x_ca)
@test var_fn3(x) ≈ var_fn3_compiled(x_ca)
@test var_fn4(x) ≈ var_fn4_compiled(x_ca)
@test @jit(mean(x_ca)) ≈ mean(x)
@test @jit(mean(x_ca; dims=1)) ≈ mean(x; dims=1)
@test @jit(mean(x_ca; dims=(1, 2))) ≈ mean(x; dims=(1, 2))
@test @jit(mean(x_ca; dims=(1, 3))) ≈ mean(x; dims=(1, 3))

@test @jit(var(x_ca)) ≈ var(x)
@test @jit(var(x_ca, dims=1)) ≈ var(x; dims=1)
@test @jit(var(x_ca, dims=(1, 2); corrected=false)) ≈
var(x; dims=(1, 2), corrected=false)
@test @jit(var(x_ca; dims=(1, 3), corrected=false)) ≈
var(x; dims=(1, 3), corrected=false)
end

@testset "concatenation" begin
Expand Down