Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Makefile and run.sh into build.jl script #1

Merged
merged 11 commits into from
Jan 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions deps/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[deps]
LLVM_full_jll = "a3ccf953-465e-511d-b87f-60a6490c289d"

[compat]
LLVM_full_jll = "15"
86 changes: 86 additions & 0 deletions deps/build.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
using LLVM_full_jll

println("Environment")
println("- llvm-config = $(LLVM_full_jll.get_llvm_config_path())")
println("- clang = $(LLVM_full_jll.get_clang_path())")

CXXFLAGS = `$(llvm_config()) --cxxflags` |> readchomp |> split
LDFLAGS = `$(llvm_config()) --ldflags` |> readchomp |> split
println("- CXXFLAGS = $CXXFLAGS")
println("- LDFLAGS = $LDFLAGS")

INCLUDE_PATH = joinpath(LLVM_full_jll.artifact_dir, "include")
DIALECTS_PATH = joinpath(INCLUDE_PATH, "mlir", "Dialect")
println("- INCLUDE_PATH = $INCLUDE_PATH")
println("- DIALECTS_PATH = $DIALECTS_PATH")

# compile TableGen generator
println("Compiling TableGen generator...")
files = [joinpath(@__DIR__, "tblgen", "mlir-jl-tblgen.cc"), joinpath(@__DIR__, "tblgen", "jl-generators.cc")]
output = ["-o", "mlir-jl-tblgen"]
libs = ["-lLLVM", "-lMLIR", "-lMLIRTableGen", "-lLLVMTableGen"]

extra = ["-rpath", joinpath(LLVM_full_jll.artifact_dir, "lib")]
if Base.Sys.isapple()
isysroot = strip(read(`xcrun --show-sdk-path`, String))
append!(extra, [
"-isysroot",
isysroot,
"-lc++",
])
elseif Base.Sys.islinux()
append!(extra, [
"-lstdc++",
])
end
println("- extra flags = $extra")

run(`$(clang()) $files $CXXFLAGS $LDFLAGS $extra $libs $output`)

# generate bindings
println("Generating bindings...")

target_dialects = [
("Builtin.jl", "../IR/BuiltinOps.td"),
("AMDGPU.jl", "AMDGPU/AMDGPU.td"),
("AMX.jl", "AMX/AMX.td"),
("Affine.jl", "Affine/IR/AffineOps.td"),
("Arithmetic.jl", "Arithmetic/IR/ArithmeticOps.td"),
# ("ArmNeon.jl", "ArmNeon/ArmNeon.td"),
("ArmSVE.jl", "ArmSVE/ArmSVE.td"),
("Async.jl", "Async/IR/AsyncOps.td"),
("Bufferization.jl", "Bufferization/IR/BufferizationOps.td"),
("Complex.jl", "Complex/IR/ComplexOps.td"),
("ControlFlow.jl", "ControlFlow/IR/ControlFlowOps.td"),
# ("DLTI.jl", "DLTI/DLTI.td"),
("EmitC.jl", "EmitC/IR/EmitC.td"),
("Func.jl", "Func/IR/FuncOps.td"),
# ("GPU.jl", "GPU/IR/GPUOps.td"),
("Linalg.jl", "Linalg/IR/LinalgOps.td"),
# ("LinalgStructured.jl", "Linalg/IR/LinalgStructuredOps.td"),
("LLVMIR.jl", "LLVMIR/LLVMOps.td"),
# ("MLProgram.jl", "MLProgram/IR/MLProgramOps.td"),
("Math.jl", "Math/IR/MathOps.td"),
("MemRef.jl", "MemRef/IR/MemRefOps.td"),
("NVGPU.jl", "NVGPU/IR/NVGPU.td"),
# ("OpenACC.jl", "OpenACC/OpenACCOps.td"),
# ("OpenMP.jl", "OpenMP/OpenMPOps.td"),
# ("PDL.jl", "PDL/IR/PDLOps.td"),
# ("PDLInterp.jl", "PDLInterp/IR/PDLInterpOps.td"),
("Quant.jl", "Quant/QuantOps.td"),
# ("SCF.jl", "SCF/IR/SCFOps.td"),
# ("SPIRV.jl", "SPIRV/IR/SPIRVOps.td"),
("Shape.jl", "Shape/IR/ShapeOps.td"),
("SparseTensor.jl", "SparseTensor/IR/SparseTensorOps.td"),
("Tensor.jl", "Tensor/IR/TensorOps.td"),
# ("Tosa.jl", "Tosa/IR/TosaOps.td"),
("Transform.jl", "Transform/IR/TransformOps.td"),
("Vector.jl", "Vector/IR/VectorOps.td"),
# ("X86Vector.jl", "X86Vector/X86Vector.td"),
]

for (file, path) in target_dialects
output = joinpath(@__DIR__, "..", "src", "dialects", file)
run(`./mlir-jl-tblgen --generator=jl-op-defs $(joinpath(DIALECTS_PATH, path)) -I$INCLUDE_PATH -o $output`)
println("- Generated \"$output\" from \"$path\"")
end
10 changes: 0 additions & 10 deletions deps/tblgen/Makefile

