Skip to content

Commit 85518c8

Browse files
committed
atomics: optimize atomic modify operations (mostly)
Lacking inlining, but now expressing the direct invoke: this gets us within about 2x of a primitive atomicrmw add.
1 parent aa421ff commit 85518c8

13 files changed

+266
-146
lines changed

base/compiler/abstractinterpretation.jl

+13-10
Original file line numberDiff line numberDiff line change
@@ -1249,6 +1249,8 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
12491249
return abstract_apply(interp, argtypes, sv, max_methods)
12501250
elseif f === invoke
12511251
return abstract_invoke(interp, argtypes, sv)
1252+
elseif f === modifyfield!
1253+
return abstract_modifyfield!(interp, argtypes, sv)
12521254
end
12531255
return CallMeta(abstract_call_builtin(interp, f, fargs, argtypes, sv, max_methods), false)
12541256
elseif f === Core.kwfunc
@@ -1515,7 +1517,8 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15151517
return abstract_eval_special_value(interp, e, vtypes, sv)
15161518
end
15171519
e = e::Expr
1518-
if e.head === :call
1520+
ehead = e.head
1521+
if ehead === :call
15191522
ea = e.args
15201523
argtypes = collect_argtypes(interp, ea, vtypes, sv)
15211524
if argtypes === nothing
@@ -1525,7 +1528,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15251528
sv.stmt_info[sv.currpc] = callinfo.info
15261529
t = callinfo.rt
15271530
end
1528-
elseif e.head === :new
1531+
elseif ehead === :new
15291532
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
15301533
if isconcretetype(t) && !ismutabletype(t)
15311534
args = Vector{Any}(undef, length(e.args)-1)
@@ -1562,7 +1565,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15621565
end
15631566
end
15641567
end
1565-
elseif e.head === :splatnew
1568+
elseif ehead === :splatnew
15661569
t = instanceof_tfunc(abstract_eval_value(interp, e.args[1], vtypes, sv))[1]
15671570
if length(e.args) == 2 && isconcretetype(t) && !ismutabletype(t)
15681571
at = abstract_eval_value(interp, e.args[2], vtypes, sv)
@@ -1575,7 +1578,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15751578
t = PartialStruct(t, at.fields::Vector{Any})
15761579
end
15771580
end
1578-
elseif e.head === :new_opaque_closure
1581+
elseif ehead === :new_opaque_closure
15791582
t = Union{}
15801583
if length(e.args) >= 5
15811584
ea = e.args
@@ -1594,29 +1597,29 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
15941597
end
15951598
end
15961599
end
1597-
elseif e.head === :foreigncall
1600+
elseif ehead === :foreigncall
15981601
abstract_eval_value(interp, e.args[1], vtypes, sv)
15991602
t = sp_type_rewrap(e.args[2], sv.linfo, true)
16001603
for i = 3:length(e.args)
16011604
if abstract_eval_value(interp, e.args[i], vtypes, sv) === Bottom
16021605
t = Bottom
16031606
end
16041607
end
1605-
elseif e.head === :cfunction
1608+
elseif ehead === :cfunction
16061609
t = e.args[1]
16071610
isa(t, Type) || (t = Any)
16081611
abstract_eval_cfunction(interp, e, vtypes, sv)
1609-
elseif e.head === :method
1612+
elseif ehead === :method
16101613
t = (length(e.args) == 1) ? Any : Nothing
1611-
elseif e.head === :copyast
1614+
elseif ehead === :copyast
16121615
t = abstract_eval_value(interp, e.args[1], vtypes, sv)
16131616
if t isa Const && t.val isa Expr
16141617
# `copyast` makes copies of Exprs
16151618
t = Expr
16161619
end
1617-
elseif e.head === :invoke
1620+
elseif ehead === :invoke || ehead === :invoke_modify
16181621
error("type inference data-flow error: tried to double infer a function")
1619-
elseif e.head === :isdefined
1622+
elseif ehead === :isdefined
16201623
sym = e.args[1]
16211624
t = Bool
16221625
if isa(sym, SlotNumber)

