Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

distinguish "inlineable" from "declared as inline" #48250

Merged
merged 1 commit into from
Jan 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 33 additions & 30 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
method = match.method
sig = match.spec_types
mi = specialize_method(match; preexisting=true)
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
if mi !== nothing && !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
csig = get_compileable_sig(method, sig, match.sparams)
if csig !== nothing && csig !== sig
abstract_call_method(interp, method, csig, match.sparams, multiple_matches, StmtInfo(false), sv)
Expand Down Expand Up @@ -1087,7 +1087,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter,
return nothing
end
mi = mi::MethodInstance
if !force && !const_prop_methodinstance_heuristic(interp, match, mi, arginfo, sv)
if !force && !const_prop_methodinstance_heuristic(interp, mi, arginfo, sv)
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
return nothing
end
Expand Down Expand Up @@ -1239,43 +1239,46 @@ end
# where we would spend a lot of time, but are probably unlikely to get an improved
# result anyway.
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter,
match::MethodMatch, mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState)
method = match.method
mi::MethodInstance, arginfo::ArgInfo, sv::InferenceState)
method = mi.def::Method
if method.is_for_opaque_closure
# Not inlining an opaque closure can be very expensive, so be generous
# with the const-prop-ability. It is quite possible that we can't infer
# anything at all without const-propping, so the inlining check below
# isn't particularly helpful here.
return true
end
# Peek at the inferred result for the function to determine if the optimizer
# was able to cut it down to something simple (inlineable in particular).
# If so, there's a good chance we might be able to const prop all the way
# through and learn something new.
if isdefined(method, :source) && is_inlineable(method.source)
# now check if the source of this method instance is inlineable, since the extended type
# information we have here would be discarded if it is not inlined into a callee context
# (modulo the inferred return type that can be potentially refined)
if is_declared_inline(method)
# this method is declared as `@inline` and will be inlined
return true
end
flag = get_curr_ssaflag(sv)
if is_stmt_inline(flag)
# force constant propagation for a call that is going to be inlined
# since the inliner will try to find this constant result
# if these constant arguments arrive there
return true
elseif is_stmt_noinline(flag)
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
return false
else
flag = get_curr_ssaflag(sv)
if is_stmt_inline(flag)
# force constant propagation for a call that is going to be inlined
# since the inliner will try to find this constant result
# if these constant arguments arrive there
return true
elseif is_stmt_noinline(flag)
# this call won't be inlined, thus this constant-prop' will most likely be unfruitful
return false
else
code = get(code_cache(interp), mi, nothing)
if isdefined(code, :inferred)
if isa(code, CodeInstance)
inferred = @atomic :monotonic code.inferred
else
inferred = code.inferred
end
# TODO propagate a specific `CallInfo` that conveys information about this call
if inlining_policy(interp, inferred, NoCallInfo(), IR_FLAG_NULL, mi, arginfo.argtypes) !== nothing
return true
end
# Peek at the inferred result for the method to determine if the optimizer
# was able to cut it down to something simple (inlineable in particular).
# If so, there will be a good chance we might be able to const prop
# all the way through and learn something new.
code = get(code_cache(interp), mi, nothing)
if isdefined(code, :inferred)
if isa(code, CodeInstance)
inferred = @atomic :monotonic code.inferred
else
inferred = code.inferred
end
# TODO propagate a specific `CallInfo` that conveys information about this call
if inlining_policy(interp, inferred, NoCallInfo(), IR_FLAG_NULL, mi, arginfo.argtypes) !== nothing
return true
end
end
end
Expand Down
30 changes: 20 additions & 10 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,25 @@ const TOP_TUPLE = GlobalRef(Core, :tuple)
const InlineCostType = UInt16
const MAX_INLINE_COST = typemax(InlineCostType)
const MIN_INLINE_COST = InlineCostType(10)
const MaybeCompressed = Union{CodeInfo, Vector{UInt8}}

