From 0d54ad3086b7fc61afa28b512b27668e1ddef2f5 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Wed, 17 Apr 2024 23:01:19 +0000 Subject: [PATCH] Make emitted egal code more loopy The strategy here is to look at (data, padding) pairs and RLE them into loops, so that repeated adjacent patterns use a loop rather than getting unrolled. On the test case from #54109, this makes compilation essentially instant, while also being faster at runtime (turns out LLVM spends a massive amount of time AND the answer is bad). There's some obvious further enhancements possible here: 1. The `memcmp` constant is small. LLVM has a pass to inline these with better code. However, we don't have it turned on. We should consider vendoring it, though we may want to add some shorcutting to it to avoid having it iterate through each function. 2. This only does one level of sequence matching. It could be recursed to turn things into nested loops. However, this solves the immediate issue, so hopefully it's a useful start. Fixes #54109. --- src/codegen.cpp | 127 +++++++++++++++++++++++++++++++++++++++ test/compiler/codegen.jl | 42 +++++++++++++ 2 files changed, 169 insertions(+) diff --git a/src/codegen.cpp b/src/codegen.cpp index 5d1dda3e735dc1..2eb3f3a520ca90 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3358,6 +3358,56 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1, return phi; } +struct egal_desc { + size_t offset; + size_t nrepeats; + size_t data_bytes; + size_t padding_bytes; +}; + +template +static void emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc ¤t_desc) +{ + // Memcmp, but with masked padding + size_t data_bytes = 0; + size_t padding_bytes = 0; + size_t nfields = jl_datatype_nfields(aty); + size_t total_size = jl_datatype_size(aty); + for (size_t i = 0; i < nfields; ++i) { + size_t offset = jl_field_offset(aty, i); + size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1); + size_t fsz = jl_field_size(aty, i); + jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i); + if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) { + // The field has no internal padding + data_bytes += fsz; + if (offset + fsz == fend) { + // The field has no padding after. Merge this into the current + // comparison range and go to next field. + } else { + padding_bytes = fend - offset - fsz; + // Found padding. Either merge this into the current comparison + // range, or emit the old one and start a new one. + if (current_desc.data_bytes == data_bytes && + current_desc.padding_bytes == padding_bytes) { + // Same as the previous range, just note that down, so we + // emit this as a loop. + current_desc.nrepeats += 1; + } else { + if (current_desc.nrepeats != 0) + emit_desc(current_desc); + current_desc.nrepeats = 1; + current_desc.data_bytes = data_bytes; + current_desc.padding_bytes = padding_bytes; + } + } + } else { + // The field may have internal padding. Recurse this. + emit_masked_bits_compare(emit_desc, fty, current_desc); + } + } +} + static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2) { ++EmittedBitsCompares; @@ -3433,6 +3483,83 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a } return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0)); } + else if (sz > 512 && jl_struct_try_layout((jl_datatype_t*)arg1.typ)) { + Type *TInt8 = getInt8Ty(ctx.builder.getContext()); + Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext()); + Type *TInt1 = getInt1Ty(ctx.builder.getContext()); + Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) : + value_to_pointer(ctx, arg1).V; + Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) : + value_to_pointer(ctx, arg2).V; + varg1 = emit_pointer_from_objref(ctx, varg1); + varg2 = emit_pointer_from_objref(ctx, varg2); + varg1 = emit_bitcast(ctx, varg1, TpInt8); + varg2 = emit_bitcast(ctx, varg2, TpInt8); + + Value *answer = nullptr; + auto emit_desc = [&](egal_desc desc) { + Value *ptr1 = varg1; + Value *ptr2 = varg2; + if (desc.offset != 0) { + ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset); + ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset); + } + + Value *new_ptr1 = ptr1; + Value *endptr1 = nullptr; + BasicBlock *postBB = nullptr; + BasicBlock *loopBB = nullptr; + PHINode *answerphi = nullptr; + if (desc.nrepeats != 1) { + // Set up loop + endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));; + + BasicBlock *currBB = ctx.builder.GetInsertBlock(); + loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f); + postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f); + ctx.builder.CreateBr(loopBB); + + ctx.builder.SetInsertPoint(loopBB); + answerphi = ctx.builder.CreatePHI(TInt1, 2); + answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB); + answer = answerphi; + + PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2); + PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2); + + new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes); + itr1->addIncoming(ptr1, currBB); + itr1->addIncoming(new_ptr1, loopBB); + + Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes); + itr2->addIncoming(ptr2, currBB); + itr2->addIncoming(new_ptr2, loopBB); + + ptr1 = itr1; + ptr2 = itr2; + } + + // Emit memcmp. TODO: LLVM has a pass to expand this for additional + // performance. + Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func), + { ptr1, + ptr2, + ConstantInt::get(ctx.types().T_size, desc.data_bytes) }); + this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0)); + answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer; + if (endptr1) { + answerphi->addIncoming(answer, loopBB); + Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1); + ctx.builder.CreateCondBr(loopend, postBB, loopBB); + ctx.builder.SetInsertPoint(postBB); + } + }; + egal_desc current_desc = {0}; + emit_masked_bits_compare(emit_desc, (jl_datatype_t*)arg1.typ, current_desc); + assert(current_desc.nrepeats != 0); + emit_desc(current_desc); + return answer; + } else { jl_svec_t *types = sty->types; Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1); diff --git a/test/compiler/codegen.jl b/test/compiler/codegen.jl index 805e5c7acc817a..2603dc8618020a 100644 --- a/test/compiler/codegen.jl +++ b/test/compiler/codegen.jl @@ -873,3 +873,45 @@ if Sys.ARCH === :x86_64 end end end + +# #54109 - Excessive LLVM time for egal +struct DefaultOr54109{T} + x::T + default::Bool +end + +@eval struct Torture1_54109 + $((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:897)...) +end +Torture1_54109() = Torture1_54109((DefaultOr54109(1.0, false) for i = 1:897)...) + +@eval struct Torture2_54109 + $((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:400)...) + $((Expr(:(::), Symbol("x$(i+400)"), DefaultOr54109{Int16}) for i = 1:400)...) +end +Torture2_54109() = Torture2_54109((DefaultOr54109(1.0, false) for i = 1:400)..., (DefaultOr54109(Int16(1), false) for i = 1:400)...) + +@noinline egal_any54109(x, @nospecialize(y::Any)) = x === Base.compilerbarrier(:type, y) + +let ir1 = get_llvm(egal_any54109, Tuple{Torture1_54109, Any}), + ir2 = get_llvm(egal_any54109, Tuple{Torture2_54109, Any}) + + # We can't really do timing on CI, so instead, let's look at the length of + # the optimized IR. The original version had tens of thousands of lines and + # was slower, so just check here that we only have < 500 lines. If somebody, + # implements a better comparison that's larger than that, just re-benchmark + # this and adjust the threshold. + + @test count(==('\n'), ir1) < 500 + @test count(==('\n'), ir2) < 500 +end + +## For completeness, also test correctness, since we don't have a lot of +## large-struct tests. + +# The two allocations of the same struct will likely have different padding, +# we want to make sure we find them egal anyway - a naive memcmp would +# accidentally look at it. +@test egal_any54109(Torture1_54109(), Torture1_54109()) +@test egal_any54109(Torture2_54109(), Torture2_54109()) +@test !egal_any54109(Torture1_54109(), Torture1_54109((DefaultOr54109(2.0, false) for i = 1:897)...))