Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Adapt to Julia 1.7 #18

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions src/Mixtape.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@ module Mixtape
using LLVM
using LLVM.Interop
using LLVM_full_jll
using MacroTools: @capture,
postwalk,
rmlines,
using MacroTools: @capture,
postwalk,
rmlines,
unblock
using CodeInfoTools
using CodeInfoTools: resolve
using Core: MethodInstance,
using Core: MethodInstance,
CodeInstance,
CodeInfo
using Core.Compiler: WorldView,
Expand All @@ -30,7 +30,7 @@ using Core.Compiler: WorldView,
verify_ir,
verify_linetable

using GPUCompiler: cached_compilation,
using GPUCompiler: cached_compilation,
FunctionSpec

import GPUCompiler: AbstractCompilerTarget,
Expand Down Expand Up @@ -62,10 +62,10 @@ import Core.Compiler: InferenceState,
##### Exports
#####

export CompilationContext,
export CompilationContext,
NoContext,
allow,
transform,
allow,
transform,
optimize!,
OptimizationBundle,
get_ir,
Expand Down Expand Up @@ -120,7 +120,7 @@ end

function StaticInterpreter(; ctx = NoContext(), opt = false)
return StaticInterpreter(Dict{MethodInstance, CodeInstance}(),
NativeInterpreter(),
NativeInterpreter(),
Tuple{MethodInstance, Int, String}[],
opt,
ctx)
Expand Down Expand Up @@ -151,7 +151,7 @@ end
#####