is_inlineable(src::Union{CodeInfo, Vector{UInt8}}) = ccall(:jl_ir_inlining_cost, InlineCostType, (Any,), src) != MAX_INLINE_COST
set_inlineable!(src::CodeInfo, val::Bool) = src.inlining_cost = (val ? MIN_INLINE_COST : MAX_INLINE_COST)
is_inlineable(@nospecialize src::MaybeCompressed) =
ccall(:jl_ir_inlining_cost, InlineCostType, (Any,), src) != MAX_INLINE_COST
set_inlineable!(src::CodeInfo, val::Bool) =
src.inlining_cost = (val ? MIN_INLINE_COST : MAX_INLINE_COST)

function inline_cost_clamp(x::Int)::InlineCostType
x > MAX_INLINE_COST && return MAX_INLINE_COST
x < MIN_INLINE_COST && return MIN_INLINE_COST
return convert(InlineCostType, x)
end

is_declared_inline(@nospecialize src::MaybeCompressed) =
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == 1

is_declared_noinline(@nospecialize src::MaybeCompressed) =
ccall(:jl_ir_flag_inlining, UInt8, (Any,), src) == 2

#####################
# OptimizationState #
#####################
Expand All @@ -73,16 +82,16 @@ function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::Met
return nothing
end

is_source_inferred(@nospecialize(src::Union{CodeInfo, Vector{UInt8}})) =
is_source_inferred(@nospecialize src::MaybeCompressed) =
ccall(:jl_ir_flag_inferred, Bool, (Any,), src)

function inlining_policy(interp::AbstractInterpreter,
@nospecialize(src), @nospecialize(info::CallInfo), stmt_flag::UInt8, mi::MethodInstance,
argtypes::Vector{Any})
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
src_inferred = is_source_inferred(src)
if isa(src, MaybeCompressed)
is_source_inferred(src) || return nothing
src_inlineable = is_stmt_inline(stmt_flag) || is_inlineable(src)
return src_inferred && src_inlineable ? src : nothing
return src_inlineable ? src : nothing
elseif src === nothing && is_stmt_inline(stmt_flag)
# if this statement is forced to be inlined, make an additional effort to find the
# inferred source in the local cache
Expand Down Expand Up @@ -413,7 +422,7 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
(; def, specTypes) = linfo

analyzed = nothing # `ConstAPI` if this call can use constant calling convention
force_noinline = _any(x::Expr -> x.head === :meta && x.args[1] === :noinline, ir.meta)
force_noinline = is_declared_noinline(src)

