Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wasm] Jiterpreter multithreading support #88279

Merged
merged 8 commits into from
Aug 15, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 71 additions & 22 deletions src/mono/mono/mini/interp/interp.c
Original file line number Diff line number Diff line change
Expand Up @@ -2659,23 +2659,25 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame

// We reuse do_jit_call's epilogue to do things like propagate thrown exceptions
// and sign-extend return values instead of inlining that logic into every thunk
goto epilogue;
// the dummy implementation sets a special value into thrown to indicate that
// we need to go through the slow path because this thread has no thunk yet
if (G_UNLIKELY (thrown == 999))
thrown = 0;
else
goto epilogue;
} else {
// FIXME: thread safety for the hit count
int count = cinfo->hit_count;
int old_count = mono_jiterp_increment_counter (&cinfo->hit_count);
// If our hit count just reached the threshold, we request that a thunk be jitted
// for this specific call site. It will go into a queue and wait until there
// are enough jit calls waiting to be compiled into one WASM module
if (count == mono_opt_jiterpreter_jit_call_trampoline_hit_count) {
if (old_count == mono_opt_jiterpreter_jit_call_trampoline_hit_count) {
mono_interp_jit_wasm_jit_call_trampoline (
rmethod->method, rmethod, cinfo,
initialize_arg_offsets(rmethod, mono_method_signature_internal (rmethod->method)),
mono_aot_mode == MONO_AOT_MODE_LLVMONLY_INTERP
);
} else {
int excess = count - mono_opt_jiterpreter_jit_call_queue_flush_threshold;
if (excess <= 0)
cinfo->hit_count++;
int excess = old_count - mono_opt_jiterpreter_jit_call_queue_flush_threshold;
// If our hit count just reached the flush threshold, that means that we
// previously requested compilation for this call site and it didn't
// happen yet. We will request a flush of the entire queue this one
Expand Down Expand Up @@ -2736,7 +2738,7 @@ do_jit_call (ThreadContext *context, stackval *ret_sp, stackval *sp, InterpFrame
* this is faster than mono_llvm_cpp_catch_exception by avoiding the use of
* emscripten invoke_vi to find and invoke jit_call_cb indirectly
*/
jiterpreter_do_jit_call(jit_call_cb, &cb_data, &thrown);
jiterpreter_do_jit_call (jit_call_cb, &cb_data, &thrown);
#else
/* Catch the exception thrown by the native code using a try-catch */
mono_llvm_cpp_catch_exception (jit_call_cb, &cb_data, &thrown);
Expand Down Expand Up @@ -7670,14 +7672,12 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;

#ifdef HOST_BROWSER
MINT_IN_CASE(MINT_TIER_NOP_JITERPRETER) {
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_PREPARE_JITERPRETER) {
if (mono_opt_jiterpreter_traces_enabled) {
// We may lose a race with another thread here so we need to use volatile and be careful
volatile guint16 *mutable_ip = (volatile guint16*)ip;
/*
* prepare_jiterpreter will update the trace's hit count and potentially either JIT it or
* disable this entry point based on whether it fails to JIT. the hit counting is necessary
Expand All @@ -7690,22 +7690,21 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* JS workers in order to register them at the appropriate slots in the function pointer
* table. when growing the function pointer table we will also need to synchronize that.
*/
JiterpreterThunk prepare_result = mono_interp_tier_prepare_jiterpreter_fast(frame, ip);
JiterpreterThunk prepare_result = mono_interp_tier_prepare_jiterpreter_fast (frame, ip);
ptrdiff_t offset;
switch ((guint32)(void*)prepare_result) {
case JITERPRETER_TRAINING:
// jiterpreter still updating hit count before deciding to generate a trace,
// so skip this opcode.
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
break;
case JITERPRETER_NOT_JITTED:
// Patch opcode to disable it because this trace failed to JIT.
if (!mono_opt_jiterpreter_estimate_heat) {
mono_memory_barrier ();
*mutable_ip = MINT_TIER_NOP_JITERPRETER;
mono_memory_barrier ();
if (!mono_jiterp_patch_opcode ((volatile JiterpreterOpcode *)ip, MINT_TIER_PREPARE_JITERPRETER, MINT_TIER_NOP_JITERPRETER))
g_printf ("Failed to patch opcode at %x into a nop\n", (unsigned int)ip);
}
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
break;
default:
/*
Expand All @@ -7716,16 +7715,17 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
* (note that right now threading doesn't work, but it's worth being correct
* here so that implementing thread support will be easier later.)
*/
*mutable_ip = MINT_TIER_MONITOR_JITERPRETER;
if (!mono_jiterp_patch_opcode ((volatile JiterpreterOpcode *)ip, MINT_TIER_PREPARE_JITERPRETER, MINT_TIER_MONITOR_JITERPRETER))
g_printf ("Failed to patch opcode at %x into a monitor point\n", (unsigned int)ip);
// now execute the trace
// this isn't important for performance, but it makes it easier to use the
// jiterpreter early in automated tests where code only runs once
offset = prepare_result(frame, locals, &jiterpreter_call_info);
offset = prepare_result (frame, locals, &jiterpreter_call_info, ip);
ip = (guint16*) (((guint8*)ip) + offset);
break;
}
} else {
ip += 3;
ip += JITERPRETER_OPCODE_SIZE;
}

MINT_IN_BREAK;
Expand All @@ -7741,8 +7741,9 @@ MINT_IN_CASE(MINT_BRTRUE_I8_SP) ZEROP_SP(gint64, !=); MINT_IN_BREAK;
}

MINT_IN_CASE(MINT_TIER_ENTER_JITERPRETER) {
JiterpreterThunk thunk = (void*)READ32(ip + 1);
ptrdiff_t offset = thunk(frame, locals, &jiterpreter_call_info);
// The fn ptr is encoded in a guint16 relative to the index of the first trace fn ptr, so compute the actual ptr
JiterpreterThunk thunk = (JiterpreterThunk)(void *)(((JiterpreterOpcode *)ip)->relative_fn_ptr + mono_jiterp_first_trace_fn_ptr);
ptrdiff_t offset = thunk (frame, locals, &jiterpreter_call_info, ip);
ip = (guint16*) (((guint8*)ip) + offset);
MINT_IN_BREAK;
}
Expand Down Expand Up @@ -8900,4 +8901,52 @@ mono_jiterp_get_opcode_info (int opcode, int type)
}
}

EMSCRIPTEN_KEEPALIVE int
mono_jiterp_dummy_trace (void *_frame, void *pLocals, JiterpreterCallInfo *cinfo, const guint16 *ip)
kg marked this conversation as resolved.
Show resolved Hide resolved
{
// If this is hit it most likely indicates that a trace is being invoked from a thread
// that has not jitted it yet. We want to jit it on this thread and install it at the
// correct location in the function pointer table.
const JiterpreterOpcode *opcode = (const JiterpreterOpcode *)ip;
if (opcode->relative_fn_ptr) {
int fn_ptr = opcode->relative_fn_ptr + mono_jiterp_first_trace_fn_ptr;
InterpFrame *frame = _frame;
MonoMethod *method = frame->imethod->method;
const guint16 *start_of_body = frame->imethod->jinfo->code_start;
int size_of_body = frame->imethod->jinfo->code_size;
// g_printf ("mono_jiterp_dummy_trace index=%d fn_ptr=%d ip=%x\n", opcode->trace_index, fn_ptr, ip);
mono_interp_tier_prepare_jiterpreter (
frame, method, ip, (gint32)opcode->trace_index,
start_of_body, size_of_body, frame->imethod->is_verbose,
fn_ptr
);
}
// advance past the enter/monitor opcode and return to interp
return mono_interp_oplen [MINT_TIER_ENTER_JITERPRETER] * 2;
}

EMSCRIPTEN_KEEPALIVE void
mono_jiterp_dummy_jit_call (void *ret_sp, void *sp, void *ftndesc, gboolean *thrown)
{
// g_print ("mono_jiterp_dummy_jit_call\n");
*thrown = 999;
}

EMSCRIPTEN_KEEPALIVE void *
mono_jiterp_get_interp_entry_func (int table)
{
g_assert (table <= JITERPRETER_TABLE_LAST);

if (table >= JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_RET_0)
return entry_funcs_instance_ret [table - JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_RET_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_0)
return entry_funcs_instance [table - JITERPRETER_TABLE_INTERP_ENTRY_INSTANCE_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_STATIC_RET_0)
return entry_funcs_static_ret [table - JITERPRETER_TABLE_INTERP_ENTRY_STATIC_RET_0];
else if (table >= JITERPRETER_TABLE_INTERP_ENTRY_STATIC_0)
return entry_funcs_static [table - JITERPRETER_TABLE_INTERP_ENTRY_STATIC_0];
else
g_assert_not_reached ();
}

#endif
Loading
Loading