Skip to content

Commit

Permalink
fixup! threading: support more than nthreads at runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
vtjnash committed Jun 21, 2022
1 parent 9bb34d6 commit 11044e7
Show file tree
Hide file tree
Showing 10 changed files with 264 additions and 201 deletions.
6 changes: 4 additions & 2 deletions src/ccall.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1547,7 +1547,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
assert(lrt == getVoidTy(ctx.builder.getContext()));
assert(!isVa && !llvmcall && nccallargs == 0);
JL_GC_POP();
emit_gc_safepoint(ctx);
ctx.builder.CreateCall(prepare_call(gcroot_flush_func));
emit_gc_safepoint(ctx.builder, get_current_ptls(ctx), ctx.tbaa().tbaa_const);
return ghostValue(ctx, jl_nothing_type);
}
else if (is_libjulia_func("jl_get_ptls_states")) {
Expand Down Expand Up @@ -1650,7 +1651,8 @@ static jl_cgval_t emit_ccall(jl_codectx_t &ctx, jl_value_t **args, size_t nargs)
ctx.builder.SetInsertPoint(checkBB);
ctx.builder.CreateLoad(
getSizeTy(ctx.builder.getContext()),
ctx.builder.CreateConstInBoundsGEP1_32(getSizeTy(ctx.builder.getContext()), get_current_signal_page(ctx), -1),
ctx.builder.CreateConstInBoundsGEP1_32(getSizeTy(ctx.builder.getContext()),
get_current_signal_page_from_ptls(ctx.builder, get_current_ptls(ctx), ctx.tbaa().tbaa_const), -1),
true);
ctx.builder.CreateBr(contBB);
ctx.f->getBasicBlockList().push_back(contBB);
Expand Down
122 changes: 19 additions & 103 deletions src/cgutils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

STATISTIC(EmittedPointerFromObjref, "Number of emitted pointer_from_objref calls");
STATISTIC(EmittedPointerBitcast, "Number of emitted pointer bitcasts");
STATISTIC(EmittedNthPtrAddr, "Number of emitted nth pointer address instructions");
STATISTIC(EmittedTypeof, "Number of emitted typeof instructions");
STATISTIC(EmittedErrors, "Number of emitted errors");
STATISTIC(EmittedConditionalErrors, "Number of emitted conditional errors");
Expand Down Expand Up @@ -42,7 +41,6 @@ STATISTIC(EmittedCPointerChecks, "Number of C pointer checks emitted");
STATISTIC(EmittedAllocObjs, "Number of object allocations emitted");
STATISTIC(EmittedWriteBarriers, "Number of write barriers emitted");
STATISTIC(EmittedNewStructs, "Number of new structs emitted");
STATISTIC(EmittedSignalFences, "Number of signal fences emitted");
STATISTIC(EmittedDeferSignal, "Number of deferred signals emitted");

static Value *track_pjlvalue(jl_codectx_t &ctx, Value *V)
Expand Down Expand Up @@ -973,40 +971,18 @@ static void emit_memcpy(jl_codectx_t &ctx, Value *dst, MDNode *tbaa_dst, const j
emit_memcpy_llvm(ctx, dst, tbaa_dst, data_pointer(ctx, src), src.tbaa, sz, align, is_volatile);
}

static Value *emit_nthptr_addr(jl_codectx_t &ctx, Value *v, ssize_t n, bool gctracked = true)
{
++EmittedNthPtrAddr;
return ctx.builder.CreateInBoundsGEP(
ctx.types().T_prjlvalue,
emit_bitcast(ctx, maybe_decay_tracked(ctx, v), ctx.types().T_pprjlvalue),
ConstantInt::get(getSizeTy(ctx.builder.getContext()), n));
}

static Value *emit_nthptr_addr(jl_codectx_t &ctx, Value *v, Value *idx)
static LoadInst *emit_nthptr_recast(jl_codectx_t &ctx, Value *v, Value *idx, MDNode *tbaa, Type *type)
{
++EmittedNthPtrAddr;
return ctx.builder.CreateInBoundsGEP(
// p = (jl_value_t**)v; *(type*)&p[n]
Value *vptr = ctx.builder.CreateInBoundsGEP(
ctx.types().T_prjlvalue,
emit_bitcast(ctx, maybe_decay_tracked(ctx, v), ctx.types().T_pprjlvalue),
idx);
LoadInst *load = ctx.builder.CreateLoad(type, emit_bitcast(ctx, vptr, PointerType::get(type, 0)));
tbaa_decorate(tbaa, load);
return load;
}

static LoadInst *emit_nthptr_recast(jl_codectx_t &ctx, Value *v, Value *idx, MDNode *tbaa, Type *type)
{
// p = (jl_value_t**)v; *(type*)&p[n]
Value *vptr = emit_nthptr_addr(ctx, v, idx);
return cast<LoadInst>(tbaa_decorate(tbaa, ctx.builder.CreateLoad(type,
emit_bitcast(ctx, vptr, PointerType::get(type, 0)))));
}

static LoadInst *emit_nthptr_recast(jl_codectx_t &ctx, Value *v, ssize_t n, MDNode *tbaa, Type *type)
{
// p = (jl_value_t**)v; *(type*)&p[n]
Value *vptr = emit_nthptr_addr(ctx, v, n);
return cast<LoadInst>(tbaa_decorate(tbaa, ctx.builder.CreateLoad(type,
emit_bitcast(ctx, vptr, PointerType::get(type, 0)))));
}

static Value *boxed(jl_codectx_t &ctx, const jl_cgval_t &v);
static Value *emit_typeof(jl_codectx_t &ctx, Value *v, bool maybenull);

Expand Down Expand Up @@ -1179,8 +1155,12 @@ static Value *emit_datatype_isprimitivetype(jl_codectx_t &ctx, Value *dt)

static Value *emit_datatype_name(jl_codectx_t &ctx, Value *dt)
{
Value *vptr = emit_nthptr_addr(ctx, dt, (ssize_t)(offsetof(jl_datatype_t, name) / sizeof(char*)));
return tbaa_decorate(ctx.tbaa().tbaa_const, ctx.builder.CreateAlignedLoad(ctx.types().T_prjlvalue, vptr, Align(sizeof(void*))));
unsigned n = offsetof(jl_datatype_t, name) / sizeof(char*);
Value *vptr = ctx.builder.CreateInBoundsGEP(
ctx.types().T_pjlvalue,
emit_bitcast(ctx, maybe_decay_tracked(ctx, dt), ctx.types().T_ppjlvalue),
ConstantInt::get(getSizeTy(ctx.builder.getContext()), n));
return tbaa_decorate(ctx.tbaa().tbaa_const, ctx.builder.CreateAlignedLoad(ctx.types().T_pjlvalue, vptr, Align(sizeof(void*))));
}

// --- generating various error checks ---
Expand Down Expand Up @@ -1492,8 +1472,8 @@ static std::pair<Value*, bool> emit_isa(jl_codectx_t &ctx, const jl_cgval_t &x,
// so the isa test reduces to a comparison of the typename by pointer
return std::make_pair(
ctx.builder.CreateICmpEQ(
mark_callee_rooted(ctx, emit_datatype_name(ctx, emit_typeof_boxed(ctx, x))),
mark_callee_rooted(ctx, literal_pointer_val(ctx, (jl_value_t*)dt->name))),
emit_datatype_name(ctx, emit_typeof_boxed(ctx, x)),
literal_pointer_val(ctx, (jl_value_t*)dt->name)),
false);
}
if (jl_is_uniontype(intersected_type) &&
Expand Down Expand Up @@ -3396,10 +3376,10 @@ static void emit_cpointercheck(jl_codectx_t &ctx, const jl_cgval_t &x, const std
emit_typecheck(ctx, mark_julia_type(ctx, t, true, jl_any_type), (jl_value_t*)jl_datatype_type, msg);

Value *istype =
ctx.builder.CreateICmpEQ(mark_callee_rooted(ctx, emit_datatype_name(ctx, t)),
mark_callee_rooted(ctx, literal_pointer_val(ctx, (jl_value_t*)jl_pointer_typename)));
BasicBlock *failBB = BasicBlock::Create(ctx.builder.getContext(),"fail",ctx.f);
BasicBlock *passBB = BasicBlock::Create(ctx.builder.getContext(),"pass");
ctx.builder.CreateICmpEQ(emit_datatype_name(ctx, t),
literal_pointer_val(ctx, (jl_value_t*)jl_pointer_typename));
BasicBlock *failBB = BasicBlock::Create(ctx.builder.getContext(), "fail", ctx.f);
BasicBlock *passBB = BasicBlock::Create(ctx.builder.getContext(), "pass");
ctx.builder.CreateCondBr(istype, passBB, failBB);
ctx.builder.SetInsertPoint(failBB);

Expand Down Expand Up @@ -3847,8 +3827,7 @@ static jl_cgval_t emit_new_struct(jl_codectx_t &ctx, jl_value_t *ty, size_t narg

static void emit_signal_fence(jl_codectx_t &ctx)
{
++EmittedSignalFences;
ctx.builder.CreateFence(AtomicOrdering::SequentiallyConsistent, SyncScope::SingleThread);
emit_signal_fence(ctx.builder);
}

static Value *emit_defer_signal(jl_codectx_t &ctx)
Expand All @@ -3861,69 +3840,6 @@ static Value *emit_defer_signal(jl_codectx_t &ctx)
return ctx.builder.CreateInBoundsGEP(ctx.types().T_sigatomic, ptls, ArrayRef<Value*>(offset), "jl_defer_signal");
}

static void emit_gc_safepoint(jl_codectx_t &ctx)
{
ctx.builder.CreateCall(prepare_call(gcroot_flush_func));
emit_signal_fence(ctx);
ctx.builder.CreateLoad(getSizeTy(ctx.builder.getContext()), get_current_signal_page(ctx), true);
emit_signal_fence(ctx);
}

static Value *emit_gc_state_set(jl_codectx_t &ctx, Value *state, Value *old_state)
{
Type *T_int8 = state->getType();
Value *ptls = emit_bitcast(ctx, get_current_ptls(ctx), getInt8PtrTy(ctx.builder.getContext()));
Constant *offset = ConstantInt::getSigned(getInt32Ty(ctx.builder.getContext()), offsetof(jl_tls_states_t, gc_state));
Value *gc_state = ctx.builder.CreateInBoundsGEP(T_int8, ptls, ArrayRef<Value*>(offset), "gc_state");
if (old_state == nullptr) {
old_state = ctx.builder.CreateLoad(T_int8, gc_state);
cast<LoadInst>(old_state)->setOrdering(AtomicOrdering::Monotonic);
}
ctx.builder.CreateAlignedStore(state, gc_state, Align(sizeof(void*)))->setOrdering(AtomicOrdering::Release);
if (auto *C = dyn_cast<ConstantInt>(old_state))
if (C->isZero())
return old_state;
if (auto *C = dyn_cast<ConstantInt>(state))
if (!C->isZero())
return old_state;
BasicBlock *passBB = BasicBlock::Create(ctx.builder.getContext(), "safepoint", ctx.f);
BasicBlock *exitBB = BasicBlock::Create(ctx.builder.getContext(), "after_safepoint", ctx.f);
Constant *zero8 = ConstantInt::get(T_int8, 0);
ctx.builder.CreateCondBr(ctx.builder.CreateAnd(ctx.builder.CreateICmpNE(old_state, zero8), // if (old_state && !state)
ctx.builder.CreateICmpEQ(state, zero8)),
passBB, exitBB);
ctx.builder.SetInsertPoint(passBB);
emit_gc_safepoint(ctx);
ctx.builder.CreateBr(exitBB);
ctx.builder.SetInsertPoint(exitBB);
return old_state;
}

static Value *emit_gc_unsafe_enter(jl_codectx_t &ctx)
{
Value *state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), 0);
return emit_gc_state_set(ctx, state, nullptr);
}

static Value *emit_gc_unsafe_leave(jl_codectx_t &ctx, Value *state)
{
Value *old_state = ConstantInt::get(state->getType(), 0);
return emit_gc_state_set(ctx, state, old_state);
}

//static Value *emit_gc_safe_enter(jl_codectx_t &ctx)
//{
// Value *state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), JL_GC_STATE_SAFE);
// return emit_gc_state_set(ctx, state, nullptr);
//}
//
//static Value *emit_gc_safe_leave(jl_codectx_t &ctx, Value *state)
//{
// Value *old_state = ConstantInt::get(state->getType(), JL_GC_STATE_SAFE);
// return emit_gc_state_set(ctx, state, old_state);
//}



#ifndef JL_NDEBUG
static int compare_cgparams(const jl_cgparams_t *a, const jl_cgparams_t *b)
Expand Down
40 changes: 8 additions & 32 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,10 +583,10 @@ static const auto jlpgcstack_func = new JuliaFunction{
nullptr,
};

static const auto jladoptthread = new JuliaFunction{
XSTR(jl_adopt_thread),
static const auto jladoptthread_func = new JuliaFunction{
"julia.get_pgcstack_or_new",
jlpgcstack_func->_type,
nullptr,
jlpgcstack_func->_attrs,
};


Expand Down Expand Up @@ -1477,11 +1477,9 @@ static Value *global_binding_pointer(jl_codectx_t &ctx, jl_module_t *m, jl_sym_t
static jl_cgval_t emit_checked_var(jl_codectx_t &ctx, Value *bp, jl_sym_t *name, bool isvol, MDNode *tbaa);
static jl_cgval_t emit_sparam(jl_codectx_t &ctx, size_t i);
static Value *emit_condition(jl_codectx_t &ctx, const jl_cgval_t &condV, const std::string &msg);
static void allocate_gc_frame(jl_codectx_t &ctx, BasicBlock *b0);
static Value *get_current_task(jl_codectx_t &ctx);
static Value *get_current_ptls(jl_codectx_t &ctx);
static Value *get_last_age_field(jl_codectx_t &ctx);
static Value *get_current_signal_page(jl_codectx_t &ctx);
static void CreateTrap(IRBuilder<> &irbuilder, bool create_new_block = true);
static CallInst *emit_jlcall(jl_codectx_t &ctx, Function *theFptr, Value *theF,
const jl_cgval_t *args, size_t nargs, JuliaFunction *trampoline);
Expand Down Expand Up @@ -5321,21 +5319,17 @@ JL_GCC_IGNORE_STOP
// --- generate function bodies ---

// gc frame emission
static void allocate_gc_frame(jl_codectx_t &ctx, BasicBlock *b0)
static void allocate_gc_frame(jl_codectx_t &ctx, BasicBlock *b0, bool or_new=false)
{
// allocate a placeholder gc instruction
// this will require the runtime, but it gets deleted later if unused
ctx.topalloca = ctx.builder.CreateCall(prepare_call(jlpgcstack_func));
ctx.topalloca = ctx.builder.CreateCall(prepare_call(or_new ? jladoptthread_func : jlpgcstack_func));
ctx.pgcstack = ctx.topalloca;
}

static Value *get_current_task(jl_codectx_t &ctx)
{
const int ptls_offset = offsetof(jl_task_t, gcstack);
return ctx.builder.CreateInBoundsGEP(
ctx.types().T_pjlvalue, emit_bitcast(ctx, ctx.pgcstack, ctx.types().T_ppjlvalue),
ConstantInt::get(getSizeTy(ctx.builder.getContext()), -(ptls_offset / sizeof(void *))),
"current_task");
return get_current_task_from_pgcstack(ctx.builder, ctx.pgcstack);
}

// Get PTLS through current task.
Expand All @@ -5355,15 +5349,6 @@ static Value *get_last_age_field(jl_codectx_t &ctx)
"world_age");
}

// Get signal page through current task.
static Value *get_current_signal_page(jl_codectx_t &ctx)
{
// return ctx.builder.CreateCall(prepare_call(reuse_signal_page_func));
Value *ptls = get_current_ptls(ctx);
int nthfield = offsetof(jl_tls_states_t, safepoint) / sizeof(void *);
return emit_nthptr_recast(ctx, ptls, nthfield, ctx.tbaa().tbaa_const, getSizePtrTy(ctx.builder.getContext()));
}

static Function *emit_tojlinvoke(jl_code_instance_t *codeinst, Module *M, jl_codegen_params_t &params)
{
++EmittedToJLInvokes;
Expand Down Expand Up @@ -5649,18 +5634,11 @@ static Function* gen_cfun_wrapper(
ctx.builder.SetInsertPoint(b0);
DebugLoc noDbg;
ctx.builder.SetCurrentDebugLocation(noDbg);
allocate_gc_frame(ctx, b0);
allocate_gc_frame(ctx, b0, true);

Value *make_tls = ctx.builder.CreateIsNull(ctx.pgcstack);
ctx.pgcstack = emit_guarded_test(ctx, make_tls, ctx.pgcstack, [&] {
return ctx.builder.CreateCall(prepare_call(jladoptthread), {});
});
Value *world_age_field = get_last_age_field(ctx);
Value *last_age = tbaa_decorate(ctx.tbaa().tbaa_gcframe,
ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()), world_age_field, Align(sizeof(size_t))));
Value *last_gc_state = ConstantInt::get(getInt8Ty(ctx.builder.getContext()), JL_GC_STATE_SAFE);
// if we called jl_adopt_thread, we must end this cfunction back in the safe-state
last_gc_state = ctx.builder.CreateSelect(make_tls, last_gc_state, emit_gc_unsafe_enter(ctx));

Value *world_v = ctx.builder.CreateAlignedLoad(getSizeTy(ctx.builder.getContext()),
prepare_global_in(jl_Module, jlgetworld_global), Align(sizeof(size_t)));
Expand Down Expand Up @@ -6033,8 +6011,6 @@ static Function* gen_cfun_wrapper(
}

ctx.builder.CreateStore(last_age, world_age_field);
if (!sig.retboxed)
emit_gc_unsafe_leave(ctx, last_gc_state);
ctx.builder.CreateRet(r);

ctx.builder.SetCurrentDebugLocation(noDbg);
Expand Down Expand Up @@ -8391,7 +8367,7 @@ static void init_jit_functions(void)
add_named_global(jl_write_barrier_func, (void*)NULL);
add_named_global(jl_write_barrier_binding_func, (void*)NULL);
add_named_global(jldlsym_func, &jl_load_and_lookup);
add_named_global(jladoptthread, &jl_adopt_thread);
add_named_global("jl_adopt_thread", &jl_adopt_thread);
add_named_global(jlgetcfunctiontrampoline_func, &jl_get_cfunction_trampoline);
add_named_global(jlgetnthfieldchecked_func, &jl_get_nth_field_checked);
add_named_global(diff_gc_total_bytes_func, &jl_gc_diff_total_bytes);
Expand Down
Loading

0 comments on commit 11044e7

Please sign in to comment.