function resolve_generic(a)
if a <: Function && isdefined(a, :instance)
if a isa Type && a <: Function && isdefined(a, :instance)
return a.instance
else
return resolve(a)
Expand Down Expand Up @@ -185,7 +185,7 @@ end
@static if VERSION >= v"1.7.0-DEV.662"
using Core.Compiler: finish as _finish
else
function _finish(interp::AbstractInterpreter,
function _finish(interp::AbstractInterpreter,
opt::OptimizationState,
params::OptimizationParams, ir, @nospecialize(result))
return Core.Compiler.finish(opt, params, ir, result)
Expand All @@ -211,7 +211,7 @@ get_state(b::OptimizationBundle) = b.sv
Object which holds inferred `ir::Core.Compiler.IRCode` and a `Core.Compiler.OptimizationState`. Provided to the user through [`optimize!`](@ref), so that the user may plug in their own optimizations.
""", OptimizationBundle)

function julia_passes!(ir::Core.Compiler.IRCode, ci::CodeInfo,
function julia_passes!(ir::Core.Compiler.IRCode, ci::CodeInfo,
sv::OptimizationState)
ir = compact!(ir)
ir = ssa_inlining_pass!(ir, ir.linetable, sv.inlining, ci.propagate_inbounds)
Expand Down Expand Up @@ -265,12 +265,12 @@ function get_codeinfo(code::Core.CodeInstance)
end
end

function get_codeinfo(graph::StaticSubGraph,
function get_codeinfo(graph::StaticSubGraph,
cursor::MethodInstance)
return get_codeinfo(get_codeinstance(graph, cursor))
end

function analyze(@nospecialize(f), tt::Type{T};
function analyze(@nospecialize(f), tt::Type{T};
ctx = NoContext(), opt = false) where T <: Tuple
si = StaticInterpreter(; ctx = ctx, opt = opt)
mi = infer(si, f, tt)
Expand Down Expand Up @@ -349,34 +349,34 @@ function cache_lookup(si::StaticInterpreter, mi::MethodInstance,
Base.get(si.code, mi, nothing)
end

# Mostly from GPUCompiler.
# Mostly from GPUCompiler.
# In future, try to upstream any requires changes.
function codegen(job::CompilerJob)
f = job.source.f
tt = job.source.tt
opt = job.params.opt
si, ssg = analyze(f, tt;
si, ssg = analyze(f, tt;
ctx = job.params.ctx, opt = opt) # Populate local cache.
world = get_world_counter(si)
λ_lookup = (mi, min, max) -> cache_lookup(si, mi, min, max)
lookup_cb = @cfunction($λ_lookup, Any, (Any, UInt, UInt))
params = Base.CodegenParams(;
track_allocations = false,
params = Base.CodegenParams(;
track_allocations = false,
code_coverage = false,
prefer_specsig = true,
gnu_pubnames = false,
lookup = Base.unsafe_convert(Ptr{Nothing}, lookup_cb))

GC.@preserve lookup_cb begin
native_code = ccall(:jl_create_native,
native_code = ccall(:jl_create_native,
Ptr{Cvoid},
(Vector{MethodInstance},
Base.CodegenParams, Cint),
(Vector{MethodInstance},
Base.CodegenParams, Cint),
[ssg.entry],
params, 1) # = extern policy = #
@assert native_code != C_NULL
llvm_mod_ref = ccall(:jl_get_llvm_module,
LLVM.API.LLVMModuleRef,
llvm_mod_ref = ccall(:jl_get_llvm_module,
LLVM.API.LLVMModuleRef,
(Ptr{Cvoid},),
native_code)
@assert llvm_mod_ref != C_NULL
Expand All @@ -385,23 +385,23 @@ function codegen(job::CompilerJob)
code = cache_lookup(si, ssg.entry, world, world)
llvm_func_idx = Ref{Int32}(-1)
llvm_specfunc_idx = Ref{Int32}(-1)
ccall(:jl_get_function_id,
Nothing,
ccall(:jl_get_function_id,
Nothing,
(Ptr{Cvoid}, Any, Ptr{Int32}, Ptr{Int32}),
native_code, code, llvm_func_idx, llvm_specfunc_idx)
@assert llvm_specfunc_idx[] != -1
@assert llvm_func_idx[] != -1
llvm_func_ref = ccall(:jl_get_llvm_function,
llvm_func_ref = ccall(:jl_get_llvm_function,
LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32),
native_code,
(Ptr{Cvoid}, UInt32),
native_code,
llvm_func_idx[] - 1)
@assert llvm_func_ref != C_NULL
llvm_func = LLVM.Function(llvm_func_ref)
llvm_specfunc_ref = ccall(:jl_get_llvm_function,
llvm_specfunc_ref = ccall(:jl_get_llvm_function,
LLVM.API.LLVMValueRef,
(Ptr{Cvoid}, UInt32),
native_code,
(Ptr{Cvoid}, UInt32),
native_code,
llvm_specfunc_idx[] - 1)
@assert llvm_specfunc_ref != C_NULL
llvm_specfunc = LLVM.Function(llvm_specfunc_ref)
Expand Down Expand Up @@ -433,13 +433,13 @@ struct MixtapeCompilerParams <: AbstractCompilerParams
ctx::CompilationContext
end

function MixtapeCompilerParams(; opt = false,
optlevel = Base.JLOptions().opt_level,
function MixtapeCompilerParams(; opt = false,
optlevel = Base.JLOptions().opt_level,
ctx = NoContext())
return MixtapeCompilerParams(opt, optlevel, ctx)
end

function llvm_machine(::MixtapeCompilerTarget,
function llvm_machine(::MixtapeCompilerTarget,
params::MixtapeCompilerParams)
optlevel = get_llvm_optlevel(params.optlevel)
tm = LLVM.JITTargetMachine(; optlevel=optlevel)
Expand All @@ -458,24 +458,24 @@ end
# Slow ABI -- requires array allocation and unpacking. But stable.
expr = quote
args = Any[args...]
ccall(entry.func,
Any,
(Any, Ptr{Any}, Int32),
ccall(entry.func,
Any,
(Any, Ptr{Any}, Int32),
entry.f, args, length(args))
end
return expr
end

const jit_compiled_cache = Dict{UInt, Any}()

function jit(@nospecialize(f), tt::Type{T};
ctx = NoContext(), opt = true,
function jit(@nospecialize(f), tt::Type{T};
ctx = NoContext(), opt = true,
optlevel = Base.JLOptions().opt_level) where T <: Tuple
fspec = FunctionSpec(f, tt, false, nothing) #=name=#
job = CompilerJob(MixtapeCompilerTarget(),
fspec,
MixtapeCompilerParams(;
opt = opt,
job = CompilerJob(MixtapeCompilerTarget(),
fspec,
MixtapeCompilerParams(;
opt = opt,
ctx = ctx,
optlevel = optlevel))
return cached_compilation(jit_compiled_cache, job, _jit, _jitlink)
Expand Down Expand Up @@ -532,7 +532,7 @@ function _emit(job::CompilerJob)
f = job.source.f
tt = job.source.tt
opt = job.params.opt
si, ssg = analyze(f, tt;
si, ssg = analyze(f, tt;
ctx = job.params.ctx, opt = opt) # Populate local cache.
return get_codeinfo(ssg, entrypoint(ssg))
end
Expand All @@ -541,15 +541,15 @@ identity(job::CompilerJob, src) = src

const emit_compiled_cache = Dict{UInt, Any}()

function emit(@nospecialize(f), tt::Type{T};
function emit(@nospecialize(f), tt::Type{T};
ctx = NoContext(), opt = false) where {F <: Function, T <: Tuple}
fspec = FunctionSpec(f, tt, false, nothing) #=name=#
optlevel = Base.JLOptions().opt_level
job = CompilerJob(MixtapeCompilerTarget(),
fspec,
job = CompilerJob(MixtapeCompilerTarget(),
fspec,
MixtapeCompilerParams(;
ctx = ctx,
opt = opt,
opt = opt,
optlevel = optlevel))
return cached_compilation(emit_compiled_cache, job, _emit, identity)
end
Expand All @@ -567,7 +567,7 @@ Emit typed (and optimized if `opt = true`) `CodeInfo` using the Mixtape pipeline

macro load_abi()
expr = quote
function cached_call(entry::Mixtape.Entry{F, RT, TT},
function cached_call(entry::Mixtape.Entry{F, RT, TT},
args...) where {F, RT, TT}

# TODO: Fast ABI.
Expand All @@ -582,15 +582,15 @@ macro load_abi()
return expr
end

@generated function _call(ctx::CompilationContext,
@generated function _call(ctx::CompilationContext,
optlevel::Val{T}, f::Function, args...) where T
TT = Tuple{args...}
entry = jit(f.instance, TT;
entry = jit(f.instance, TT;
ctx = ctx(), opt = true, optlevel = T)
return cached_call(entry, args...)
end

function call(f::T, args...;
function call(f::T, args...;
ctx = NoContext(),
optlevel = Base.JLOptions().opt_level) where T <: Function
_call(ctx, Val(optlevel), f, args...)
Expand All @@ -603,7 +603,7 @@ end
"""
@load_abi()
...expands...
call(f::T, args...; ctx = NoContext(),
call(f::T, args...; ctx = NoContext(),
optlevel = Base.JLOptions().opt_level) where T <: Function


Expand Down