Skip to content

Commit

Permalink
Make emitted egal code more loopy
Browse files Browse the repository at this point in the history
The strategy here is to look at (data, padding) pairs and RLE
them into loops, so that repeated adjacent patterns use a loop
rather than getting unrolled. On the test case from #54109,
this makes compilation essentially instant, while also being
faster at runtime (turns out LLVM spends a massive amount of time
AND the answer is bad).

There's some obvious further enhancements possible here:
1. The `memcmp` constant is small. LLVM has a pass to inline these
   with better code. However, we don't have it turned on. We should
   consider vendoring it, though we may want to add some shorcutting
   to it to avoid having it iterate through each function.
2. This only does one level of sequence matching. It could be recursed
   to turn things into nested loops.

However, this solves the immediate issue, so hopefully it's a useful
start. Fixes #54109.
  • Loading branch information
Keno committed Apr 24, 2024
1 parent 7ba1b33 commit 98e54a0
Show file tree
Hide file tree
Showing 8 changed files with 243 additions and 19 deletions.
19 changes: 17 additions & 2 deletions base/reflection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ gc_alignment(T::Type) = gc_alignment(Core.sizeof(T))
Base.datatype_haspadding(dt::DataType) -> Bool
Return whether the fields of instances of this type are packed in memory,
with no intervening padding bits (defined as bits whose value does not uniquely
impact the egal test when applied to the struct fields).
with no intervening padding bits (defined as bits whose value does not impact
the semantic value of the instance itself).
Can be called on any `isconcretetype`.
"""
function datatype_haspadding(dt::DataType)
Expand All @@ -499,6 +499,21 @@ function datatype_haspadding(dt::DataType)
return flags & 1 == 1
end

"""
Base.datatype_isbitsegal(dt::DataType) -> Bool
Return whether egality of the (non-padding bits of the) in-memory representation
of an instance of this type implies semantic egality of the instance itself.
This may not be the case if the type contains to other values whose egality is
independent of their identity (e.g. immutable structs, some types, etc.).
"""
function datatype_isbitsegal(dt::DataType)
@_foldable_meta
dt.layout == C_NULL && throw(UndefRefError())
flags = unsafe_load(convert(Ptr{DataTypeLayout}, dt.layout)).flags
return (flags & (1<<5)) != 0
end

"""
Base.datatype_nfields(dt::DataType) -> UInt32
Expand Down
6 changes: 3 additions & 3 deletions src/builtins.c
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ static int NOINLINE compare_fields(const jl_value_t *a, const jl_value_t *b, jl_
continue; // skip this field (it is #undef)
}
}
if (!ft->layout->flags.haspadding) {
if (!ft->layout->flags.haspadding && ft->layout->flags.isbitsegal) {
if (!bits_equal(ao, bo, ft->layout->size))
return 0;
}
Expand Down Expand Up @@ -284,7 +284,7 @@ inline int jl_egal__bits(const jl_value_t *a JL_MAYBE_UNROOTED, const jl_value_t
if (sz == 0)
return 1;
size_t nf = jl_datatype_nfields(dt);
if (nf == 0 || !dt->layout->flags.haspadding)
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal))
return bits_equal(a, b, sz);
return compare_fields(a, b, dt);
}
Expand Down Expand Up @@ -394,7 +394,7 @@ static uintptr_t immut_id_(jl_datatype_t *dt, jl_value_t *v, uintptr_t h) JL_NOT
if (sz == 0)
return ~h;
size_t f, nf = jl_datatype_nfields(dt);
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->npointers == 0)) {
if (nf == 0 || (!dt->layout->flags.haspadding && dt->layout->flags.isbitsegal && dt->layout->npointers == 0)) {
// operate element-wise if there are unused bits inside,
// otherwise just take the whole data block at once
// a few select pointers (notably symbol) also have special hash values
Expand Down
3 changes: 2 additions & 1 deletion src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2200,7 +2200,8 @@ static jl_cgval_t typed_store(jl_codectx_t &ctx,
}
else if (!isboxed) {
assert(jl_is_concrete_type(jltype));
needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding;
needloop = ((jl_datatype_t*)jltype)->layout->flags.haspadding ||
!((jl_datatype_t*)jltype)->layout->flags.isbitsegal;
Value *SameType = emit_isa(ctx, cmp, jltype, Twine()).first;
if (SameType != ConstantInt::getTrue(ctx.builder.getContext())) {
BasicBlock *SkipBB = BasicBlock::Create(ctx.builder.getContext(), "skip_xchg", ctx.f);
Expand Down
137 changes: 136 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3358,6 +3358,58 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1,
return phi;
}

struct egal_desc {
size_t offset;
size_t nrepeats;
size_t data_bytes;
size_t padding_bytes;
};

template <typename callback>
static size_t emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc &current_desc)
{
// Memcmp, but with masked padding
size_t data_bytes = 0;
size_t padding_bytes = 0;
size_t nfields = jl_datatype_nfields(aty);
size_t total_size = jl_datatype_size(aty);
for (size_t i = 0; i < nfields; ++i) {
size_t offset = jl_field_offset(aty, i);
size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1);
size_t fsz = jl_field_size(aty, i);
jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i);
if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) {
// The field has no internal padding
data_bytes += fsz;
if (offset + fsz == fend) {
// The field has no padding after. Merge this into the current
// comparison range and go to next field.
} else {
padding_bytes = fend - offset - fsz;
// Found padding. Either merge this into the current comparison
// range, or emit the old one and start a new one.
if (current_desc.data_bytes == data_bytes &&
current_desc.padding_bytes == padding_bytes) {
// Same as the previous range, just note that down, so we
// emit this as a loop.
current_desc.nrepeats += 1;
} else {
if (current_desc.nrepeats != 0)
emit_desc(current_desc);
current_desc.nrepeats = 1;
current_desc.data_bytes = data_bytes;
current_desc.padding_bytes = padding_bytes;
}
data_bytes = 0;
}
} else {
// The field may have internal padding. Recurse this.
data_bytes += emit_masked_bits_compare(emit_desc, fty, current_desc);
}
}
return data_bytes;
}