base/compiler/optimize.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ function statement_cost(ex::Expr, line::Int, src::Union{CodeInfo, IRCode}, sptyp
491491
return 0
492492
end
493493
return error_path ? params.inline_error_path_cost : params.inline_nonleaf_penalty
494-
elseif head === :foreigncall || head === :invoke
494+
elseif head === :foreigncall || head === :invoke || head == :invoke_modify
495495
# Calls whose "return type" is Union{} do not actually return:
496496
# they are errors. Since these are not part of the typical
497497
# run-time of the function, we omit them from

base/compiler/ssair/inlining.jl

+16
Original file line numberDiff line numberDiff line change
@@ -1141,6 +1141,22 @@ function process_simple!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, sta
11411141
ir.stmts[idx][:inst] = res
11421142
return nothing
11431143
end
1144+
if (sig.f === modifyfield! || sig.ft typeof(modifyfield!)) && 5 <= length(stmt.args) <= 6
1145+
let info = ir.stmts[idx][:info]
1146+
info isa MethodResultPure && (info = info.info)
1147+
info isa ConstCallInfo && (info = info.call)
1148+
info isa MethodMatchInfo || return nothing
1149+
length(info.results) == 1 || return nothing
1150+
match = info.results[1]::MethodMatch
1151+
match.fully_covers || return nothing
1152+
case = compileable_specialization(state.et, match)
1153+
case === nothing && return nothing
1154+
stmt.head = :invoke_modify
1155+
pushfirst!(stmt.args, case)
1156+
ir.stmts[idx][:inst] = stmt
1157+
end
1158+
return nothing
1159+
end
11441160

11451161
check_effect_free!(ir, stmt, calltype, idx)
11461162

base/compiler/ssair/ir.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,8 @@ function getindex(x::UseRef)
403403
end
404404

405405
function is_relevant_expr(e::Expr)
406-
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
406+
return e.head in (:call, :invoke, :invoke_modify,
407+
:new, :splatnew, :(=), :(&),
407408
:gc_preserve_begin, :gc_preserve_end,
408409
:foreigncall, :isdefined, :copyast,
409410
:undefcheck, :throw_undef_if_not,

base/compiler/tfuncs.jl

+31-1
Original file line numberDiff line numberDiff line change
@@ -939,10 +939,40 @@ function modifyfield!_tfunc(o, f, op, v)
939939
@nospecialize
940940
T = _fieldtype_tfunc(o, isconcretetype(o), f)
941941
T === Bottom && return Bottom
942-
# note: we could sometimes refine this to a PartialStruct if we analyzed `op(o.f, v)::T`
943942
PT = Const(Pair)
944943
return instanceof_tfunc(apply_type_tfunc(PT, T, T))[1]
945944
end
945+
function abstract_modifyfield!(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
946+
nargs = length(argtypes)
947+
if !isempty(argtypes) && isvarargtype(argtypes[nargs])
948+
nargs - 1 <= 6 || return CallMeta(Bottom, false)
949+
nargs > 3 || return CallMeta(Any, false)
950+
else
951+
5 <= nargs <= 6 || return CallMeta(Bottom, false)
952+
end
953+
o = unwrapva(argtypes[2])
954+
f = unwrapva(argtypes[3])
955+
RT = modifyfield!_tfunc(o, f, Any, Any)
956+
info = false
957+
if nargs >= 5 && RT !== Bottom
958+
# we may be able to refine this to a PartialStruct by analyzing `op(o.f, v)::T`
959+
# as well as compute the info for the method matches
960+
op = unwrapva(argtypes[4])
961+
v = unwrapva(argtypes[5])
962+
TF = getfield_tfunc(o, f)
963+
push!(sv.ssavalue_uses[sv.currpc], sv.currpc) # temporarily disable `call_result_unused` check for this call
964+
callinfo = abstract_call(interp, nothing, Any[op, TF, v], sv, #=max_methods=# 1)
965+
pop!(sv.ssavalue_uses[sv.currpc], sv.currpc)
966+
TF2 = tmeet(callinfo.rt, widenconst(TF))
967+
if TF2 === Bottom
968+
RT = Bottom
969+
elseif isconcretetype(RT) && has_nontrivial_const_info(TF2) # isconcrete condition required to form a PartialStruct
970+
RT = PartialStruct(RT, Any[TF, TF2])
971+
end
972+
info = callinfo.info
973+
end
974+
return CallMeta(RT, info)
975+
end
946976
replacefield!_tfunc(o, f, x, v, success_order, failure_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
947977
replacefield!_tfunc(o, f, x, v, success_order) = (@nospecialize; replacefield!_tfunc(o, f, x, v))
948978
function replacefield!_tfunc(o, f, x, v)

base/compiler/validation.jl

+5-3
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
const VALID_EXPR_HEADS = IdDict{Symbol,UnitRange{Int}}(
55
:call => 1:typemax(Int),
66
:invoke => 2:typemax(Int),
7+
:invoke_modify => 3:typemax(Int),
78
:static_parameter => 1:1,
89
:(&) => 1:1,
910
:(=) => 2:2,
@@ -78,7 +79,7 @@ end
7879

7980
function _validate_val!(@nospecialize(x), errors, ssavals::BitSet)
8081
if isa(x, Expr)
81-
if x.head === :call || x.head === :invoke
82+
if x.head === :call || x.head === :invoke || x.head === :invoke_modify
8283
f = x.args[1]
8384
if f isa GlobalRef && (f.name === :cglobal) && x.head === :call
8485
# TODO: these are not yet linearized
@@ -138,7 +139,8 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
138139
end
139140
validate_val!(lhs)
140141
validate_val!(rhs)
141-
elseif head === :call || head === :invoke || head === :gc_preserve_end || head === :meta ||
142+
elseif head === :call || head === :invoke || x.head === :invoke_modify ||
143+
head === :gc_preserve_end || head === :meta ||
142144
head === :inbounds || head === :foreigncall || head === :cfunction ||
143145
head === :const || head === :enter || head === :leave || head === :pop_exception ||
144146
head === :method || head === :global || head === :static_parameter ||
@@ -238,7 +240,7 @@ end
238240

239241
function is_valid_rvalue(@nospecialize(x))
240242
is_valid_argument(x) && return true
241-
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
243+
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :invoke_modify, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
242244
return true
243245
end
244246
return false

src/ast.c

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ extern "C" {
2828

2929
// head symbols for each expression type
3030
jl_sym_t *call_sym; jl_sym_t *invoke_sym;
31+
jl_sym_t *invoke_modify_sym;
3132
jl_sym_t *empty_sym; jl_sym_t *top_sym;
3233
jl_sym_t *module_sym; jl_sym_t *slot_sym;
3334
jl_sym_t *export_sym; jl_sym_t *import_sym;
@@ -345,6 +346,7 @@ void jl_init_common_symbols(void)
345346
empty_sym = jl_symbol("");
346347
call_sym = jl_symbol("call");
347348
invoke_sym = jl_symbol("invoke");
349+
invoke_modify_sym = jl_symbol("invoke_modify");
348350
foreigncall_sym = jl_symbol("foreigncall");
349351
cfunction_sym = jl_symbol("cfunction");
350352
quote_sym = jl_symbol("quote");

src/cgutils.cpp

+30-19
Original file line numberDiff line numberDiff line change
@@ -1547,17 +1547,23 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
15471547
Value *parent, // for the write barrier, NULL if no barrier needed
15481548
bool isboxed, AtomicOrdering Order, AtomicOrdering FailOrder, unsigned alignment,
15491549
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
1550-
bool maybe_null_if_boxed, const std::string &fname)
1550+
bool maybe_null_if_boxed, const jl_cgval_t *modifyop, const std::string &fname)
15511551
{
15521552
auto newval = [&](const jl_cgval_t &lhs) {
1553-
jl_cgval_t argv[3] = { cmp, lhs, rhs };
1554-
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
1555-
argv[0] = mark_julia_type(ctx, callval, true, jl_any_type);
1556-
if (!jl_subtype(argv[0].typ, jltype)) {
1557-
emit_typecheck(ctx, argv[0], jltype, fname + "typed_store");
1558-
argv[0] = update_julia_type(ctx, argv[0], jltype);
1559-
}
1560-
return argv[0];
1553+
const jl_cgval_t argv[3] = { cmp, lhs, rhs };
1554+
jl_cgval_t ret;
1555+
if (modifyop) {
1556+
ret = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
1557+
}
1558+
else {
1559+
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
1560+
ret = mark_julia_type(ctx, callval, true, jl_any_type);
1561+
}
1562+
if (!jl_subtype(ret.typ, jltype)) {
1563+
emit_typecheck(ctx, ret, jltype, fname + "typed_store");
1564+
ret = update_julia_type(ctx, ret, jltype);
1565+
}
1566+
return ret;
15611567
};
15621568
assert(!needlock || parent != nullptr);
15631569
Type *elty = isboxed ? T_prjlvalue : julia_type_to_llvm(ctx, jltype);
@@ -1570,7 +1576,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
15701576
else if (isreplacefield) {
15711577
Value *Success = emit_f_is(ctx, cmp, ghostValue(jltype));
15721578
Success = ctx.builder.CreateZExt(Success, T_int8);
1573-
jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
1579+
const jl_cgval_t argv[2] = {ghostValue(jltype), mark_julia_type(ctx, Success, false, jl_bool_type)};
15741580
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
15751581
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
15761582
}
@@ -1579,7 +1585,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
15791585
}
15801586
else { // modifyfield
15811587
jl_cgval_t oldval = ghostValue(jltype);
1582-
jl_cgval_t argv[2] = { oldval, newval(oldval) };
1588+
const jl_cgval_t argv[2] = { oldval, newval(oldval) };
15831589
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
15841590
return emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
15851591
}
@@ -1862,7 +1868,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
18621868
}
18631869
}
18641870
if (ismodifyfield) {
1865-
jl_cgval_t argv[2] = { oldval, rhs };
1871+
const jl_cgval_t argv[2] = { oldval, rhs };
18661872
jl_datatype_t *rettyp = jl_apply_modify_type(jltype);
18671873
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
18681874
}
@@ -1881,7 +1887,7 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
18811887
oldval = mark_julia_type(ctx, instr, isboxed, jltype);
18821888
if (isreplacefield) {
18831889
Success = ctx.builder.CreateZExt(Success, T_int8);
1884-
jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
1890+
const jl_cgval_t argv[2] = {oldval, mark_julia_type(ctx, Success, false, jl_bool_type)};
18851891
jl_datatype_t *rettyp = jl_apply_cmpswap_type(jltype);
18861892
oldval = emit_new_struct(ctx, (jl_value_t*)rettyp, 2, argv);
18871893
}
@@ -3269,7 +3275,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
32693275
jl_cgval_t rhs, jl_cgval_t cmp,
32703276
bool checked, bool wb, AtomicOrdering Order, AtomicOrdering FailOrder,
32713277
bool needlock, bool issetfield, bool isreplacefield, bool isswapfield, bool ismodifyfield,
3272-
const std::string &fname)
3278+
const jl_cgval_t *modifyop, const std::string &fname)
32733279
{
32743280
if (!sty->name->mutabl && checked) {
32753281
std::string msg = fname + "immutable struct of type "
@@ -3309,9 +3315,14 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33093315
if (ismodifyfield) {
33103316
if (needlock)
33113317
emit_lockstate_value(ctx, strct, false);
3312-
jl_cgval_t argv[3] = { cmp, oldval, rhs };
3313-
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
3314-
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
3318+
const jl_cgval_t argv[3] = { cmp, oldval, rhs };
3319+
if (modifyop) {
3320+
rhs = emit_invoke(ctx, *modifyop, argv, 3, (jl_value_t*)jl_any_type);
3321+
}
3322+
else {
3323+
Value *callval = emit_jlcall(ctx, jlapplygeneric_func, nullptr, argv, 3, JLCALL_F_CC);
3324+
rhs = mark_julia_type(ctx, callval, true, jl_any_type);
3325+
}
33153326
if (!jl_subtype(rhs.typ, jfty)) {
33163327
emit_typecheck(ctx, rhs, jfty, fname);
33173328
rhs = update_julia_type(ctx, rhs, jfty);
@@ -3364,7 +3375,7 @@ static jl_cgval_t emit_setfield(jl_codectx_t &ctx,
33643375
return typed_store(ctx, addr, NULL, rhs, cmp, jfty, strct.tbaa, nullptr,
33653376
wb ? maybe_bitcast(ctx, data_pointer(ctx, strct), T_pjlvalue) : nullptr,
33663377
isboxed, Order, FailOrder, align,
3367-
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, fname);
3378+
needlock, issetfield, isreplacefield, isswapfield, ismodifyfield, maybe_null, modifyop, fname);
33683379
}
33693380
}
33703381

@@ -3543,7 +3554,7 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg
35433554
else
35443555
need_wb = false;
35453556
emit_typecheck(ctx, rhs, jl_svecref(sty->types, i), "new");
3546-
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, "");
3557+
emit_setfield(ctx, sty, strctinfo, i, rhs, jl_cgval_t(), false, need_wb, AtomicOrdering::NotAtomic, AtomicOrdering::NotAtomic, false, true, false, false, false, nullptr, "");
35473558
}
35483559
return strctinfo;
35493560
}

0 commit comments

Comments
 (0)