This file was deleted.

33 changes: 0 additions & 33 deletions deps/tblgen/run.sh

This file was deleted.

203 changes: 18 additions & 185 deletions src/Dialects.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,190 +18,23 @@ operandsegmentsizes(segments) = namedattribute(
Int32.(segments)
)))

include("dialects/builtin.jl")

include("dialects/llvm.jl")

# include("dialects/arith.jl")

# include("dialects/cf.jl")

# include("dialects/func.jl")

# include("dialects/Gpu.jl")

# include("dialects/Memref.jl")

# include("dialects/Index.jl")

include("dialects/affine.jl")

# include("dialects/Ub.jl")

# include("dialects/SCF.jl")

module arith

using ...IR

for (f, t) in Iterators.product(
(:add, :sub, :mul),
(:i, :f),
)
fname = Symbol(f, t)
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end

for fname in (:xori, :andi, :ori)
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end

for (f, t) in Iterators.product(
(:div, :max, :min),
(:si, :ui, :f),
)
fname = Symbol(f, t)
@eval function $fname(operands, type=IR.get_type(first(operands)); loc=Location())
IR.create_operation($(string("arith.", fname)), loc; operands, results=[type])
end
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithindex_cast-mlirarithindexcastop
for f in (:index_cast, :index_castui)
@eval function $f(operand; loc=Location())
IR.create_operation(
$(string("arith.", f)),
loc;
operands=[operand],
results=[IR.IndexType()],
)
end
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithextf-mlirarithextfop
function extf(operand, type; loc=Location())
IR.create_operation("arith.exf", loc; operands=[operand], results=[type])
end

# https://mlir.llvm.org/docs/Dialects/ArithOps/#arithconstant-mlirarithconstantop
function constant(value, type=MLIRType(typeof(value)); loc=Location())
IR.create_operation(
"arith.constant",
loc;
results=[type],
attributes=[
IR.NamedAttribute("value",
Attribute(value, type)),
],
)
end

module Predicates
const eq = 0
const ne = 1
const slt = 2
const sle = 3
const sgt = 4
const sge = 5
const ult = 6
const ule = 7
const ugt = 8
const uge = 9
end

function cmpi(predicate, operands; loc=Location())
IR.create_operation(
"arith.cmpi",
loc;
operands,
results=[MLIRType(Bool)],
attributes=[
IR.NamedAttribute("predicate",
Attribute(predicate))
],
)
end

end # module arith

module std
# for llvm 14

using ...IR

function return_(operands; loc=Location())
IR.create_operation("std.return", loc; operands, result_inference=false)
end

function br(dest, operands; loc=Location())
IR.create_operation("std.br", loc; operands, successors=[dest], result_inference=false)
end

function cond_br(
cond,
true_dest, false_dest,
true_dest_operands,
false_dest_operands;
loc=Location(),
)
IR.create_operation(
"std.cond_br",
loc;
successors=[true_dest, false_dest],
operands=[cond, true_dest_operands..., false_dest_operands...],
attributes=[
IR.NamedAttribute("operand_segment_sizes",
IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)]))
],
result_inference=false,
)
end

end # module std

module func
# https://mlir.llvm.org/docs/Dialects/Func/

using ...IR

function return_(operands; loc=Location())
IR.create_operation("func.return", loc; operands, result_inference=false)
end

end # module func

module cf

using ...IR

function br(dest, operands; loc=Location())
IR.create_operation("cf.br", loc; operands, successors=[dest], result_inference=false)
end

function cond_br(
cond,
true_dest, false_dest,
true_dest_operands,
false_dest_operands;
loc=Location(),
)
IR.create_operation(
"cf.cond_br", loc;
operands=[cond, true_dest_operands..., false_dest_operands...],
successors=[true_dest, false_dest],
attributes=[
IR.NamedAttribute("operand_segment_sizes",
IR.Attribute(Int32[1, length(true_dest_operands), length(false_dest_operands)]))
],
result_inference=false,
)
end

end # module cf

include.(filter(contains(r".jl$"), readdir(joinpath(@__DIR__, "dialects"); join=true)))

# module arith

# module Predicates
# const eq = 0
# const ne = 1
# const slt = 2
# const sle = 3
# const sgt = 4
# const sge = 5
# const ult = 6
# const ule = 7
# const ugt = 8
# const uge = 9
# end

# end # module arith

end # module Dialects
6 changes: 3 additions & 3 deletions src/MLIR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,11 @@ end
module IR
import ..API: API

include("./IR/IR.jl")
include("./IR/state.jl")
include("IR/IR.jl")
include("IR/state.jl")
end # module IR

include("./Dialects.jl")
include("Dialects.jl")


end # module MLIR
1 change: 1 addition & 0 deletions src/dialects/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.jl
Loading