From 0a447e69348118a349f0a1b3239b3e40ae47bffa Mon Sep 17 00:00:00 2001 From: Paul Berg Date: Fri, 29 Sep 2023 15:39:29 +0200 Subject: [PATCH] IR Module (#11) Co-authored-by: jumerckx <31353884+jumerckx@users.noreply.github.com> --- examples/brutus.jl | 280 +++++++++++++++ src/Dialects.jl | 166 +++++++++ src/IR/IR.jl | 830 +++++++++++++++++++++++++++++++++++++++++++++ src/IR/Pass.jl | 176 ++++++++++ src/IR/Support.jl | 133 ++++++++ src/MLIR.jl | 5 +- 6 files changed, 1589 insertions(+), 1 deletion(-) create mode 100644 examples/brutus.jl create mode 100644 src/Dialects.jl create mode 100644 src/IR/IR.jl create mode 100644 src/IR/Pass.jl create mode 100644 src/IR/Support.jl diff --git a/examples/brutus.jl b/examples/brutus.jl new file mode 100644 index 00000000..36df2877 --- /dev/null +++ b/examples/brutus.jl @@ -0,0 +1,280 @@ +module Brutus + +import LLVM +using MLIR.IR +using MLIR.Dialects: arith, func, cf, std +using Core: PhiNode, GotoNode, GotoIfNot, SSAValue, Argument, ReturnNode, PiNode + +const BrutusScalar = Union{Bool,Int64,Int32,Float32,Float64} + +function cmpi_pred(predicate) + function(ctx, ops; loc=Location(ctx)) + arith.cmpi(ctx, predicate, ops; loc) + end +end + +function single_op_wrapper(fop) + (ctx::Context, block::Block, args::Vector{Value}; loc=Location(ctx)) -> push!(block, fop(ctx, args; loc)) +end + +const intrinsics_to_mlir = Dict([ + Base.add_int => single_op_wrapper(arith.addi), + Base.sle_int => single_op_wrapper(cmpi_pred(arith.Predicates.sle)), + Base.slt_int => single_op_wrapper(cmpi_pred(arith.Predicates.slt)), + Base.:(===) => single_op_wrapper(cmpi_pred(arith.Predicates.eq)), + Base.mul_int => single_op_wrapper(arith.muli), + Base.mul_float => single_op_wrapper(arith.mulf), + Base.not_int => function(ctx, block, args; loc=Location(ctx)) + arg = only(args) + ones = push!(block, arith.constant(ctx, -1, IR.get_type(arg); loc)) |> IR.get_result + push!(block, arith.xori(ctx, Value[arg, ones]; loc)) + end, +]) + +"Generates a block argument for each phi node present in the block." +function prepare_block(ctx, ir, bb) + b = Block() + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + type = stmt[:type] + IR.push_argument!(b, MLIRType(ctx, type), Location(ctx)) + end + + return b +end + +"Values to populate the Phi Node when jumping from `from` to `to`." +function collect_value_arguments(ir, from, to) + to = ir.cfg.blocks[to] + values = [] + for s in to.stmts + stmt = ir.stmts[s] + inst = stmt[:inst] + inst isa Core.PhiNode || continue + + edge = findfirst(==(from), inst.edges) + if isnothing(edge) # use dummy scalar val instead + val = zero(stmt[:type]) + push!(values, val) + else + push!(values, inst.values[edge]) + end + end + values +end + +""" + code_mlir(f, types::Type{Tuple}; ctx=Context()) -> IR.Operation + +Returns a `func.func` operation corresponding to the ircode of the provided method. +This only supports a few Julia Core primitives and scalar types of type $BrutusScalar. + +!!! note + The Julia SSAIR to MLIR conversion implemented is very primitive and only supports a + handful of primitives. A better to perform this conversion would to create a dialect + representing Julia IR and progressively lower it to base MLIR dialects. +""" +function code_mlir(f, types; ctx=Context()) + ir, ret = Core.Compiler.code_ircode(f, types) |> only + @assert first(ir.argtypes) isa Core.Const + + values = Vector{Value}(undef, length(ir.stmts)) + + for dialect in (LLVM.version() >= v"15" ? ("func", "cf") : ("std",)) + IR.get_or_load_dialect!(ctx, dialect) + end + + blocks = [ + prepare_block(ctx, ir, bb) + for bb in ir.cfg.blocks + ] + + current_block = entry_block = blocks[begin] + + for argtype in types.parameters + IR.push_argument!(entry_block, MLIRType(ctx, argtype), Location(ctx)) + end + + function get_value(x)::Value + if x isa Core.SSAValue + @assert isassigned(values, x.id) "value $x was not assigned" + values[x.id] + elseif x isa Core.Argument + IR.get_argument(entry_block, x.n - 1) + elseif x isa BrutusScalar + IR.get_result(push!(current_block, arith.constant(ctx, x))) + else + error("could not use value $x inside MLIR") + end + end + + for (block_id, (b, bb)) in enumerate(zip(blocks, ir.cfg.blocks)) + current_block = b + n_phi_nodes = 0 + + for sidx in bb.stmts + stmt = ir.stmts[sidx] + inst = stmt[:inst] + line = ir.linetable[stmt[:line]] + + if Meta.isexpr(inst, :call) + val_type = stmt[:type] + if !(val_type <: BrutusScalar) + error("type $val_type is not supported") + end + out_type = MLIRType(ctx, val_type) + + called_func = first(inst.args) + if called_func isa GlobalRef # TODO: should probably use something else here + called_func = getproperty(called_func.mod, called_func.name) + end + + fop! = intrinsics_to_mlir[called_func] + args = get_value.(@view inst.args[begin+1:end]) + + loc = Location(ctx, string(line.file), line.line, 0) + res = IR.get_result(fop!(ctx, current_block, args; loc)) + + values[sidx] = res + elseif inst isa PhiNode + values[sidx] = IR.get_argument(current_block, n_phi_nodes += 1) + elseif inst isa PiNode + values[sidx] = get_value(inst.val) + elseif inst isa GotoNode + args = get_value.(collect_value_arguments(ir, block_id, inst.label)) + dest = blocks[inst.label] + loc = Location(ctx, string(line.file), line.line, 0) + brop = LLVM.version() >= v"15" ? cf.br : std.br + push!(current_block, brop(ctx, dest, args; loc)) + elseif inst isa GotoIfNot + false_args = get_value.(collect_value_arguments(ir, block_id, inst.dest)) + cond = get_value(inst.cond) + @assert length(bb.succs) == 2 # NOTE: We assume that length(bb.succs) == 2, this might be wrong + other_dest = setdiff(bb.succs, inst.dest) |> only + true_args = get_value.(collect_value_arguments(ir, block_id, other_dest)) + other_dest = blocks[other_dest] + dest = blocks[inst.dest] + + loc = Location(ctx, string(line.file), line.line, 0) + cond_brop = LLVM.version() >= v"15" ? cf.cond_br : std.cond_br + cond_br = cond_brop(ctx, cond, other_dest, dest, true_args, false_args; loc) + push!(current_block, cond_br) + elseif inst isa ReturnNode + line = ir.linetable[stmt[:line]] + retop = LLVM.version() >= v"15" ? func.return_ : std.return_ + loc = Location(ctx, string(line.file), line.line, 0) + push!(current_block, retop(ctx, [get_value(inst.val)]; loc)) + elseif Meta.isexpr(inst, :code_coverage_effect) + # Skip + else + error("unhandled ir $(inst)") + end + end + end + + func_name = nameof(f) + + region = Region() + for b in blocks + push!(region, b) + end + + LLVM15 = LLVM.version() >= v"15" + + input_types = MLIRType[ + IR.get_type(IR.get_argument(entry_block, i)) + for i in 1:IR.num_arguments(entry_block) + ] + result_types = [MLIRType(ctx, ret)] + + ftype = MLIRType(ctx, input_types => result_types) + op = IR.create_operation( + LLVM15 ? "func.func" : "builtin.func", + Location(ctx); + attributes = [ + NamedAttribute(ctx, "sym_name", IR.Attribute(ctx, string(func_name))), + NamedAttribute(ctx, LLVM15 ? "function_type" : "type", IR.Attribute(ftype)), + ], + owned_regions = Region[region], + result_inference=false, + ) + + IR.verifyall(op) + + op +end + +""" + @code_mlir f(args...) +""" +macro code_mlir(call) + @assert Meta.isexpr(call, :call) "only calls are supported" + + f = first(call.args) |> esc + args = Expr(:curly, + Tuple, + map(arg -> :($(Core.Typeof)($arg)), + call.args[begin+1:end])..., + ) |> esc + + quote + code_mlir($f, $args) + end +end + +end # module Brutus + +# --- + +function pow(x::F, n) where {F} + p = one(F) + for _ in 1:n + p *= x + end + p +end + +function f(x) + if x == 1 + 2 + else + 3 + end +end + +# --- + +using Test +using MLIR.IR, MLIR + +ctx = Context() +# IR.enable_multithreading!(ctx, false) + +op = Brutus.code_mlir(pow, Tuple{Int, Int}; ctx) + +mod = MModule(ctx, Location(ctx)) +body = IR.get_body(mod) +push!(body, op) + +pm = IR.PassManager(ctx) +opm = IR.OpPassManager(pm) + +# IR.enable_ir_printing!(pm) +IR.enable_verifier!(pm, true) + +MLIR.API.mlirRegisterAllPasses() +MLIR.API.mlirRegisterAllLLVMTranslations(ctx) +IR.add_pipeline!(opm, Brutus.LLVM.version() >= v"15" ? "convert-arith-to-llvm,convert-func-to-llvm" : "convert-std-to-llvm") + +IR.run!(pm, mod) + +jit = MLIR.API.mlirExecutionEngineCreate(mod, 0, 0, C_NULL) +fptr = MLIR.API.mlirExecutionEngineLookup(jit, "pow") + +x, y = 3, 4 + +@test ccall(fptr, Int, (Int, Int), x, y) == pow(x, y) diff --git a/src/Dialects.jl b/src/Dialects.jl new file mode 100644 index 00000000..4cb400eb --- /dev/null +++ b/src/Dialects.jl @@ -0,0 +1,166 @@ +module Dialects + +module arith + +using ...IR + +for (f, t) in Iterators.product( + (:add, :sub, :mul), + (:i, :f), +) + fname = Symbol(f, t) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + end +end + +for fname in (:xori, :andi, :ori) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + end +end + +for (f, t) in Iterators.product( + (:div, :max, :min), + (:si, :ui, :f), +) + fname = Symbol(f, t) + @eval function $fname(context, operands, type=IR.get_type(first(operands)); loc=Location(context)) + IR.create_operation($(string("arith.", fname)), loc; operands, results=[type]) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop +for f in (:index_cast, :index_castui) + @eval function $f(context, operand; loc=Location(context)) + IR.create_operation( + $(string("arith.", f)), + loc; + operands=[operand], + results=[IR.IndexType(context)], + ) + end +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop +function extf(context, operand, type; loc=Location(context)) + IR.create_operation("arith.exf", loc; operands=[operand], results=[type]) +end + +# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop +function constant(context, value, type=MLIRType(context, typeof(value)); loc=Location(context)) + IR.create_operation( + "arith.constant", + loc; + results=[type], + attributes=[ + IR.NamedAttribute(context, "value", + Attribute(context, value, type)), + ], + ) +end + +module Predicates + const eq = 0 + const ne = 1 + const slt = 2 + const sle = 3 + const sgt = 4 + const sge = 5 + const ult = 6 + const ule = 7 + const ugt = 8 + const uge = 9 +end + +function cmpi(context, predicate, operands; loc=Location(context)) + IR.create_operation( + "arith.cmpi", + loc; + operands, + results=[MLIRType(context, Bool)], + attributes=[ + IR.NamedAttribute(context, "predicate", + Attribute(context, predicate)) + ], + ) +end + +end # module arith + +module std +# for llvm 14 + +using ...IR + +function return_(context, operands; loc=Location(context)) + IR.create_operation("std.return", loc; operands, result_inference=false) +end + +function br(context, dest, operands; loc=Location(context)) + IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false) +end + +function cond_br( + context, cond, + true_dest, false_dest, + true_dest_operands, + false_dest_operands; + loc=Location(context), +) + IR.create_operation( + "std.cond_br", + loc; + successors=[true_dest, false_dest], + operands=[cond, true_dest_operands..., false_dest_operands...], + attributes=[ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) +end + +end # module std + +module func +# https://mlir.llvm.org/docs/Dialects/Func/ + +using ...IR + +function return_(context, operands; loc=Location(context)) + IR.create_operation("func.return", loc; operands, result_inference=false) +end + +end # module func + +module cf + +using ...IR + +function br(context, dest, operands; loc=Location(context)) + IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false) +end + +function cond_br( + context, cond, + true_dest, false_dest, + true_dest_operands, + false_dest_operands; + loc=Location(context), +) + IR.create_operation( + "cf.cond_br", loc; + operands=[cond, true_dest_operands..., false_dest_operands...], + successors=[true_dest, false_dest], + attributes=[ + IR.NamedAttribute(context, "operand_segment_sizes", + IR.Attribute(context, Int32[1, length(true_dest_operands), length(false_dest_operands)])) + ], + result_inference=false, + ) +end + +end # module cf + +end # module Dialects diff --git a/src/IR/IR.jl b/src/IR/IR.jl new file mode 100644 index 00000000..3606d329 --- /dev/null +++ b/src/IR/IR.jl @@ -0,0 +1,830 @@ +module IR + +import ..API: API + +export + Operation, + OperationState, + Location, + Context, + MModule, + Value, + MLIRType, + Region, + Block, + Attribute, + NamedAttribute + +import Base: ==, String +using .API: + MlirDialectRegistry, + MlirDialectHandle, + MlirAttribute, + MlirNamedAttribute, + MlirDialect, + MlirStringRef, + MlirOperation, + MlirOperationState, + MlirLocation, + MlirBlock, + MlirRegion, + MlirModule, + MlirContext, + MlirType, + MlirValue, + MlirIdentifier, + MlirPassManager, + MlirOpPassManager + +function print_callback(str::MlirStringRef, userdata) + data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) + write(userdata isa Base.RefValue ? userdata[] : userdata, data) + return Cvoid() +end + +### Dialect + +struct Dialect + dialect::MlirDialect + + Dialect(dialect) = begin + @assert !mlirIsNull(dialect) "cannot create Dialect from null MlirDialect" + new(dialect) + end +end + +Base.convert(::Type{MlirDialect}, dialect::Dialect) = dialect.dialect +function Base.show(io::IO, dialect::Dialect) + print(io, "Dialect(\"", String(API.mlirDialectGetNamespace(dialect)), "\")") +end + +### DialectHandle + +struct DialectHandle + handle::API.MlirDialectHandle +end + +function DialectHandle(s::Symbol) + s = Symbol("mlirGetDialectHandle__", s, "__") + DialectHandle(getproperty(API, s)()) +end + +Base.convert(::Type{MlirDialectHandle}, handle::DialectHandle) = handle.handle + +### Dialect Registry + +mutable struct DialectRegistry + registry::MlirDialectRegistry +end +function DialectRegistry() + registry = API.mlirDialectRegistryCreate() + @assert !mlirIsNull(registry) "cannot create DialectRegistry with null MlirDialectRegistry" + finalizer(DialectRegistry(registry)) do registry + API.mlirDialectRegistryDestroy(registry.registry) + end +end + +function Base.insert!(registry::DialectRegistry, handle::DialectHandle) + API.mlirDialectHandleInsertDialect(registry, handle) +end + +### Context + +mutable struct Context + context::MlirContext +end +function Context() + context = API.mlirContextCreate() + @assert !mlirIsNull(context) "cannot create Context with null MlirContext" + finalizer(Context(context)) do context + API.mlirContextDestroy(context.context) + end +end + +Base.convert(::Type{MlirContext}, c::Context) = c.context + +num_loaded_dialects(context) = API.mlirContextGetNumLoadedDialects(context) +function get_or_load_dialect!(context, handle::DialectHandle) + mlir_dialect = API.mlirDialectHandleLoadDialect(handle, context) + if mlirIsNull(mlir_dialect) + error("could not load dialect from handle $handle") + else + Dialect(mlir_dialect) + end +end +function get_or_load_dialect!(context, dialect::String) + get_or_load_dialect!(context, DialectHandle(Symbol(dialect))) +end + +function enable_multithreading!(context, enable=true) + API.mlirContextEnableMultithreading(context, enable) + context +end + +is_registered_operation(context, opname) = API.mlirContextIsRegisteredOperation(context, opname) + +### Location + +struct Location + location::MlirLocation + + Location(location) = begin + @assert !mlirIsNull(location) "cannot create Location with null MlirLocation" + new(location) + end +end + +Location(context::Context) = Location(API.mlirLocationUnknownGet(context)) +Location(context::Context, filename, line, column) = + Location(API.mlirLocationFileLineColGet(context, filename, line, column)) + +Base.convert(::Type{MlirLocation}, location::Location) = location.location + +function Base.show(io::IO, location::Location) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + print(io, "Location(#= ") + API.mlirLocationPrint(location, c_print_callback, ref) + print(io, " =#)") +end + +### Type + +struct MLIRType + type::MlirType + + MLIRType(type) = begin + @assert !mlirIsNull(type) + new(type) + end +end + +MLIRType(t::MLIRType) = t +MLIRType(context::Context, T::Type{<:Signed}) = + MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MLIRType(context::Context, T::Type{<:Unsigned}) = + MLIRType(API.mlirIntegerTypeGet(context, sizeof(T) * 8)) +MLIRType(context::Context, ::Type{Bool}) = + MLIRType(API.mlirIntegerTypeGet(context, 1)) +MLIRType(context::Context, ::Type{Float32}) = + MLIRType(API.mlirF32TypeGet(context)) +MLIRType(context::Context, ::Type{Float64}) = + MLIRType(API.mlirF64TypeGet(context)) +MLIRType(context::Context, ft::Pair) = + MLIRType(API.mlirFunctionTypeGet(context, + length(ft.first), [MLIRType(t) for t in ft.first], + length(ft.second), [MLIRType(t) for t in ft.second])) +MLIRType(context, a::AbstractArray{T}) where {T} = MLIRType(context, MLIRType(context, T), size(a)) +MLIRType(context, ::Type{<:AbstractArray{T,N}}, dims) where {T,N} = + MLIRType(API.mlirRankedTensorTypeGetChecked( + Location(context), + N, collect(dims), + MLIRType(context, T), + Attribute(), + )) +MLIRType(context, element_type::MLIRType, dims) = + MLIRType(API.mlirRankedTensorTypeGetChecked( + Location(context), + length(dims), collect(dims), + element_type, + Attribute(), + )) +MLIRType(context, ::T) where {T<:Real} = MLIRType(context, T) +MLIRType(_, type::MLIRType) = type + +IndexType(context) = MLIRType(API.mlirIndexTypeGet(context)) + +Base.convert(::Type{MlirType}, mtype::MLIRType) = mtype.type +Base.parse(::Type{MLIRType}, context, s) = + MLIRType(API.mlirTypeParseGet(context, s)) + +function Base.eltype(type::MLIRType) + if API.mlirTypeIsAShaped(type) + MLIRType(API.mlirShapedTypeGetElementType(type)) + else + type + end +end + +function show_inner(io::IO, type::MLIRType) + if API.mlirTypeIsAInteger(type) + is_signless = API.mlirIntegerTypeIsSignless(type) + is_signed = API.mlirIntegerTypeIsSigned(type) + + width = API.mlirIntegerTypeGetWidth(type) + t = if is_signed + "si" + elseif is_signless + "i" + else + "u" + end + print(io, t, width) + elseif API.mlirTypeIsAF64(type) + print(io, "f64") + elseif API.mlirTypeIsAF32(type) + print(io, "f32") + elseif API.mlirTypeIsARankedTensor(type) + print(io, "tensor<") + s = size(type) + print(io, join(s, "x"), "x") + show_inner(io, eltype(type)) + print(io, ">") + elseif API.mlirTypeIsAIndex(type) + print(io, "index") + else + print(io, "unknown") + end +end + +function Base.show(io::IO, type::MLIRType) + print(io, "MLIRType(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirTypePrint(type, c_print_callback, ref) + print(io, " =#)") +end + +function inttype(size, issigned) + size == 1 && issigned && return Bool + ints = (Int8, Int16, Int32, Int64, Int128) + IT = ints[Int(log2(size)) - 2] + issigned ? IT : unsigned(IT) +end + +function julia_type(type::MLIRType) + if API.mlirTypeIsAInteger(type) + is_signed = API.mlirIntegerTypeIsSigned(type) || + API.mlirIntegerTypeIsSignless(type) + width = API.mlirIntegerTypeGetWidth(type) + + try + inttype(width, is_signed) + catch + t = is_signed ? "i" : "u" + throw("could not convert type $(t)$(width) to julia") + end + elseif API.mlirTypeIsAF32(type) + Float32 + elseif API.mlirTypeIsAF64(type) + Float64 + else + throw("could not convert type $type to julia") + end +end + +Base.ndims(type::MLIRType) = + if API.mlirTypeIsAShaped(type) && API.mlirShapedTypeHasRank(type) + API.mlirShapedTypeGetRank(type) + else + 0 + end + +Base.size(type::MLIRType, i::Int) = API.mlirShapedTypeGetDimSize(type, i - 1) +Base.size(type::MLIRType) = Tuple(size(type, i) for i in 1:ndims(type)) + +function is_tensor(type::MLIRType) + API.mlirTypeIsAShaped(type) +end + +function is_integer(type::MLIRType) + API.mlirTypeIsAInteger(type) +end + +is_function_type(mtype) = API.mlirTypeIsAFunction(mtype) + +function num_inputs(ftype::MLIRType) + @assert is_function_type(ftype) "cannot get the number of inputs on type $(ftype), expected a function type" + API.mlirFunctionTypeGetNumInputs(ftype) +end +function num_results(ftype::MLIRType) + @assert is_function_type(ftype) "cannot get the number of results on type $(ftype), expected a function type" + API.mlirFunctionTypeGetNumResults(ftype) +end + +function get_input(ftype::MLIRType, pos) + @assert is_function_type(ftype) "cannot get input on type $(ftype), expected a function type" + MLIRType(API.mlirFunctionTypeGetInput(ftype, pos - 1)) +end +function get_result(ftype::MLIRType, pos=1) + @assert is_function_type(ftype) "cannot get result on type $(ftype), expected a function type" + MLIRType(API.mlirFunctionTypeGetResult(ftype, pos - 1)) +end + +### Attribute + +struct Attribute + attribute::MlirAttribute +end + +Attribute() = Attribute(API.mlirAttributeGetNull()) +Attribute(context, s::AbstractString) = Attribute(API.mlirStringAttrGet(context, s)) +Attribute(type::MLIRType) = Attribute(API.mlirTypeAttrGet(type)) +Attribute(context, f::F, type=MLIRType(context, F)) where {F<:AbstractFloat} = Attribute( + API.mlirFloatAttrDoubleGet(context, type, Float64(f)) +) +Attribute(context, i::T) where {T<:Integer} = Attribute( + API.mlirIntegerAttrGet(MLIRType(context, T), Int64(i)) +) +function Attribute(context, values::T) where {T<:AbstractArray{Int32}} + type = MLIRType(context, T, size(values)) + Attribute( + API.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Int64}} + type = MLIRType(context, T, size(values)) + Attribute( + API.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Float64}} + type = MLIRType(context, T, size(values)) + Attribute( + API.mlirDenseElementsAttrDoubleGet(type, length(values), values) + ) +end +function Attribute(context, values::T) where {T<:AbstractArray{Float32}} + type = MLIRType(context, T, size(values)) + Attribute( + API.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Int32}, type) + Attribute( + API.mlirDenseElementsAttrInt32Get(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Int}, type) + Attribute( + API.mlirDenseElementsAttrInt64Get(type, length(values), values) + ) +end +function Attribute(context, values::AbstractArray{Float32}, type) + Attribute( + API.mlirDenseElementsAttrFloatGet(type, length(values), values) + ) +end +function ArrayAttribute(context, values::AbstractVector{Int}) + elements = Attribute.((context,), values) + Attribute( + API.mlirArrayAttrGet(context, length(elements), elements) + ) +end +function ArrayAttribute(context, attributes::Vector{Attribute}) + Attribute( + API.mlirArrayAttrGet(context, length(attributes), attributes), + ) +end +function DenseArrayAttribute(context, values::AbstractVector{Int}) + Attribute( + API.mlirDenseI64ArrayGet(context, length(values), collect(values)) + ) +end +function Attribute(context, value::Int, type::MLIRType) + Attribute( + API.mlirIntegerAttrGet(type, value) + ) +end +function Attribute(context, value::Bool, ::MLIRType=nothing) + Attribute( + API.mlirBoolAttrGet(context, value) + ) +end + +Base.convert(::Type{MlirAttribute}, attribute::Attribute) = attribute.attribute +Base.parse(::Type{Attribute}, context, s) = + Attribute(API.mlirAttributeParseGet(context, s)) + +function get_type(attribute::Attribute) + MLIRType(API.mlirAttributeGetType(attribute)) +end +function type_value(attribute) + @assert API.mlirAttributeIsAType(attribute) "attribute $(attribute) is not a type" + MLIRType(API.mlirTypeAttrGetValue(attribute)) +end +function bool_value(attribute) + @assert API.mlirAttributeIsABool(attribute) "attribute $(attribute) is not a boolean" + API.mlirBoolAttrGetValue(attribute) +end +function string_value(attribute) + @assert API.mlirAttributeIsAString(attribute) "attribute $(attribute) is not a string attribute" + String(API.mlirStringAttrGetValue(attribute)) +end + +function Base.show(io::IO, attribute::Attribute) + print(io, "Attribute(#= ") + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirAttributePrint(attribute, c_print_callback, ref) + print(io, " =#)") +end + +### Named Attribute + +struct NamedAttribute + named_attribute::MlirNamedAttribute +end + +function NamedAttribute(context, name, attribute) + @assert !mlirIsNull(attribute.attribute) + NamedAttribute(API.mlirNamedAttributeGet( + API.mlirIdentifierGet(context, name), + attribute + )) +end + +Base.convert(::Type{MlirAttribute}, named_attribute::NamedAttribute) = + named_attribute.named_attribute + +### Value + +struct Value + value::MlirValue + + Value(value) = begin + @assert !mlirIsNull(value) "cannot create Value with null MlirValue" + new(value) + end +end + +get_type(value) = MLIRType(API.mlirValueGetType(value)) + +Base.convert(::Type{MlirValue}, value::Value) = value.value +Base.size(value::Value) = Base.size(get_type(value)) +Base.ndims(value::Value) = Base.ndims(get_type(value)) + +function Base.show(io::IO, value::Value) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirValuePrint(value, c_print_callback, ref) +end + +is_a_op_result(value) = API.mlirValueIsAOpResult(value) +is_a_block_argument(value) = API.mlirValueIsABlockArgument(value) + +function set_type!(value, type) + @assert is_a_block_argument(value) "could not set type, value is not a block argument" + API.mlirBlockArgumentSetType(value, type) + value +end + +function get_owner(value::Value) + if is_a_block_argument(value) + raw_block = API.mlirBlockArgumentGetOwner(value) + if mlirIsNull(raw_block) + return nothing + end + + return Block(raw_block, false) + end + + raw_op = API.mlirOpResultGetOwner(value) + if mlirIsNull(raw_op) + return nothing + end + + return Operation(raw_op, false) +end + +### Operation + +mutable struct Operation + operation::MlirOperation + @atomic owned::Bool + + Operation(operation, owned=true) = begin + @assert !mlirIsNull(operation) "cannot create Operation with null MlirOperation" + finalizer(new(operation, owned)) do op + if op.owned + API.mlirOperationDestroy(op.operation) + end + end + end +end + +function create_operation( + name, loc; + results=nothing, + operands=nothing, + owned_regions=nothing, + successors=nothing, + attributes=nothing, + result_inference=isnothing(results), +) + GC.@preserve name loc begin + state = Ref(API.mlirOperationStateGet(name, loc)) + if !isnothing(results) + if result_inference + error("Result inference and provided results conflict") + end + API.mlirOperationStateAddResults(state, length(results), results) + end + if !isnothing(operands) + API.mlirOperationStateAddOperands(state, length(operands), operands) + end + if !isnothing(owned_regions) + lose_ownership!.(owned_regions) + GC.@preserve owned_regions begin + mlir_regions = Base.unsafe_convert.(MlirRegion, owned_regions) + API.mlirOperationStateAddOwnedRegions(state, length(mlir_regions), mlir_regions) + end + end + if !isnothing(successors) + GC.@preserve successors begin + mlir_blocks = Base.unsafe_convert.(MlirBlock, successors) + API.mlirOperationStateAddSuccessors( + state, + length(mlir_blocks), + mlir_blocks, + ) + end + end + if !isnothing(attributes) + API.mlirOperationStateAddAttributes(state, length(attributes), attributes) + end + if result_inference + API.mlirOperationStateEnableResultTypeInference(state) + end + op = API.mlirOperationCreate(state) + if mlirIsNull(op) + error("Create Operation failed") + end + Operation(op, true) + end +end + +Base.copy(operation::Operation) = Operation(API.mlirOperationClone(operation)) + +num_regions(operation) = API.mlirOperationGetNumRegions(operation) +function get_region(operation, i) + i ∉ 1:num_regions(operation) && throw(BoundsError(operation, i)) + Region(API.mlirOperationGetRegion(operation, i - 1), false) +end +num_results(operation::Operation) = API.mlirOperationGetNumResults(operation) +get_results(operation) = [ + get_result(operation, i) + for i in 1:num_results(operation) +] +function get_result(operation::Operation, i=1) + i ∉ 1:num_results(operation) && throw(BoundsError(operation, i)) + Value(API.mlirOperationGetResult(operation, i - 1)) +end +num_operands(operation) = API.mlirOperationGetNumOperands(operation) +function get_operand(operation, i=1) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + Value(API.mlirOperationGetOperand(operation, i - 1)) +end +function set_operand!(operation, i, value) + i ∉ 1:num_operands(operation) && throw(BoundsError(operation, i)) + API.mlirOperationSetOperand(operation, i - 1, value) + value +end + +function get_attribute_by_name(operation, name) + raw_attr = API.mlirOperationGetAttributeByName(operation, name) + if mlirIsNull(raw_attr) + return nothing + end + Attribute(raw_attr) +end +function set_attribute_by_name!(operation, name, attribute) + API.mlirOperationSetAttributeByName(operation, name, attribute) + operation +end + +location(operation) = Location(API.mlirOperationGetLocation(operation)) +name(operation) = String(API.mlirOperationGetName(operation)) +block(operation) = Block(API.mlirOperationGetBlock(operation), false) +parent_operation(operation) = Operation(API.mlirOperationGetParentOperation(operation), false) +dialect(operation) = first(split(get_name(operation), '.')) |> Symbol + +function get_first_region(op::Operation) + reg = iterate(RegionIterator(op)) + isnothing(reg) && return nothing + first(reg) +end +function get_first_block(op::Operation) + reg = get_first_region(op) + isnothing(reg) && return nothing + block = iterate(BlockIterator(reg)) + isnothing(block) && return nothing + first(block) +end +function get_first_child_op(op::Operation) + block = get_first_block(op) + isnothing(block) && return nothing + cop = iterate(OperationIterator(block)) + first(cop) +end + +op::Operation == other::Operation = API.mlirOperationEqual(op, other) + +Base.cconvert(::Type{MlirOperation}, operation::Operation) = operation +Base.unsafe_convert(::Type{MlirOperation}, operation::Operation) = operation.operation + +function lose_ownership!(operation::Operation) + @assert operation.owned + @atomic operation.owned = false + operation +end + +function Base.show(io::IO, operation::Operation) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + flags = API.mlirOpPrintingFlagsCreate() + get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) + API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) + println(io) +end + +verify(operation::Operation) = API.mlirOperationVerify(operation) + +### Block + +mutable struct Block + block::MlirBlock + @atomic owned::Bool + + Block(block::MlirBlock, owned::Bool=true) = begin + @assert !mlirIsNull(block) "cannot create Block with null MlirBlock" + finalizer(new(block, owned)) do block + if block.owned + API.mlirBlockDestroy(block.block) + end + end + end +end + +Block() = Block(MLIRType[], Location[]) +function Block(args::Vector{MLIRType}, locs::Vector{Location}) + @assert length(args) == length(locs) "there should be one args for each locs (got $(length(args)) & $(length(locs)))" + Block(API.mlirBlockCreate(length(args), args, locs)) +end + +function Base.push!(block::Block, op::Operation) + API.mlirBlockAppendOwnedOperation(block, lose_ownership!(op)) + op +end +function Base.insert!(block::Block, pos, op::Operation) + API.mlirBlockInsertOwnedOperation(block, pos - 1, lose_ownership!(op)) + op +end +function Base.pushfirst!(block::Block, op::Operation) + insert!(block, 1, op) + op +end +function insert_after!(block::Block, reference::Operation, op::Operation) + API.mlirBlockInsertOwnedOperationAfter(block, reference, lose_ownership!(op)) + op +end +function insert_before!(block::Block, reference::Operation, op::Operation) + API.mlirBlockInsertOwnedOperationBefore(block, reference, lose_ownership!(op)) + op +end + +num_arguments(block::Block) = + API.mlirBlockGetNumArguments(block) +function get_argument(block::Block, i) + i ∉ 1:num_arguments(block) && throw(BoundsError(block, i)) + Value(API.mlirBlockGetArgument(block, i - 1)) +end +push_argument!(block::Block, type, loc) = + Value(API.mlirBlockAddArgument(block, type, loc)) + +Base.cconvert(::Type{MlirBlock}, block::Block) = block +Base.unsafe_convert(::Type{MlirBlock}, block::Block) = block.block + +function lose_ownership!(block::Block) + @assert block.owned + @atomic block.owned = false + block +end + +function Base.show(io::IO, block::Block) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + API.mlirBlockPrint(block, c_print_callback, ref) +end + +### Region + +mutable struct Region + region::MlirRegion + @atomic owned::Bool + + Region(region, owned=true) = begin + @assert !mlirIsNull(region) + finalizer(new(region, owned)) do region + if region.owned + API.mlirRegionDestroy(region.region) + end + end + end +end + +Region() = Region(API.mlirRegionCreate()) + +function Base.push!(region::Region, block::Block) + API.mlirRegionAppendOwnedBlock(region, lose_ownership!(block)) + block +end +function Base.insert!(region::Region, pos, block::Block) + API.mlirRegionInsertOwnedBlock(region, pos - 1, lose_ownership!(block)) + block +end +function Base.pushfirst!(region::Region, block) + insert!(region, 1, block) + block +end +insert_after!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockAfter(region, reference, lose_ownership!(block)) +insert_before!(region::Region, reference::Block, block::Block) = + API.mlirRegionInsertOwnedBlockBefore(region, reference, lose_ownership!(block)) + +function get_first_block(region::Region) + block = iterate(BlockIterator(region)) + isnothing(block) && return nothing + first(block) +end + +function lose_ownership!(region::Region) + @assert region.owned + @atomic region.owned = false + region +end + +Base.cconvert(::Type{MlirRegion}, region::Region) = region +Base.unsafe_convert(::Type{MlirRegion}, region::Region) = region.region + +### Module + +mutable struct MModule + module_::MlirModule + context::Context + + MModule(module_, context) = begin + @assert !mlirIsNull(module_) "cannot create MModule with null MlirModule" + finalizer(API.mlirModuleDestroy, new(module_, context)) + end +end + +MModule(context::Context, loc=Location(context)) = + MModule(API.mlirModuleCreateEmpty(loc), context) +get_operation(module_) = Operation(API.mlirModuleGetOperation(module_), false) +get_body(module_) = Block(API.mlirModuleGetBody(module_), false) +get_first_child_op(mod::MModule) = get_first_child_op(get_operation(mod)) + +Base.convert(::Type{MlirModule}, module_::MModule) = module_.module_ +Base.parse(::Type{MModule}, context, module_) = MModule(API.mlirModuleCreateParse(context, module_), context) + +macro mlir_str(code) + quote + ctx = Context() + parse(MModule, ctx, code) + end +end + +function Base.show(io::IO, module_::MModule) + println(io, "MModule:") + show(io, get_operation(module_)) +end + +### TypeID + +struct TypeID + typeid::API.MlirTypeID +end + +Base.hash(typeid::TypeID) = API.mlirTypeIDHashValue(typeid.typeid) +Base.convert(::Type{API.MlirTypeID}, typeid::TypeID) = typeid.typeid + +@static if isdefined(API, :MlirTypeIDAllocator) + +### TypeIDAllocator + +mutable struct TypeIDAllocator + allocator::API.MlirTypeIDAllocator + + function TypeIDAllocator() + ptr = API.mlirTypeIDAllocatorCreate() + @assert ptr != C_NULL "cannot create TypeIDAllocator" + finalizer(API.mlirTypeIDAllocatorDestroy, new(ptr)) + end +end + +Base.cconvert(::Type{API.MlirTypeIDAllocator}, allocator::TypeIDAllocator) = allocator +Base.unsafe_convert(::Type{API.MlirTypeIDAllocator}, allocator) = allocator.allocator + +TypeID(allocator::TypeIDAllocator) = TypeID(API.mlirTypeIDCreate(allocator)) + +else + +struct TypeIDAllocator end + +end + +include("./Support.jl") +include("./Pass.jl") + +end # module IR diff --git a/src/IR/Pass.jl b/src/IR/Pass.jl new file mode 100644 index 00000000..7eef5b88 --- /dev/null +++ b/src/IR/Pass.jl @@ -0,0 +1,176 @@ +### Pass Manager + +abstract type AbstractPass end + +mutable struct ExternalPassHandle + ctx::Union{Nothing,Context} + pass::AbstractPass +end + +mutable struct PassManager + pass::MlirPassManager + context::Context + allocator::TypeIDAllocator + passes::Dict{TypeID,ExternalPassHandle} + + PassManager(pm::MlirPassManager, context) = begin + @assert !mlirIsNull(pm) "cannot create PassManager with null MlirPassManager" + finalizer(new(pm, context, TypeIDAllocator(), Dict{TypeID,ExternalPassHandle}())) do pm + API.mlirPassManagerDestroy(pm.pass) + end + end +end + +function enable_ir_printing!(pm) + API.mlirPassManagerEnableIRPrinting(pm) + pm +end +function enable_verifier!(pm, enable=true) + API.mlirPassManagerEnableVerifier(pm, enable) + pm +end + +PassManager(context) = + PassManager(API.mlirPassManagerCreate(context), context) + +function run!(pm::PassManager, module_) + status = API.mlirPassManagerRun(pm, module_) + if mlirLogicalResultIsFailure(status) + throw("failed to run pass manager on module") + end + module_ +end + +Base.convert(::Type{MlirPassManager}, pass::PassManager) = pass.pass + +### Op Pass Manager + +struct OpPassManager + op_pass::MlirOpPassManager + pass::PassManager + + OpPassManager(op_pass, pass) = begin + @assert !mlirIsNull(op_pass) "cannot create OpPassManager with null MlirOpPassManager" + new(op_pass, pass) + end +end + +OpPassManager(pm::PassManager) = OpPassManager(API.mlirPassManagerGetAsOpPassManager(pm), pm) +OpPassManager(pm::PassManager, opname) = OpPassManager(API.mlirPassManagerGetNestedUnder(pm, opname), pm) +OpPassManager(opm::OpPassManager, opname) = OpPassManager(API.mlirOpPassManagerGetNestedUnder(opm, opname), opm.pass) + +Base.convert(::Type{MlirOpPassManager}, op_pass::OpPassManager) = op_pass.op_pass + +function Base.show(io::IO, op_pass::OpPassManager) + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + ref = Ref(io) + println(io, "OpPassManager(\"\"\"") + API.mlirPrintPassPipeline(op_pass, c_print_callback, ref) + println(io) + print(io, "\"\"\")") +end + +struct AddPipelineException <: Exception + message::String +end + +function Base.showerror(io::IO, err::AddPipelineException) + print(io, "failed to add pipeline:", err.message) + nothing +end + +function add_pipeline!(op_pass::OpPassManager, pipeline) + @static if isdefined(API, :mlirOpPassManagerAddPipeline) + io = IOBuffer() + c_print_callback = @cfunction(print_callback, Cvoid, (MlirStringRef, Any)) + result = GC.@preserve io API.mlirOpPassManagerAddPipeline(op_pass, pipeline, c_print_callback, io) + if mlirLogicalResultIsFailure(result) + exc = AddPipelineException(String(take!(io))) + throw(exc) + end + else + result = API.mlirParsePassPipeline(op_pass, pipeline) + if mlirLogicalResultIsFailure(result) + throw(AddPipelineException(" " * pipeline)) + end + end + op_pass +end + +function add_owned_pass!(pm::PassManager, pass) + API.mlirPassManagerAddOwnedPass(pm, pass) + pm +end + +function add_owned_pass!(opm::OpPassManager, pass) + API.mlirOpPassManagerAddOwnedPass(opm, pass) + opm +end + + +@static if isdefined(API, :mlirCreateExternalPass) + +### Pass + +# AbstractPass interface: +opname(::AbstractPass) = "" +function pass_run(::Context, ::P, op) where {P<:AbstractPass} + error("pass $P does not implement `MLIR.pass_run`") +end + +function _pass_construct(ptr::ExternalPassHandle) + nothing +end + +function _pass_destruct(ptr::ExternalPassHandle) + nothing +end + +function _pass_initialize(ctx, handle::ExternalPassHandle) + try + handle.ctx = Context(ctx) + mlirLogicalResultSuccess() + catch + mlirLogicalResultFailure() + end +end + +function _pass_clone(handle::ExternalPassHandle) + ExternalPassHandle(handle.ctx, deepcopy(handle.pass)) +end + +function _pass_run(rawop, external_pass, handle::ExternalPassHandle) + op = Operation(rawop, false) + try + pass_run(handle.ctx, handle.pass, op) + catch ex + @error "Something went wrong running pass" exception=(ex,catch_backtrace()) + API.mlirExternalPassSignalFailure(external_pass) + end + nothing +end + +function create_external_pass!(oppass::OpPassManager, args...) + create_external_pass!(oppass.pass, args...) +end +function create_external_pass!(manager, pass, name, argument, + description, opname=opname(pass), + dependent_dialects=MlirDialectHandle[]) + passid = TypeID(manager.allocator) + callbacks = API.MlirExternalPassCallbacks( + @cfunction(_pass_construct, Cvoid, (Any,)), + @cfunction(_pass_destruct, Cvoid, (Any,)), + @cfunction(_pass_initialize, API.MlirLogicalResult, (MlirContext, Any,)), + @cfunction(_pass_clone, Any, (Any,)), + @cfunction(_pass_run, Cvoid, (MlirOperation, API.MlirExternalPass, Any)) + ) + pass_handle = manager.passes[passid] = ExternalPassHandle(nothing, pass) + userdata = Base.pointer_from_objref(pass_handle) + mlir_pass = API.mlirCreateExternalPass(passid, name, argument, description, opname, + length(dependent_dialects), dependent_dialects, + callbacks, userdata) + mlir_pass +end + +end + diff --git a/src/IR/Support.jl b/src/IR/Support.jl new file mode 100644 index 00000000..f84689e3 --- /dev/null +++ b/src/IR/Support.jl @@ -0,0 +1,133 @@ +function mlirIsNull(val) + val.ptr == C_NULL +end + +### Identifier + +String(ident::MlirIdentifier) = String(API.mlirIdentifierStr(ident)) + +### Logical Result + +mlirLogicalResultSuccess() = API.MlirLogicalResult(1) +mlirLogicalResultFailure() = API.MlirLogicalResult(0) + +mlirLogicalResultIsSuccess(result) = result.value != 0 +mlirLogicalResultIsFailure(result) = result.value == 0 + +### Iterators + +""" + BlockIterator(region::Region) + +Iterates over all blocks in the given region. +""" +struct BlockIterator + region::Region +end + +function Base.iterate(it::BlockIterator) + reg = it.region + raw_block = API.mlirRegionGetFirstBlock(reg) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +function Base.iterate(it::BlockIterator, block) + raw_block = API.mlirBlockGetNextInRegion(block) + if mlirIsNull(raw_block) + nothing + else + b = Block(raw_block, false) + (b, b) + end +end + +""" + OperationIterator(block::Block) + +Iterates over all operations for the given block. +""" +struct OperationIterator + block::Block +end + +function Base.iterate(it::OperationIterator) + raw_op = API.mlirBlockGetFirstOperation(it.block) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +function Base.iterate(it::OperationIterator, op) + raw_op = API.mlirOperationGetNextInBlock(op) + if mlirIsNull(raw_op) + nothing + else + op = Operation(raw_op, false) + (op, op) + end +end + +""" + RegionIterator(::Operation) + +Iterates over all sub-regions for the given operation. +""" +struct RegionIterator + op::Operation +end + +function Base.iterate(it::RegionIterator) + raw_region = API.mlirOperationGetFirstRegion(it.op) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +function Base.iterate(it::RegionIterator, region) + raw_region = API.mlirRegionGetNextInOperation(region) + if mlirIsNull(raw_region) + nothing + else + region = Region(raw_region, false) + (region, region) + end +end + +### Utils + +function visit(f, op) + for region in RegionIterator(op) + for block in BlockIterator(region) + for op in OperationIterator(block) + f(op) + end + end + end +end + +""" + verifyall(operation; debug=false) + +Prints the operations which could not be verified. +""" +function verifyall(operation::Operation; debug=false) + io = IOContext(stdout, :debug => debug) + visit(operation) do op + if !verify(op) + show(io, op) + end + end +end +verifyall(module_::MModule) = get_operation(module_) |> verifyall + diff --git a/src/MLIR.jl b/src/MLIR.jl index c60e67e4..36638296 100644 --- a/src/MLIR.jl +++ b/src/MLIR.jl @@ -35,4 +35,7 @@ function Base.unsafe_convert(::Type{API.MlirStringRef}, s::Union{Symbol, String, return API.MlirStringRef(p, length(s)) end -end # module +include("./IR/IR.jl") +include("./Dialects.jl") + +end # module MLIR