static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2)
{
++EmittedBitsCompares;
Expand Down Expand Up @@ -3396,7 +3448,7 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
if (at->isAggregateType()) { // Struct or Array
jl_datatype_t *sty = (jl_datatype_t*)arg1.typ;
size_t sz = jl_datatype_size(sty);
if (sz > 512 && !sty->layout->flags.haspadding) {
if (sz > 512 && !sty->layout->flags.haspadding && sty->layout->flags.isbitsegal) {
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
value_to_pointer(ctx, arg1).V;
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
Expand Down Expand Up @@ -3433,6 +3485,89 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
}
return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
}
else if (sz > 512 && jl_struct_try_layout(sty) && sty->layout->flags.isbitsegal) {
Type *TInt8 = getInt8Ty(ctx.builder.getContext());
Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext());
Type *TInt1 = getInt1Ty(ctx.builder.getContext());
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
value_to_pointer(ctx, arg1).V;
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
value_to_pointer(ctx, arg2).V;
varg1 = emit_pointer_from_objref(ctx, varg1);
varg2 = emit_pointer_from_objref(ctx, varg2);
varg1 = emit_bitcast(ctx, varg1, TpInt8);
varg2 = emit_bitcast(ctx, varg2, TpInt8);

Value *answer = nullptr;
auto emit_desc = [&](egal_desc desc) {
Value *ptr1 = varg1;
Value *ptr2 = varg2;
if (desc.offset != 0) {
ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset);
ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset);
}

Value *new_ptr1 = ptr1;
Value *endptr1 = nullptr;
BasicBlock *postBB = nullptr;
BasicBlock *loopBB = nullptr;
PHINode *answerphi = nullptr;
if (desc.nrepeats != 1) {
// Set up loop
endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));;

BasicBlock *currBB = ctx.builder.GetInsertBlock();
loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f);
postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f);
ctx.builder.CreateBr(loopBB);

ctx.builder.SetInsertPoint(loopBB);
answerphi = ctx.builder.CreatePHI(TInt1, 2);
answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB);
answer = answerphi;

PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2);
PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2);

new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes);
itr1->addIncoming(ptr1, currBB);
itr1->addIncoming(new_ptr1, loopBB);

Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes);
itr2->addIncoming(ptr2, currBB);
itr2->addIncoming(new_ptr2, loopBB);

ptr1 = itr1;
ptr2 = itr2;
}

