Skip to content

Commit

Permalink
Allow CodeInstance in Expr(:invoke)
Browse files Browse the repository at this point in the history
This is a quick experiment to see what it would look like to switch
Expr(:invoke) to use CodeInstance rather than MethodInstance. There
is some unresolved semantic questions here about whether this is
a good idea or if there's some other representation that's better,
but discussion might be easier with an implementation.

The larger context here is the question of whether and how to do
more general method specialization. I had written some thoughts
in [1], but as mentioned there remains ongoing discussion of whether
this is the correct direction or not.

[1] https://hackmd.io/@Og2_pcUySm6R_RPqbZ06JA/S1bqP1D_6
  • Loading branch information
Keno committed Jan 24, 2024
1 parent d741c24 commit 9751fe0
Show file tree
Hide file tree
Showing 9 changed files with 182 additions and 133 deletions.
3 changes: 2 additions & 1 deletion base/compiler/ssair/EscapeAnalysis/EscapeAnalysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,8 @@ end

# escape statically-resolved call, i.e. `Expr(:invoke, ::MethodInstance, ...)`
function escape_invoke!(astate::AnalysisState, pc::Int, args::Vector{Any})
mi = first(args)::MethodInstance
arg1 = first(args)
mi = isa(arg1, Core.CodeInstance) ? arg1.def : arg1::MethodInstance
first_idx, last_idx = 2, length(args)
# TODO inspect `astate.ir.stmts[pc][:info]` and use const-prop'ed `InferenceResult` if available
cache = astate.get_escape_cache(mi)
Expand Down
59 changes: 40 additions & 19 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ struct SomeCase
end

struct InvokeCase
invoke::MethodInstance
invoke::Union{MethodInstance, CodeInstance}
effects::Effects
info::CallInfo
end
Expand Down Expand Up @@ -831,29 +831,44 @@ function compileable_specialization(mi::MethodInstance, effects::Effects,
return InvokeCase(mi_invoke, effects, info)
end

function compileable_specialization(ci::CodeInstance, effects::Effects,
et::InliningEdgeTracker, @nospecialize(info::CallInfo); compilesig_invokes::Bool=true)
if compilesig_invokes ? !isa_compileable_sig(ci.def.specTypes, ci.def.sparam_vals, ci.def.def) :
any(@nospecialize(t)->isa(t, TypeVar), ci.def.sparam_vals)
return compileable_specialization(ci.def, effects, et, info; compilesig_invokes)
end
add_inlining_backedge!(et, ci.def) # to the dispatch lookup
push!(et.edges, ci.def.def.sig, ci.def) # add_inlining_backedge to the invoke call
return InvokeCase(ci, effects, info)
end

function compileable_specialization(match::MethodMatch, effects::Effects,
et::InliningEdgeTracker, @nospecialize(info::CallInfo); compilesig_invokes::Bool=true)
mi = specialize_method(match)
return compileable_specialization(mi, effects, et, info; compilesig_invokes)
end

struct InferredResult
ci::Union{Nothing, CodeInstance}
src::Any
effects::Effects
InferredResult(@nospecialize(src), effects::Effects) = new(src, effects)
InferredResult(ci::Union{Nothing, CodeInstance}, @nospecialize(src), effects::Effects) = new(ci, src, effects)
end
@inline function get_cached_result(code::CodeInstance)
if use_const_api(code)
# in this case function can be inlined to a constant
return ConstantCase(quoted(code.rettype_const))
end
src = @atomic :monotonic code.inferred
effects = decode_effects(code.ipo_purity_bits)
return InferredResult(code, src, effects)
end
@inline function get_cached_result(state::InliningState, mi::MethodInstance)
code = get(code_cache(state), mi, nothing)
if code isa CodeInstance
if use_const_api(code)
# in this case function can be inlined to a constant
return ConstantCase(quoted(code.rettype_const))
end
src = @atomic :monotonic code.inferred
effects = decode_effects(code.ipo_purity_bits)
return InferredResult(src, effects)
return get_cached_result(code)
end
return InferredResult(nothing, Effects())
return InferredResult(nothing, nothing, Effects())
end
@inline function get_local_result(inf_result::InferenceResult)
effects = inf_result.ipo_effects
Expand All @@ -864,11 +879,11 @@ end
return ConstantCase(quoted(res.val))
end
end
return InferredResult(inf_result.src, effects)
return InferredResult(nothing, inf_result.src, effects)
end

# the general resolver for usual and const-prop'ed calls
function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult},
function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,VolatileInferenceResult,CodeInstance},
@nospecialize(info::CallInfo), flag::UInt32, state::InliningState;
invokesig::Union{Nothing,Vector{Any}}=nothing)
et = InliningEdgeTracker(state, invokesig)
Expand All @@ -887,17 +902,18 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
add_inlining_backedge!(et, mi)
return inferred_result
end
(; src, effects) = inferred_result
(; src, ci, effects) = inferred_result
invoke_target = mi

