Skip to content

Commit

Permalink
Merge pull request #47994 from JuliaLang/avi/improve-semi-concrete-ac…
Browse files Browse the repository at this point in the history
…curacy

Currently semi-concrete interpretation can end up with more inaccurate
result than usual abstract interpretation because `src.ssavaluetypes`
are all widened when cached so semi-concrete interpretation can't use
extended lattice information of `SSAValue`s. This commit tries to fix it
by making the codegen system recognize extended lattice information such
as `Const` and `PartialStruct` and not-widening those information from
`src::CodeInfo` when caching.

I found there are other chances when semi-concrete interpretation can
end up with inaccurate results, but that is unrelated to this and should
be fixed separately.
  • Loading branch information
aviatesk authored Dec 27, 2022
2 parents a9e0545 + 7ace688 commit 03bdf15
Show file tree
Hide file tree
Showing 11 changed files with 89 additions and 22 deletions.
4 changes: 2 additions & 2 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -524,12 +524,12 @@ module IR
export CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
Const, PartialStruct, InterConditional, PartialOpaque

import Core: CodeInfo, MethodInstance, CodeInstance, GotoNode, GotoIfNot, ReturnNode,
NewvarNode, SSAValue, Slot, SlotNumber, TypedSlot, Argument,
PiNode, PhiNode, PhiCNode, UpsilonNode, LineInfoNode,
Const, PartialStruct
Const, PartialStruct, InterConditional, PartialOpaque

end

Expand Down
13 changes: 6 additions & 7 deletions base/compiler/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -179,29 +179,28 @@ function ir_to_codeinf!(opt::OptimizationState)
optdef = linfo.def
replace_code_newstyle!(src, opt.ir::IRCode, isa(optdef, Method) ? Int(optdef.nargs) : 0)
opt.ir = nothing
widen_all_consts!(src)
widencompileronly!(src)
src.rettype = widenconst(src.rettype)
src.inferred = true
# finish updating the result struct
validate_code_in_debug_mode(linfo, src, "optimized")
return src
end

# widen all Const elements in type annotations
function widen_all_consts!(src::CodeInfo)
# widen extended lattice elements in type annotations so that they are recognizable by the codegen system.
function widencompileronly!(src::CodeInfo)
ssavaluetypes = src.ssavaluetypes::Vector{Any}
for i = 1:length(ssavaluetypes)
ssavaluetypes[i] = widenconst(ssavaluetypes[i])
ssavaluetypes[i] = widencompileronly(ssavaluetypes[i])
end

for i = 1:length(src.code)
x = src.code[i]
if isa(x, PiNode)
src.code[i] = PiNode(x.val, widenconst(x.typ))
src.code[i] = PiNode(x.val, widencompileronly(x.typ))
end
end

src.rettype = widenconst(src.rettype)

return src
end

Expand Down
29 changes: 29 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ const AnyConditional = Union{Conditional,InterConditional}
Conditional(cnd::InterConditional) = Conditional(cnd.slot, cnd.thentype, cnd.elsetype)
InterConditional(cnd::Conditional) = InterConditional(cnd.slot, cnd.thentype, cnd.elsetype)

# TODO make `MustAlias` and `InterMustAlias` recognizable by the codegen system