// Emit memcmp. TODO: LLVM has a pass to expand this for additional
// performance.
Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func),
{ ptr1,
ptr2,
ConstantInt::get(ctx.types().T_size, desc.data_bytes) });
this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer;
if (endptr1) {
answerphi->addIncoming(answer, loopBB);
Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1);
ctx.builder.CreateCondBr(loopend, postBB, loopBB);
ctx.builder.SetInsertPoint(postBB);
}
};
egal_desc current_desc = {0};
size_t trailing_data_bytes = emit_masked_bits_compare(emit_desc, sty, current_desc);
assert(current_desc.nrepeats != 0);
emit_desc(current_desc);
if (trailing_data_bytes != 0) {
current_desc.nrepeats = 1;
current_desc.data_bytes = trailing_data_bytes;
current_desc.padding_bytes = 0;
emit_desc(current_desc);
}
return answer;
}
else {
jl_svec_t *types = sty->types;
Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1);
Expand Down
33 changes: 24 additions & 9 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
uint32_t npointers,
uint32_t alignment,
int haspadding,
int isbitsegal,
int arrayelem,
jl_fielddesc32_t desc[],
uint32_t pointers[]) JL_NOTSAFEPOINT
Expand Down Expand Up @@ -226,6 +227,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
flddesc->nfields = nfields;
flddesc->alignment = alignment;
flddesc->flags.haspadding = haspadding;
flddesc->flags.isbitsegal = isbitsegal;
flddesc->flags.fielddesc_type = fielddesc_type;
flddesc->flags.arrayelem_isboxed = arrayelem == 1;
flddesc->flags.arrayelem_isunion = arrayelem == 2;
Expand Down Expand Up @@ -504,6 +506,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
int isunboxed = jl_islayout_inline(eltype, &elsz, &al) && (kind != (jl_value_t*)jl_atomic_sym || jl_is_datatype(eltype));
int isunion = isunboxed && jl_is_uniontype(eltype);
int haspadding = 1; // we may want to eventually actually compute this more precisely
int isbitsegal = 0;
int nfields = 0; // aka jl_is_layout_opaque
int npointers = 1;
int zi;
Expand Down Expand Up @@ -562,7 +565,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
else
arrayelem = 0;
assert(!st->layout);
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, arrayelem, NULL, pointers);
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, isbitsegal, arrayelem, NULL, pointers);
st->zeroinit = zi;
//st->has_concrete_subtype = 1;
//st->isbitstype = 0;
Expand Down Expand Up @@ -673,6 +676,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
size_t alignm = 1;
int zeroinit = 0;
int haspadding = 0;
int isbitsegal = 1;
int homogeneous = 1;
int needlock = 0;
uint32_t npointers = 0;
Expand All @@ -687,19 +691,30 @@ void jl_compute_field_offsets(jl_datatype_t *st)
throw_ovf(should_malloc, desc, st, fsz);
desc[i].isptr = 0;
if (jl_is_uniontype(fld)) {
haspadding = 1;
fsz += 1; // selector byte
zeroinit = 1;
// TODO: Some unions could be bits comparable.
isbitsegal = 0;
}
else {
uint32_t fld_npointers = ((jl_datatype_t*)fld)->layout->npointers;
if (((jl_datatype_t*)fld)->layout->flags.haspadding)
haspadding = 1;
if (!((jl_datatype_t*)fld)->layout->flags.isbitsegal)
isbitsegal = 0;
if (i >= nfields - st->name->n_uninitialized && fld_npointers &&
fld_npointers * sizeof(void*) != fsz) {
// field may be undef (may be uninitialized and contains pointer),
// and contains non-pointer fields of non-zero sizes.
haspadding = 1;
// For field types that contain pointers, we allow inlinealloc
// as long as the field type itself is always fully initialized.
// In such a case, we use the first pointer in the inlined field
// as the #undef marker (if it is zero, we treat the whole inline
// struct as #undef). However, we do not zero-initialize the whole
// struct, so the non-pointer parts of the inline allocation may
// be arbitrary, but still need to compare egal (because all #undef)
// representations are egal. Because of this, we cannot bitscompare
// them.
// TODO: Consider zero-initializing the whole struct.
isbitsegal = 0;
}
if (!zeroinit)
zeroinit = ((jl_datatype_t*)fld)->zeroinit;
Expand All @@ -715,8 +730,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
zeroinit = 1;
npointers++;
if (!jl_pointer_egal(fld)) {
// this somewhat poorly named flag says whether some of the bits can be non-unique
haspadding = 1;
isbitsegal = 0;
}
}
if (isatomic && fsz > MAX_ATOMIC_SIZE)
Expand Down Expand Up @@ -777,7 +791,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
}
}
assert(ptr_i == npointers);
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, 0, desc, pointers);
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, isbitsegal, 0, desc, pointers);
if (should_malloc) {
free(desc);
if (npointers)
Expand Down Expand Up @@ -931,7 +945,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_primitivetype(jl_value_t *name, jl_module_t *
bt->ismutationfree = 1;
bt->isidentityfree = 1;
bt->isbitstype = (parameters == jl_emptysvec);
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 0, NULL, NULL);
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 1, 0, NULL, NULL);
bt->instance = NULL;
return bt;
}
Expand All @@ -954,6 +968,7 @@ JL_DLLEXPORT jl_datatype_t * jl_new_foreign_type(jl_sym_t *name,
layout->alignment = sizeof(void *);
layout->npointers = haspointers;
layout->flags.haspadding = 1;
layout->flags.isbitsegal = 0;
layout->flags.fielddesc_type = 3;
layout->flags.padding = 0;
layout->flags.arrayelem_isboxed = 0;
Expand Down
5 changes: 4 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ typedef struct {
// metadata bit only for GenericMemory eltype layout
uint16_t arrayelem_isboxed : 1;
uint16_t arrayelem_isunion : 1;
uint16_t padding : 11;
// If set, this type's egality can be determined entirely by comparing
// the non-padding bits of this datatype.
uint16_t isbitsegal : 1;
uint16_t padding : 10;
} flags;
// union {
// jl_fielddesc8_t field8[nfields];
Expand Down
Loading

0 comments on commit 98e54a0

Please sign in to comment.