Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add @code_hlo macro to get mlir code #39

Merged
merged 2 commits into from
Jul 13, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 96 additions & 48 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,64 @@ pad_dot_general<1>(1);
enzyme-hlo-remove-transform
"""

function compile_to_module(mod, f, args; optimize=true)
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
mod, f, args, (), "main", true
)

concrete_seen = IdDict()

concrete_result = make_tracer(
concrete_seen, traced_result, ("result",), TracedToConcrete
)

if optimize
XLA.RunPassPipeline(
opt_passes *
",enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," *
opt_passes,
mod,
)
end

preserved_args = Tuple{TracedRArray,Int}[]
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
nresults = MLIR.IR.Value[]
linear_results2 = TracedRArray[]
for (i, op) in enumerate(results)
if !MLIR.IR.is_block_arg(op)
push!(nresults, op)
push!(linear_results2, linear_results[i])
continue
end
push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op)))
end
fnbody = MLIR.IR.block(ret)
MLIR.API.mlirOperationDestroy(ret.operation)
ret.operation = MLIR.API.MlirOperation(C_NULL)
MLIR.IR.block!(fnbody) do
return MLIR.Dialects.func.return_(nresults)
end

out_tys2 = [MLIR.IR.type(a) for a in nresults]

func3 = MLIR.Dialects.func.func_(;
sym_name="main",
function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
body=MLIR.IR.Region(),
)
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func3, 1), MLIR.IR.region(func2, 1))

push!(MLIR.IR.body(mod), func3)

MLIR.API.mlirOperationDestroy(func2.operation)
func2.operation = MLIR.API.MlirOperation(C_NULL)

return linear_args,
linear_results2, preserved_args, seen_args, concrete_result,
fnwrapped
end

function compile(
f::FTy, args::VAT; pipeline_options="", client=nothing
) where {FTy,VAT<:Tuple}
Expand All @@ -1109,14 +1167,8 @@ function compile(
MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
MLIR.IR.mmodule!(mod) do
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys, linear_results = make_mlir_fn(
mod, f, args, (), "main", true
)

concrete_seen = IdDict()

concrete_result = make_tracer(
concrete_seen, traced_result, ("result",), TracedToConcrete
linear_args, linear_results2, preserved_args, seen_args, concrete_result, fnwrapped = compile_to_module(
mod, f, args; optimize=true
)

if isnothing(client)
Expand All @@ -1133,46 +1185,6 @@ function compile(
end
end

XLA.RunPassPipeline(
opt_passes *
",enzyme,arith-raise{stablehlo=true},canonicalize, remove-unnecessary-enzyme-ops, enzyme-simplify-math," *
opt_passes,
mod,
)

preserved_args = Tuple{TracedRArray,Int}[]
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
nresults = MLIR.IR.Value[]
linear_results2 = TracedRArray[]
for (i, op) in enumerate(results)
if !MLIR.IR.is_block_arg(op)
push!(nresults, op)
push!(linear_results2, linear_results[i])
continue
end
push!(preserved_args, (linear_results[i], MLIR.IR.block_arg_num(op)))
end
fnbody = MLIR.IR.block(ret)
MLIR.API.mlirOperationDestroy(ret.operation)
ret.operation = MLIR.API.MlirOperation(C_NULL)
MLIR.IR.block!(fnbody) do
return MLIR.Dialects.func.return_(nresults)
end

out_tys2 = [MLIR.IR.type(a) for a in nresults]

func3 = MLIR.Dialects.func.func_(;
sym_name="main",
function_type=MLIR.IR.FunctionType(in_tys, out_tys2),
body=MLIR.IR.Region(),
)
MLIR.API.mlirRegionTakeBody(MLIR.IR.region(func3, 1), MLIR.IR.region(func2, 1))

push!(MLIR.IR.body(mod), func3)

MLIR.API.mlirOperationDestroy(func2.operation)
func2.operation = MLIR.API.MlirOperation(C_NULL)

return generate_jlfunc(
concrete_result,
client,
Expand All @@ -1186,6 +1198,42 @@ function compile(
end
end

struct CompiledModule
mod::MLIR.IR.Module
ctx::MLIR.IR.Context
end

Base.show(io::IO, cm::CompiledModule) = show(io, cm.mod)

"""
@code_hlo [optimize = ...] f(args...)
"""
macro code_hlo(options, maybe_call=nothing)
call = something(maybe_call, options)
options = isnothing(maybe_call) ? :(optimize = true) : options
Meta.isexpr(call, :call) || error("@code_mlir: expected call, got $call")
if !Meta.isexpr(options, :(=)) || options.args[1] != :optimize
error("@code_mlir: expected options in format optimize=value, got $options")
end

options = Expr(:tuple, Expr(:parameters, Expr(:kw, options.args...)))

quote
options = $(esc(options))
f = $(esc(call.args[1]))
args = $(esc(Expr(:vect, call.args[2:end]...)))

ctx = MLIR.IR.Context()
Base.append!(registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
MLIR.IR.context!(ctx) do
mod = MLIR.IR.Module(MLIR.IR.Location())
compile_to_module(mod, f, args; optimize=options.optimize)
CompiledModule(mod, ctx)
end
end
end

function set_default_backend(backend::XLA.Client)
return XLA.default_backend[] = backend
end
Expand Down
Loading