"""
alias::MustAlias
Expand Down Expand Up @@ -668,6 +670,33 @@ function tmeet(lattice::OptimizerLattice, @nospecialize(v), @nospecialize(t::Typ
tmeet(widenlattice(lattice), v, t)
end

"""
is_core_extended_info(t) -> Bool
Check if extended lattice element `t` is recognizable by the runtime/codegen system.
See also the implementation of `jl_widen_core_extended_info` in jltypes.c.
"""
function is_core_extended_info(@nospecialize t)
isa(t, Type) && return true
isa(t, Const) && return true
isa(t, PartialStruct) && return true
isa(t, InterConditional) && return true
# TODO isa(t, InterMustAlias) && return true
isa(t, PartialOpaque) && return true
return false
end

"""
widencompileronly(t) -> wt::Any
Widen the extended lattice element `x` so that `wt` is recognizable by the runtime/codegen system.
"""
function widencompileronly(@nospecialize t)
is_core_extended_info(t) && return t
return widenconst(t)
end

"""
widenconst(x) -> t::Type
Expand Down
2 changes: 1 addition & 1 deletion base/opaque_closure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ function Core.OpaqueClosure(ir::IRCode, env...;
src.slotnames = fill(:none, nargs+1)
src.slottypes = copy(ir.argtypes)
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
Core.Compiler.widen_all_consts!(src)
Core.Compiler.widencompileronly!(src)
src.inferred = true
# NOTE: we need ir.argtypes[1] == typeof(env)

Expand Down
6 changes: 4 additions & 2 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,8 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
return jl_cgval_t();
}
if (jl_is_ssavalue(args[2]) && !jl_is_long(ctx.source->ssavaluetypes)) {
jl_value_t *rtt = jl_arrayref((jl_array_t*)ctx.source->ssavaluetypes, ((jl_ssavalue_t*)args[2])->id - 1);
jl_value_t *rtt = jl_widen_core_extended_info(
jl_arrayref((jl_array_t*)ctx.source->ssavaluetypes, ((jl_ssavalue_t*)args[2])->id - 1));
if (jl_is_type_type(rtt))
rt = jl_tparam0(rtt);
}
Expand All @@ -776,7 +777,8 @@ static jl_cgval_t emit_llvmcall(jl_codectx_t &ctx, jl_value_t **args, size_t nar
}
}
if (jl_is_ssavalue(args[3]) && !jl_is_long(ctx.source->ssavaluetypes)) {
jl_value_t *att = jl_arrayref((jl_array_t*)ctx.source->ssavaluetypes, ((jl_ssavalue_t*)args[3])->id - 1);
jl_value_t *att = jl_widen_core_extended_info(
jl_arrayref((jl_array_t*)ctx.source->ssavaluetypes, ((jl_ssavalue_t*)args[3])->id - 1));
if (jl_is_type_type(att))
at = jl_tparam0(att);
}
Expand Down
15 changes: 8 additions & 7 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4501,7 +4501,7 @@ static void emit_phinode_assign(jl_codectx_t &ctx, ssize_t idx, jl_value_t *r)
jl_value_t *ssavalue_types = (jl_value_t*)ctx.source->ssavaluetypes;
jl_value_t *phiType = NULL;
if (jl_is_array(ssavalue_types)) {
phiType = jl_array_ptr_ref(ssavalue_types, idx);
phiType = jl_widen_core_extended_info(jl_array_ptr_ref(ssavalue_types, idx));
} else {
phiType = (jl_value_t*)jl_any_type;
}
Expand Down Expand Up @@ -4610,7 +4610,7 @@ static void emit_ssaval_assign(jl_codectx_t &ctx, ssize_t ssaidx_0based, jl_valu
// e.g. sometimes the information is inconsistent after inlining getfield on a Tuple
jl_value_t *ssavalue_types = (jl_value_t*)ctx.source->ssavaluetypes;
if (jl_is_array(ssavalue_types)) {
jl_value_t *declType = jl_array_ptr_ref(ssavalue_types, ssaidx_0based);
jl_value_t *declType = jl_widen_core_extended_info(jl_array_ptr_ref(ssavalue_types, ssaidx_0based));
if (declType != slot.typ) {
slot = update_julia_type(ctx, slot, declType);
}
Expand Down Expand Up @@ -4949,7 +4949,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
}
if (jl_is_pinode(expr)) {
Value *skip = NULL;
return convert_julia_type(ctx, emit_expr(ctx, jl_fieldref_noalloc(expr, 0)), jl_fieldref_noalloc(expr, 1), &skip);
return convert_julia_type(ctx, emit_expr(ctx, jl_fieldref_noalloc(expr, 0)), jl_widen_core_extended_info(jl_fieldref_noalloc(expr, 1)), &skip);
}
if (!jl_is_expr(expr)) {
jl_value_t *val = expr;
Expand Down Expand Up @@ -4987,13 +4987,13 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
else if (head == jl_invoke_sym) {
assert(ssaidx_0based >= 0);
jl_value_t *expr_t = jl_is_long(ctx.source->ssavaluetypes) ? (jl_value_t*)jl_any_type :
jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based);
jl_widen_core_extended_info(jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based));
return emit_invoke(ctx, ex, expr_t);
}
else if (head == jl_invoke_modify_sym) {
assert(ssaidx_0based >= 0);
jl_value_t *expr_t = jl_is_long(ctx.source->ssavaluetypes) ? (jl_value_t*)jl_any_type :
jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based);
jl_widen_core_extended_info(jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based));
return emit_invoke_modify(ctx, ex, expr_t);
}
else if (head == jl_call_sym) {
Expand All @@ -5003,7 +5003,8 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
// TODO: this case is needed for the call to emit_expr in emit_llvmcall
expr_t = (jl_value_t*)jl_any_type;
else {
expr_t = jl_is_long(ctx.source->ssavaluetypes) ? (jl_value_t*)jl_any_type : jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based);
expr_t = jl_is_long(ctx.source->ssavaluetypes) ? (jl_value_t*)jl_any_type :
jl_widen_core_extended_info(jl_array_ptr_ref(ctx.source->ssavaluetypes, ssaidx_0based));
is_promotable = ctx.ssavalue_usecount.at(ssaidx_0based) == 1;
}
jl_cgval_t res = emit_call(ctx, ex, expr_t, is_promotable);
Expand Down Expand Up @@ -7114,7 +7115,7 @@ static jl_llvm_functions_t
}
jl_varinfo_t &vi = (ctx.phic_slots.emplace(i, jl_varinfo_t(ctx.builder.getContext())).first->second =
jl_varinfo_t(ctx.builder.getContext()));
jl_value_t *typ = jl_array_ptr_ref(src->ssavaluetypes, i);
jl_value_t *typ = jl_widen_core_extended_info(jl_array_ptr_ref(src->ssavaluetypes, i));
vi.used = true;
vi.isVolatile = true;
vi.value = mark_julia_type(ctx, (Value*)NULL, false, typ);
Expand Down
6 changes: 4 additions & 2 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,10 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
if (jl_is_pinode(e)) {
jl_value_t *val = eval_value(jl_fieldref_noalloc(e, 0), s);
#ifndef JL_NDEBUG
JL_GC_PUSH1(&val);
jl_typeassert(val, jl_fieldref_noalloc(e, 1));
jl_value_t *typ = NULL;
JL_GC_PUSH2(&val, &typ);
typ = jl_widen_core_extended_info(jl_fieldref_noalloc(e, 1));
jl_typeassert(val, typ);
JL_GC_POP();
#endif
return val;
Expand Down
22 changes: 22 additions & 0 deletions src/jltypes.c
Original file line number Diff line number Diff line change
Expand Up @@ -2015,6 +2015,28 @@ void jl_reinstantiate_inner_types(jl_datatype_t *t) // can throw!
}
}

