Skip to content

Commit

Permalink
Make emitted egal code more loopy
Browse files Browse the repository at this point in the history
The strategy here is to look at (data, padding) pairs and RLE
them into loops, so that repeated adjacent patterns use a loop
rather than getting unrolled. On the test case from #54109,
this makes compilation essentially instant, while also being
faster at runtime (turns out LLVM spends a massive amount of time
AND the answer is bad).

There's some obvious further enhancements possible here:
1. The `memcmp` constant is small. LLVM has a pass to inline these
   with better code. However, we don't have it turned on. We should
   consider vendoring it, though we may want to add some shorcutting
   to it to avoid having it iterate through each function.
2. This only does one level of sequence matching. It could be recursed
   to turn things into nested loops.

However, this solves the immediate issue, so hopefully it's a useful
start. Fixes #54109.
  • Loading branch information
Keno committed Apr 17, 2024
1 parent 7ba1b33 commit 0d54ad3
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
127 changes: 127 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3358,6 +3358,56 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1,
return phi;
}

struct egal_desc {
size_t offset;
size_t nrepeats;
size_t data_bytes;
size_t padding_bytes;
};

template <typename callback>
static void emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc &current_desc)
{
// Memcmp, but with masked padding
size_t data_bytes = 0;
size_t padding_bytes = 0;
size_t nfields = jl_datatype_nfields(aty);
size_t total_size = jl_datatype_size(aty);
for (size_t i = 0; i < nfields; ++i) {
size_t offset = jl_field_offset(aty, i);
size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1);
size_t fsz = jl_field_size(aty, i);
jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i);
if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) {
// The field has no internal padding
data_bytes += fsz;
if (offset + fsz == fend) {
// The field has no padding after. Merge this into the current
// comparison range and go to next field.
} else {
padding_bytes = fend - offset - fsz;
// Found padding. Either merge this into the current comparison
// range, or emit the old one and start a new one.
if (current_desc.data_bytes == data_bytes &&
current_desc.padding_bytes == padding_bytes) {
// Same as the previous range, just note that down, so we
// emit this as a loop.
current_desc.nrepeats += 1;
} else {
if (current_desc.nrepeats != 0)
emit_desc(current_desc);
current_desc.nrepeats = 1;
current_desc.data_bytes = data_bytes;
current_desc.padding_bytes = padding_bytes;
}
}
} else {
// The field may have internal padding. Recurse this.
emit_masked_bits_compare(emit_desc, fty, current_desc);
}
}
}

static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2)
{
++EmittedBitsCompares;
Expand Down Expand Up @@ -3433,6 +3483,83 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
}
return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
}
else if (sz > 512 && jl_struct_try_layout((jl_datatype_t*)arg1.typ)) {
Type *TInt8 = getInt8Ty(ctx.builder.getContext());
Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext());
Type *TInt1 = getInt1Ty(ctx.builder.getContext());
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
value_to_pointer(ctx, arg1).V;
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
value_to_pointer(ctx, arg2).V;
varg1 = emit_pointer_from_objref(ctx, varg1);
varg2 = emit_pointer_from_objref(ctx, varg2);
varg1 = emit_bitcast(ctx, varg1, TpInt8);
varg2 = emit_bitcast(ctx, varg2, TpInt8);

Value *answer = nullptr;
auto emit_desc = [&](egal_desc desc) {
Value *ptr1 = varg1;
Value *ptr2 = varg2;
if (desc.offset != 0) {
ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset);
ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset);
}

Value *new_ptr1 = ptr1;
Value *endptr1 = nullptr;
BasicBlock *postBB = nullptr;
BasicBlock *loopBB = nullptr;
PHINode *answerphi = nullptr;
if (desc.nrepeats != 1) {
// Set up loop
endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));;

BasicBlock *currBB = ctx.builder.GetInsertBlock();
loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f);
postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f);
ctx.builder.CreateBr(loopBB);

ctx.builder.SetInsertPoint(loopBB);
answerphi = ctx.builder.CreatePHI(TInt1, 2);
answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB);
answer = answerphi;

PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2);
PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2);

new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes);
itr1->addIncoming(ptr1, currBB);
itr1->addIncoming(new_ptr1, loopBB);

Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes);
itr2->addIncoming(ptr2, currBB);
itr2->addIncoming(new_ptr2, loopBB);

ptr1 = itr1;
ptr2 = itr2;
}

// Emit memcmp. TODO: LLVM has a pass to expand this for additional
// performance.
Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func),
{ ptr1,
ptr2,
ConstantInt::get(ctx.types().T_size, desc.data_bytes) });
this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer;
if (endptr1) {
answerphi->addIncoming(answer, loopBB);
Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1);
ctx.builder.CreateCondBr(loopend, postBB, loopBB);
ctx.builder.SetInsertPoint(postBB);
}
};
egal_desc current_desc = {0};
emit_masked_bits_compare(emit_desc, (jl_datatype_t*)arg1.typ, current_desc);
assert(current_desc.nrepeats != 0);
emit_desc(current_desc);
return answer;
}
else {
jl_svec_t *types = sty->types;
Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1);
Expand Down
42 changes: 42 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,45 @@ if Sys.ARCH === :x86_64
end
end
end

# #54109 - Excessive LLVM time for egal
struct DefaultOr54109{T}
x::T
default::Bool
end

@eval struct Torture1_54109
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:897)...)
end
Torture1_54109() = Torture1_54109((DefaultOr54109(1.0, false) for i = 1:897)...)

@eval struct Torture2_54109
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:400)...)
$((Expr(:(::), Symbol("x$(i+400)"), DefaultOr54109{Int16}) for i = 1:400)...)
end
Torture2_54109() = Torture2_54109((DefaultOr54109(1.0, false) for i = 1:400)..., (DefaultOr54109(Int16(1), false) for i = 1:400)...)

@noinline egal_any54109(x, @nospecialize(y::Any)) = x === Base.compilerbarrier(:type, y)

let ir1 = get_llvm(egal_any54109, Tuple{Torture1_54109, Any}),
ir2 = get_llvm(egal_any54109, Tuple{Torture2_54109, Any})

# We can't really do timing on CI, so instead, let's look at the length of
# the optimized IR. The original version had tens of thousands of lines and
# was slower, so just check here that we only have < 500 lines. If somebody,
# implements a better comparison that's larger than that, just re-benchmark
# this and adjust the threshold.

@test count(==('\n'), ir1) < 500
@test count(==('\n'), ir2) < 500
end

## For completeness, also test correctness, since we don't have a lot of
## large-struct tests.

# The two allocations of the same struct will likely have different padding,
# we want to make sure we find them egal anyway - a naive memcmp would
# accidentally look at it.
@test egal_any54109(Torture1_54109(), Torture1_54109())
@test egal_any54109(Torture2_54109(), Torture2_54109())
@test !egal_any54109(Torture1_54109(), Torture1_54109((DefaultOr54109(2.0, false) for i = 1:897)...))

0 comments on commit 0d54ad3

Please sign in to comment.