Skip to content

Commit

Permalink
[wasm] Jiterpreter multithreading support (#88279)
Browse files Browse the repository at this point in the history
* Enable jiterpreter traces in multithreaded builds
* Pre-allocate fixed number of function table entries populated with placeholder functions at startup, for traces, interp_entry and jit_call (size configurable with runtime options)
* Change layout of all jiterpreter opcodes to 4 ushorts: [opcode] [relative_fn_ptr] [trace_index_low] [trace_index_high]
* Jiterpreter opcode access is done through a packed struct instead of raw byte manipulation
* Update most jiterpreter state management to use atomics (via stdatomic.h)
* Move some jiterpreter state, like counters, from TS to shared storage in C that is managed with atomics
* Patch emscripten libraries to expose a reliable getter for the wasm memory object
* Partial thread safety for jiterpreter jit_call and interp_entry wrappers
* Removed trace transfer optimization because it's no longer useful
  • Loading branch information
kg authored Aug 15, 2023
1 parent c5aba1b commit 333c2c7
Show file tree
Hide file tree
Showing 20 changed files with 922 additions and 491 deletions.
93 changes: 71 additions & 22 deletions src/mono/mono/mini/interp/interp.c
Original file line number Diff line number Diff line change
Expand Up @@ -2659,23 +2659,25 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame

// We reuse do_jit_call's epilogue to do things like propagate thrown exceptions
// and sign-extend return values instead of inlining that logic into every thunk
goto epilogue;
// the dummy implementation sets a special value into thrown to indicate that
// we need to go through the slow path because this thread has no thunk yet
if (G_UNLIKELY (thrown == 999))
thrown = 0;
else
goto epilogue;
} else {
// FIXME: thread safety for the hit count
int count = cinfo->hit_count;
int old_count = mono_jiterp_increment_counter (&cinfo->hit_count);
// If our hit count just reached the threshold, we request that a thunk be jitted
// for this specific call site. It will go into a queue and wait until there
// are enough jit calls waiting to be compiled into one WASM module
if (count == mono_opt_jiterpreter_jit_call_trampoline_hit_count) {
if (old_count == mono_opt_jiterpreter_jit_call_trampoline_hit_count) {
mono_interp_jit_wasm_jit_call_trampoline (
rmethod->method, rmethod, cinfo,
initialize_arg_offsets(rmethod, mono_method_signature_internal (rmethod->method)),
mono_aot_mode == MONO_AOT_MODE_LLVMONLY_INTERP
);
} else {
int excess = count - mono_opt_jiterpreter_jit_call_queue_flush_threshold;
if (excess <= 0)
cinfo->hit_count++;
int excess = old_count - mono_opt_jiterpreter_jit_call_queue_flush_threshold;
// If our hit count just reached the flush threshold, that means that we
// previously requested compilation for this call site and it didn't
// happen yet. We will request a flush of the entire queue this one
Expand Down Expand Up @@ -2736,7 +2738,7 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame
* this is faster than mono_llvm_cpp_catch_exception by avoiding the use of
* emscripten invoke_vi to find and invoke jit_call_cb indirectly
*/
jiterpreter_do_jit_call(jit_call_cb, &cb_data, &thrown);
jiterpreter_do_jit_call (jit_call_cb, &cb_data, &thrown);
#else
/* Catch the exception thrown by the native code using a try-catch */
mono_llvm_cpp_catch_exception (jit_call_cb, &cb_data, &thrown);
Expand Down Expand Up @@ -7702,14 +7704,12 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;

#ifdef HOST_BROWSER
MINT_IN_CASE(MINT_TIER_NOP_JITERPRETER) {
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_PREPARE_JITERPRETER) {
if (mono_opt_jiterpreter_traces_enabled) {
// We may lose a race with another thread here so we need to use volatile and be careful
volatile guint16 *mutable_ip = (volatile guint16*)ip;
/*
* prepare_jiterpreter will update the trace's hit count and potentially either JIT it or
* disable this entry point based on whether it fails to JIT. the hit counting is necessary
Expand All @@ -7722,22 +7722,21 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* JS workers in order to register them at the appropriate slots in the function pointer
* table. when growing the function pointer table we will also need to synchronize that.
*/
JiterpreterThunk prepare_result = mono_interp_tier_prepare_jiterpreter_fast(frame, ip);
JiterpreterThunk prepare_result = mono_interp_tier_prepare_jiterpreter_fast (frame, ip);
ptrdiff_t offset;
switch ((guint32)(void*)prepare_result) {
case JITERPRETER_TRAINING:
// jiterpreter still updating hit count before deciding to generate a trace,
// so skip this opcode.
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
break;
case JITERPRETER_NOT_JITTED:
// Patch opcode to disable it because this trace failed to JIT.
if (!mono_opt_jiterpreter_estimate_heat) {
mono_memory_barrier ();
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
mono_memory_barrier ();
if (!mono_jiterp_patch_opcode ((volatile JiterpreterOpcode *)ip, MINT_TIER_PREPARE_JITERPRETER, MINT_TIER_NOP_JITERPRETER))
g_printf ("Failed to patch opcode at %x into a nop\n", (unsigned int)ip);
}
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
break;
default:
/*
Expand All @@ -7748,16 +7747,17 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* (note that right now threading doesn't work, but it's worth being correct
* here so that implementing thread support will be easier later.)
*/
*mutable_ip = MINT_TIER_MONITOR_JITERPRETER;
if (!mono_jiterp_patch_opcode ((volatile JiterpreterOpcode *)ip, MINT_TIER_PREPARE_JITERPRETER, MINT_TIER_MONITOR_JITERPRETER))
g_printf ("Failed to patch opcode at %x into a monitor point\n", (unsigned int)ip);
// now execute the trace
// this isn't important for performance, but it makes it easier to use the
// jiterpreter early in automated tests where code only runs once
offset = prepare_result(frame, locals, &jiterpreter_call_info);
offset = prepare_result (frame, locals, &jiterpreter_call_info, ip);
ip = (guint16*) (((guint8*)ip) + offset);
break;
}
} else {
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
}

MINT_IN_BREAK;
Expand All @@ -7773,8 +7773,9 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
JiterpreterThunk thunk = (void*)READ32(ip + 1);
ptrdiff_t offset = thunk(frame, locals, &jiterpreter_call_info);
// The fn ptr is encoded in a guint16 relative to the index of the first trace fn ptr, so compute the actual ptr
JiterpreterThunk thunk = (JiterpreterThunk)(void *)(((JiterpreterOpcode *)ip)->relative_fn_ptr + mono_jiterp_first_trace_fn_ptr);
ptrdiff_t offset = thunk (frame, locals, &jiterpreter_call_info, ip);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}
Expand Down Expand Up @@ -8957,4 +8958,52 @@ mono_jiterp_get_opcode_info (int opcode, int type)
}
}

EMSCRIPTEN_KEEPALIVE int
mono_jiterp_placeholder_trace (void *_frame, void *pLocals, JiterpreterCallInfo *cinfo, const guint16 *ip)
{
// If this is hit it most likely indicates that a trace is being invoked from a thread
// that has not jitted it yet. We want to jit it on this thread and install it at the
// correct location in the function pointer table.
const JiterpreterOpcode *opcode = (const JiterpreterOpcode *)ip;
if (opcode->relative_fn_ptr) {
int fn_ptr = opcode->relative_fn_ptr + mono_jiterp_first_trace_fn_ptr;
InterpFrame *frame = _frame;
MonoMethod *method = frame->imethod->method;
const guint16 *start_of_body = frame->imethod->jinfo->code_start;
int size_of_body = frame->imethod->jinfo->code_size;
// g_printf ("mono_jiterp_placeholder_trace index=%d fn_ptr=%d ip=%x\n", opcode->trace_index, fn_ptr, ip);
mono_interp_tier_prepare_jiterpreter (
frame, method, ip, (gint32)opcode->trace_index,
start_of_body, size_of_body, frame->imethod->is_verbose,
fn_ptr
);
}
// advance past the enter/monitor opcode and return to interp
return mono_interp_oplen [MINT_TIER_ENTER_JITERPRETER] * 2;
}

EMSCRIPTEN_KEEPALIVE void
mono_jiterp_placeholder_jit_call (void *ret_sp, void *sp, void *ftndesc, gboolean *thrown)
{
// g_print ("mono_jiterp_placeholder_jit_call\n");
*thrown = 999;
}

EMSCRIPTEN_KEEPALIVE void *
mono_jiterp_get_interp_entry_func (int table)
{
g_assert (table <= JITERPRETER_TABLE_LAST);

if (table >= JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_RET_0)
return entry_funcs_instance_ret [table - JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_RET_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_0)
return entry_funcs_instance [table - JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_STATIC_RET_0)
return entry_funcs_static_ret [table - JITERPRETER_TABLE_INTERP_ENTRY_STATIC_RET_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_STATIC_0)
return entry_funcs_static [table - JITERPRETER_TABLE_INTERP_ENTRY_STATIC_0];
else
g_assert_not_reached ();
}

#endif
Loading

0 comments on commit 333c2c7

Please sign in to comment.