# compute inlining and other related optimizations
result = caller.result
Expand Down Expand Up @@ -483,23 +492,24 @@ function finish(interp::AbstractInterpreter, opt::OptimizationState,
else
force_noinline = true
end
if !is_inlineable(src) && result === Bottom
if !is_declared_inline(src) && result === Bottom
force_noinline = true
end
end
if force_noinline
set_inlineable!(src, false)
elseif isa(def, Method)
if is_inlineable(src) && isdispatchtuple(specTypes)
if is_declared_inline(src) && isdispatchtuple(specTypes)
# obey @inline declaration if a dispatch barrier would not help
set_inlineable!(src, true)
else
# compute the cost (size) of inlining this code
cost_threshold = default = params.inline_cost_threshold
if ⊑(optimizer_lattice(interp), result, Tuple) && !isconcretetype(widenconst(result))
cost_threshold += params.inline_tupleret_bonus
end
# if the method is declared as `@inline`, increase the cost threshold 20x
if is_inlineable(src)
if is_declared_inline(src)
cost_threshold += 19*default
end
# a few functions get special treatment
Expand Down
21 changes: 21 additions & 0 deletions base/compiler/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,27 @@ function specialize_method(match::MethodMatch; kwargs...)
return specialize_method(match.method, match.spec_types, match.sparams; kwargs...)
end

"""
is_declared_inline(method::Method) -> Bool

Check if `method` is declared as `@inline`.
"""
is_declared_inline(method::Method) = _is_declared_inline(method, true)

"""
is_declared_noinline(method::Method) -> Bool

Check if `method` is declared as `@noinline`.
"""
is_declared_noinline(method::Method) = _is_declared_inline(method, false)

function _is_declared_inline(method::Method, inline::Bool)
isdefined(method, :source) || return false
src = method.source
isa(src, MaybeCompressed) || return false
return (inline ? is_declared_inline : is_declared_noinline)(src)
end

"""
is_aggressive_constprop(method::Union{Method,CodeInfo}) -> Bool

Expand Down
2 changes: 1 addition & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6877,7 +6877,7 @@ static jl_llvm_functions_t
FnAttrs.addAttribute(polly::PollySkipFnAttr);
#endif

if (jl_has_meta(stmts, jl_noinline_sym))
if (src->inlining == 2)
FnAttrs.addAttribute(Attribute::NoInline);

#ifdef JL_DEBUG_BUILD
Expand Down
18 changes: 16 additions & 2 deletions src/ircode.c
Original file line number Diff line number Diff line change
Expand Up @@ -434,12 +434,14 @@ static void jl_encode_value_(jl_ircode_state *s, jl_value_t *v, int as_literal)
}
}

static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inferred, uint8_t constprop)
static jl_code_info_flags_t code_info_flags(uint8_t pure, uint8_t propagate_inbounds, uint8_t inferred,
uint8_t inlining, uint8_t constprop)
{
jl_code_info_flags_t flags;
flags.bits.pure = pure;
flags.bits.propagate_inbounds = propagate_inbounds;
flags.bits.inferred = inferred;
flags.bits.inlining = inlining;
flags.bits.constprop = constprop;
return flags;
}
Expand Down Expand Up @@ -778,7 +780,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
1
};

jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inferred, code->constprop);
jl_code_info_flags_t flags = code_info_flags(code->pure, code->propagate_inbounds, code->inferred,
code->inlining, code->constprop);
write_uint8(s.s, flags.packed);
write_uint8(s.s, code->purity.bits);
write_uint16(s.s, code->inlining_cost);
Expand Down Expand Up @@ -873,6 +876,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
jl_code_info_t *code = jl_new_code_info_uninit();
jl_code_info_flags_t flags;
flags.packed = read_uint8(s.s);
code->inlining = flags.bits.inlining;
code->constprop = flags.bits.constprop;
code->inferred = flags.bits.inferred;
code->propagate_inbounds = flags.bits.propagate_inbounds;
Expand Down Expand Up @@ -944,6 +948,16 @@ JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data)
return flags.bits.inferred;
}

JL_DLLEXPORT uint8_t jl_ir_flag_inlining(jl_array_t *data)
{
if (jl_is_code_info(data))
return ((jl_code_info_t*)data)->inlining;
assert(jl_typeis(data, jl_array_uint8_type));
jl_code_info_flags_t flags;
flags.packed = ((uint8_t*)data->data)[0];
return flags.bits.inlining;
}

JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data)
{
if (jl_is_code_info(data))
Expand Down
1 change: 1 addition & 0 deletions src/jl_exported_funcs.inc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,7 @@
XX(jl_ios_get_nbyte_int) \
XX(jl_ir_flag_inferred) \
XX(jl_ir_inlining_cost) \
XX(jl_ir_flag_inlining) \
XX(jl_ir_flag_pure) \
XX(jl_ir_nslots) \
XX(jl_ir_slotflag) \
Expand Down
6 changes: 4 additions & 2 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2432,7 +2432,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_code_info_type =
jl_new_datatype(jl_symbol("CodeInfo"), core,
jl_any_type, jl_emptysvec,
jl_perm_symsvec(21,
jl_perm_symsvec(22,
"code",
"codelocs",
"ssavaluetypes",
Expand All @@ -2452,9 +2452,10 @@ void jl_init_types(void) JL_GC_DISABLED
"propagate_inbounds",
"pure",
"has_fcall",
"inlining",
"constprop",
"purity"),
jl_svec(21,
jl_svec(22,
jl_array_any_type,
jl_array_int32_type,
jl_any_type,
Expand All @@ -2475,6 +2476,7 @@ void jl_init_types(void) JL_GC_DISABLED
jl_bool_type,
jl_bool_type,
jl_uint8_type,
jl_uint8_type,
jl_uint8_type),
jl_emptysvec,
0, 1, 20);
Expand Down
2 changes: 2 additions & 0 deletions src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ typedef struct _jl_code_info_t {
uint8_t pure;
uint8_t has_fcall;
// uint8 settings
uint8_t inlining; // 0 = default; 1 = @inline; 2 = @noinline
uint8_t constprop; // 0 = use heuristic; 1 = aggressive; 2 = none
_jl_purity_overrides_t purity;
} jl_code_info_t;
Expand Down Expand Up @@ -1838,6 +1839,7 @@ JL_DLLEXPORT jl_value_t *jl_copy_ast(jl_value_t *expr JL_MAYBE_UNROOTED);
JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code);
JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t *metadata, jl_array_t *data);
JL_DLLEXPORT uint8_t jl_ir_flag_inferred(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_inlining(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint8_t jl_ir_flag_pure(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT uint16_t jl_ir_inlining_cost(jl_array_t *data) JL_NOTSAFEPOINT;
JL_DLLEXPORT ssize_t jl_ir_nslots(jl_array_t *data) JL_NOTSAFEPOINT;
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,7 @@ typedef struct {
uint8_t pure:1;
uint8_t propagate_inbounds:1;
uint8_t inferred:1;
uint8_t inlining:2; // 0 = use heuristic; 1 = aggressive; 2 = none
uint8_t constprop:2; // 0 = use heuristic; 1 = aggressive; 2 = none
} jl_code_info_flags_bitfield_t;

Expand Down
5 changes: 4 additions & 1 deletion src/method.c
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,9 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
if (ma == (jl_value_t*)jl_pure_sym)
li->pure = 1;
else if (ma == (jl_value_t*)jl_inline_sym)
li->inlining_cost = 0x10; // This corresponds to MIN_INLINE_COST
li->inlining = 1;
else if (ma == (jl_value_t*)jl_noinline_sym)
li->inlining = 2;
else if (ma == (jl_value_t*)jl_propagate_inbounds_sym)
li->propagate_inbounds = 1;
else if (ma == (jl_value_t*)jl_aggressive_constprop_sym)
Expand Down Expand Up @@ -477,6 +479,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
src->has_fcall = 0;
src->edges = jl_nothing;
src->constprop = 0;
src->inlining = 0;
src->purity.bits = 0;
return src;
}
Expand Down
8 changes: 5 additions & 3 deletions stdlib/Serialization/src/Serialization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ const TAGS = Any[

@assert length(TAGS) == 255

const ser_version = 20 # do not make changes without bumping the version #!
const ser_version = 21 # do not make changes without bumping the version #!

format_version(::AbstractSerializer) = ser_version
format_version(s::Serializer) = s.version
Expand Down Expand Up @@ -1027,8 +1027,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
nargs = deserialize(s)::Int32
isva = deserialize(s)::Bool
is_for_opaque_closure = false
constprop = 0x00
purity = 0x00
constprop = purity = 0x00
template_or_is_opaque = deserialize(s)
if isa(template_or_is_opaque, Bool)
is_for_opaque_closure = template_or_is_opaque
Expand Down Expand Up @@ -1194,6 +1193,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
if format_version(s) >= 20
ci.has_fcall = deserialize(s)
end
if format_version(s) >= 21
ci.inlining = deserialize(s)::UInt8
end
if format_version(s) >= 14
ci.constprop = deserialize(s)::UInt8
end
Expand Down
Loading