// Widens "core" extended lattice element `t` to the native `Type` representation.
// The implementation of this function should sync with those of the corresponding `widenconst`s.
JL_DLLEXPORT jl_value_t *jl_widen_core_extended_info(jl_value_t *t)
{
jl_value_t* tt = jl_typeof(t);
if (tt == (jl_value_t*)jl_const_type) {
jl_value_t* val = jl_fieldref_noalloc(t, 0);
if (jl_isa(val, (jl_value_t*)jl_type_type))
return (jl_value_t*)jl_wrap_Type(val);
else
return jl_typeof(val);
}
else if (tt == (jl_value_t*)jl_partial_struct_type)
return (jl_value_t*)jl_fieldref_noalloc(t, 0);
else if (tt == (jl_value_t*)jl_interconditional_type)
return (jl_value_t*)jl_bool_type;
else if (tt == (jl_value_t*)jl_partial_opaque_type)
return (jl_value_t*)jl_fieldref_noalloc(t, 0);
else
return t;
}

// initialization -------------------------------------------------------------

static jl_tvar_t *tvar(const char *name)
Expand Down
1 change: 0 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -1842,7 +1842,6 @@ JL_DLLEXPORT jl_value_t *jl_compress_argnames(jl_array_t *syms);
JL_DLLEXPORT jl_array_t *jl_uncompress_argnames(jl_value_t *syms);
JL_DLLEXPORT jl_value_t *jl_uncompress_argname_n(jl_value_t *syms, size_t i);


JL_DLLEXPORT int jl_is_operator(char *sym);
JL_DLLEXPORT int jl_is_unary_operator(char *sym);
JL_DLLEXPORT int jl_is_unary_and_binary_operator(char *sym);
Expand Down
1 change: 1 addition & 0 deletions src/julia_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -722,6 +722,7 @@ void jl_init_main_module(void);
JL_DLLEXPORT int jl_is_submodule(jl_module_t *child, jl_module_t *parent) JL_NOTSAFEPOINT;
jl_array_t *jl_get_loaded_modules(void);
JL_DLLEXPORT int jl_datatype_isinlinealloc(jl_datatype_t *ty, int pointerfree);
JL_DLLEXPORT jl_value_t *jl_widen_core_extended_info(jl_value_t *t);

void jl_eval_global_expr(jl_module_t *m, jl_expr_t *ex, int set_type);
jl_value_t *jl_toplevel_eval_flex(jl_module_t *m, jl_value_t *e, int fast, int expanded);
Expand Down
12 changes: 12 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4689,3 +4689,15 @@ end
@test Base.return_types(empty_nt_keys, (Any,)) |> only === Tuple{}
g() = empty_nt_values(Base.inferencebarrier(Tuple{}))
@test g() == () # Make sure to actually run this to test this in the inference world age

let # jl_widen_core_extended_info
for (extended, widened) = [(Core.Const(42), Int),
(Core.Const(Int), Type{Int}),
(Core.Const(Vector), Type{Vector}),
(Core.PartialStruct(Some{Any}, Any["julia"]), Some{Any}),
(Core.InterConditional(2, Int, Nothing), Bool)]
@test @ccall(jl_widen_core_extended_info(extended::Any)::Any) ===
Core.Compiler.widenconst(extended) ===
widened
end
end

0 comments on commit 03bdf15

Please sign in to comment.