# the duplicated check might have been done already within `analyze_method!`, but still
# we need it here too since we may come here directly using a constant-prop' result
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
return compileable_specialization(mi, effects, et, info;
return compileable_specialization(invoke_target, effects, et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)
end

src = inlining_policy(state.interp, src, info, flag)
src === nothing && return compileable_specialization(mi, effects, et, info;
src === nothing && return compileable_specialization(invoke_target, effects, et, info;
compilesig_invokes=OptimizationParams(state.interp).compilesig_invokes)

add_inlining_backedge!(et, mi)
Expand All @@ -906,15 +922,21 @@ function resolve_todo(mi::MethodInstance, result::Union{Nothing,InferenceResult,
end

# the special resolver for :invoke-d call
function resolve_todo(mi::MethodInstance, @nospecialize(info::CallInfo), flag::UInt32,
function resolve_todo(mici::Union{MethodInstance, CodeInstance}, @nospecialize(info::CallInfo), flag::UInt32,
state::InliningState)
if !OptimizationParams(state.interp).inlining || is_stmt_noinline(flag)
return nothing
end

et = InliningEdgeTracker(state)

cached_result = get_cached_result(state, mi)
if isa(mici, CodeInstance)
cached_result = get_cached_result(mici)
mi = mici.def
else
cached_result = get_cached_result(state, mici)
mi = mici
end
if cached_result isa ConstantCase
add_inlining_backedge!(et, mi)
return cached_result
Expand Down Expand Up @@ -1632,8 +1654,7 @@ end

function handle_invoke_expr!(todo::Vector{Pair{Int,Any}}, ir::IRCode,
idx::Int, stmt::Expr, @nospecialize(info::CallInfo), flag::UInt32, sig::Signature, state::InliningState)
mi = stmt.args[1]::MethodInstance
case = resolve_todo(mi, info, flag, state)
case = resolve_todo(stmt.args[1], info, flag, state)
handle_single_case!(todo, ir, idx, stmt, case, false)
return nothing
end
Expand Down
4 changes: 3 additions & 1 deletion base/compiler/ssair/irinterp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,9 @@ function reprocess_instruction!(interp::AbstractInterpreter, inst::Instruction,
(; rt, effects) = abstract_eval_statement_expr(interp, stmt, nothing, irsv)
add_flag!(inst, flags_for_effects(effects))
elseif head === :invoke
rt, (nothrow, noub) = concrete_eval_invoke(interp, stmt, stmt.args[1]::MethodInstance, irsv)
arg1 = stmt.args[1]
mi = isa(arg1, CodeInstance) ? arg1.def : arg1::MethodInstance
rt, (nothrow, noub) = concrete_eval_invoke(interp, stmt, mi, irsv)
if nothrow
add_flag!(inst, IR_FLAG_NOTHROW)
end
Expand Down
12 changes: 9 additions & 3 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1477,9 +1477,15 @@ end
# NOTE we resolve the inlining source here as we don't want to serialize `Core.Compiler`
# data structure into the global cache (see the comment in `handle_finalizer_call!`)
function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
mi::MethodInstance, @nospecialize(info::CallInfo), inlining::InliningState,
code_or_mi::Union{MethodInstance, CodeInstance}, @nospecialize(info::CallInfo), inlining::InliningState,
attach_after::Bool)
code = get(code_cache(inlining), mi, nothing)
if isa(code_or_mi, CodeInstance)
code = code_or_mi
mi = code.def
else
mi = code_or_mi
code = get(code_cache(inlining), mi, nothing)
end
et = InliningEdgeTracker(inlining)
if code isa CodeInstance
if use_const_api(code)
Expand Down Expand Up @@ -1646,7 +1652,7 @@ function try_resolve_finalizer!(ir::IRCode, idx::Int, finalizer_idx::Int, defuse
if inline === nothing
# No code in the function - Nothing to do
else
mi = finalizer_stmt.args[5]::MethodInstance
mi = finalizer_stmt.args[5]::Union{MethodInstance, CodeInstance}
if inline::Bool && try_inline_finalizer!(ir, argexprs, loc, mi, info, inlining, attach_after)
# the finalizer body has been inlined
else
Expand Down
3 changes: 2 additions & 1 deletion base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ function print_stmt(io::IO, idx::Int, @nospecialize(stmt), used::BitSet, maxleng
stmt = stmt::Expr
# TODO: why is this here, and not in Base.show_unquoted
print(io, "invoke ")
linfo = stmt.args[1]::Core.MethodInstance
arg1 = stmt.args[1]
linfo = isa(arg1, Core.CodeInstance) ? arg1.def : arg1::Core.MethodInstance
show_unquoted(io, stmt.args[2], indent)
print(io, "(")
# XXX: this is wrong if `sig` is not a concretetype method
Expand Down
169 changes: 90 additions & 79 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4653,95 +4653,106 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
bool handled = false;
jl_cgval_t result;
if (lival.constant) {
jl_method_instance_t *mi = (jl_method_instance_t*)lival.constant;
assert(jl_is_method_instance(mi));
if (mi == ctx.linfo) {
// handle self-recursion specially
jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed;
FunctionType *ft = ctx.f->getFunctionType();
StringRef protoname = ctx.f->getName();
if (ft == ctx.types().T_jlfunc) {
result = emit_call_specfun_boxed(ctx, ctx.rettype, protoname, nullptr, argv, nargs, rt);
handled = true;
}
else if (ft != ctx.types().T_jlfuncparams) {
unsigned return_roots = 0;
result = emit_call_specfun_other(ctx, mi, ctx.rettype, protoname, nullptr, argv, nargs, &cc, &return_roots, rt);
handled = true;
}
}
else {
jl_value_t *ci = ctx.params->lookup(mi, ctx.world, ctx.world); // TODO: need to use the right pair world here
if (ci != jl_nothing) {
jl_code_instance_t *codeinst = (jl_code_instance_t*)ci;
auto invoke = jl_atomic_load_acquire(&codeinst->invoke);
// check if we know how to handle this specptr
if (invoke == jl_fptr_const_return_addr) {
result = mark_julia_const(ctx, codeinst->rettype_const);
jl_code_instance_t *codeinst = NULL;
jl_method_instance_t *mi = NULL;
if (jl_is_method_instance(lival.constant)) {
mi = (jl_method_instance_t*)lival.constant;
assert(jl_is_method_instance(mi));
if (mi == ctx.linfo) {
// handle self-recursion specially
jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed;
FunctionType *ft = ctx.f->getFunctionType();
StringRef protoname = ctx.f->getName();
if (ft == ctx.types().T_jlfunc) {
result = emit_call_specfun_boxed(ctx, ctx.rettype, protoname, nullptr, argv, nargs, rt);
handled = true;
}
else if (invoke != jl_fptr_sparam_addr) {
bool specsig, needsparams;
std::tie(specsig, needsparams) = uses_specsig(mi, codeinst->rettype, ctx.params->prefer_specsig);
std::string name;
StringRef protoname;
bool need_to_emit = true;
bool cache_valid = ctx.use_cache || ctx.external_linkage;
bool external = false;

// Check if we already queued this up
auto it = ctx.call_targets.find(codeinst);
if (need_to_emit && it != ctx.call_targets.end()) {
protoname = it->second.decl->getName();
need_to_emit = cache_valid = false;
}
else if (ft != ctx.types().T_jlfuncparams) {
unsigned return_roots = 0;
result = emit_call_specfun_other(ctx, mi, ctx.rettype, protoname, nullptr, argv, nargs, &cc, &return_roots, rt);
handled = true;
}
goto done;
} else {
jl_value_t *ci = ctx.params->lookup(mi, ctx.world, ctx.world); // TODO: need to use the right pair world here
if (ci == jl_nothing)
goto done;
codeinst = (jl_code_instance_t*)ci;
}
} else {
assert(jl_is_code_instance(lival.constant));
codeinst = (jl_code_instance_t*)lival.constant;
// TODO: Separate copy of the callsig in the CodeInstance
mi = codeinst->def;
}

// Check if it is already compiled (either JIT or externally)
if (cache_valid) {
// optimization: emit the correct name immediately, if we know it
// TODO: use `emitted` map here too to try to consolidate names?
// WARNING: isspecsig is protected by the codegen-lock. If that lock is removed, then the isspecsig load needs to be properly atomically sequenced with this.
auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (fptr) {
while (!(jl_atomic_load_acquire(&codeinst->specsigflags) & 0b10)) {
jl_cpu_pause();
}
invoke = jl_atomic_load_relaxed(&codeinst->invoke);
if (specsig ? jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b1 : invoke == jl_fptr_args_addr) {
protoname = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
if (ctx.external_linkage) {
// TODO: Add !specsig support to aotcompile.cpp
// Check that the codeinst is containing native code
if (specsig && jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b100) {
external = true;
need_to_emit = false;
}
}
else { // ctx.use_cache
need_to_emit = false;
}
auto invoke = jl_atomic_load_acquire(&codeinst->invoke);
// check if we know how to handle this specptr
if (invoke == jl_fptr_const_return_addr) {
result = mark_julia_const(ctx, codeinst->rettype_const);
handled = true;
}
else if (invoke != jl_fptr_sparam_addr) {
bool specsig, needsparams;
std::tie(specsig, needsparams) = uses_specsig(mi, codeinst->rettype, ctx.params->prefer_specsig);
std::string name;
StringRef protoname;
bool need_to_emit = true;
bool cache_valid = ctx.use_cache || ctx.external_linkage;
bool external = false;

// Check if we already queued this up
auto it = ctx.call_targets.find(codeinst);
if (need_to_emit && it != ctx.call_targets.end()) {
protoname = it->second.decl->getName();
need_to_emit = cache_valid = false;
}

// Check if it is already compiled (either JIT or externally)
if (cache_valid) {
// optimization: emit the correct name immediately, if we know it
// TODO: use `emitted` map here too to try to consolidate names?
// WARNING: isspecsig is protected by the codegen-lock. If that lock is removed, then the isspecsig load needs to be properly atomically sequenced with this.
auto fptr = jl_atomic_load_relaxed(&codeinst->specptr.fptr);
if (fptr) {
while (!(jl_atomic_load_acquire(&codeinst->specsigflags) & 0b10)) {
jl_cpu_pause();
}
invoke = jl_atomic_load_relaxed(&codeinst->invoke);
if (specsig ? jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b1 : invoke == jl_fptr_args_addr) {
protoname = jl_ExecutionEngine->getFunctionAtAddress((uintptr_t)fptr, codeinst);
if (ctx.external_linkage) {
// TODO: Add !specsig support to aotcompile.cpp
// Check that the codeinst is containing native code
if (specsig && jl_atomic_load_relaxed(&codeinst->specsigflags) & 0b100) {
external = true;
need_to_emit = false;
}
}
}
if (need_to_emit) {
raw_string_ostream(name) << (specsig ? "j_" : "j1_") << name_from_method_instance(mi) << "_" << jl_atomic_fetch_add(&globalUniqueGeneratedNames, 1);
protoname = StringRef(name);
}
jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed;
unsigned return_roots = 0;
if (specsig)
result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt);
else
result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, rt);
handled = true;
if (need_to_emit) {
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig};
else { // ctx.use_cache
need_to_emit = false;
}
}
}
}
if (need_to_emit) {
raw_string_ostream(name) << (specsig ? "j_" : "j1_") << name_from_method_instance(mi) << "_" << jl_atomic_fetch_add(&globalUniqueGeneratedNames, 1);
protoname = StringRef(name);
}
jl_returninfo_t::CallingConv cc = jl_returninfo_t::CallingConv::Boxed;
unsigned return_roots = 0;
if (specsig)
result = emit_call_specfun_other(ctx, mi, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, &cc, &return_roots, rt);
else
result = emit_call_specfun_boxed(ctx, codeinst->rettype, protoname, external ? codeinst : nullptr, argv, nargs, rt);
handled = true;
if (need_to_emit) {
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig};
}
}
}
done:
if (!handled) {
Value *r = emit_jlcall(ctx, jlinvoke_func, boxed(ctx, lival), argv, nargs, julia_call2);
result = mark_julia_type(ctx, r, true, rt);
Expand Down
Loading

0 comments on commit 9751fe0

Please sign in to comment.