From b060dbaab67734ee8953e014fb0b41a6d06881b2 Mon Sep 17 00:00:00 2001 From: Julian P Samaroo Date: Sat, 18 Sep 2021 13:32:03 -0500 Subject: [PATCH] Add threadpool support to runtime --- base/task.jl | 27 ++++++--- base/threadingconstructs.jl | 51 +++++++++++++--- src/jl_exported_data.inc | 1 + src/jl_exported_funcs.inc | 4 ++ src/julia.h | 1 + src/task.c | 10 ++++ src/threading.c | 115 ++++++++++++++++++++++++++++++------ 7 files changed, 173 insertions(+), 36 deletions(-) diff --git a/base/task.jl b/base/task.jl index 0d4e5da4ccfd4a..54780fd1b19e55 100644 --- a/base/task.jl +++ b/base/task.jl @@ -239,7 +239,9 @@ true """ istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed) -Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1) +Threads.threadid(t::Task) = Int(ccall(:jl_get_task_relative_tid, Int16, (Any,), t)+1) +Threads.rawthreadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1) +Threads.threadpoolid(t::Task) = Int(ccall(:jl_get_task_poolid, Int16, (Any,), t)+1) task_result(t::Task) = t.result @@ -599,8 +601,9 @@ function list_deletefirst!(W::InvasiveLinkedListSynchronized{T}, t::T) where T end const StickyWorkqueue = InvasiveLinkedListSynchronized{Task} -global const Workqueues = [StickyWorkqueue()] +global const Workqueues = [StickyWorkqueue()] # default threadpool is first global const Workqueue = Workqueues[1] # default work queue is thread 1 +global const AllWorkqueues = [Workqueues] function __preinit_threads__() if length(Workqueues) < Threads.nthreads() resize!(Workqueues, Threads.nthreads()) @@ -613,7 +616,9 @@ end function enq_work(t::Task) (t._state === task_state_runnable && t.queue === nothing) || error("schedule: Task not runnable") + tp = Threads.threadpoolid(t) tid = Threads.threadid(t) + _tid = Threads.rawthreadid(t) # Note there are three reasons a Task might be put into a sticky queue # even if t.sticky == false: # 1. The Task's stack is currently being used by the scheduler for a certain thread. @@ -627,23 +632,27 @@ function enq_work(t::Task) # set it to be sticky. # XXX: Ideally we would be able to unset this current_task().sticky = true + tp = Threads.threadpoolid() tid = Threads.threadid() - ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1) + _tid = Threads.rawthreadid() + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, _tid-1) end - push!(Workqueues[tid], t) + push!(AllWorkqueues[tp][tid], t) else if ccall(:jl_enqueue_task, Cint, (Any,), t) != 0 # if multiq is full, give to a random thread (TODO fix) if tid == 0 + tp = Threads.threadpoolid() tid = mod(time_ns() % Int, Threads.nthreads()) + 1 - ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1) + ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, _tid-1) + _tid = Threads.rawthreadid(t) end - push!(Workqueues[tid], t) + push!(AllWorkqueues[tp][tid], t) else - tid = 0 + _tid = 0 end end - ccall(:jl_wakeup_thread, Cvoid, (Int16,), (tid - 1) % Int16) + ccall(:jl_wakeup_thread, Cvoid, (Int16,), (_tid - 1) % Int16) return t end @@ -819,7 +828,7 @@ end function wait() GC.safepoint() - W = Workqueues[Threads.threadid()] + W = AllWorkqueues[Threads.threadpoolid()][Threads.threadid()] poptask(W) result = try_yieldto(ensure_rescheduled) process_events() diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index e66af69b3e82f0..5df58ad500a3ed 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -3,24 +3,59 @@ export threadid, nthreads, @threads, @spawn """ - Threads.threadid() + Threads.threadid() -> Int -Get the ID number of the current thread of execution. The master thread has ID `1`. +Get the ID number of the current thread of execution within the current +threadpool. The master thread has ID `1`. """ -threadid() = Int(ccall(:jl_threadid, Int16, ())+1) +threadid() = Int(ccall(:jl_relative_threadid, Int16, ())+1) -# Inclusive upper bound on threadid() """ - Threads.nthreads() + Threads.threadid() -> Int -Get the number of threads available to the Julia process. This is the inclusive upper bound -on [`threadid()`](@ref). +Get the ID number of the current thread of execution within the Julia session. +The master thread has ID `1`. +""" +rawthreadid() = Int(ccall(:jl_threadid, Int16, ())+1) + +""" + Threads.nthreads(tp::Int=Threads.threadpoolid()) -> Int + +Get the number of threads available in the specified threadpool. This is the +inclusive upper bound on [`threadid()`](@ref). See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the [`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the [`Distributed`](@ref man-distributed) standard library. """ -nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint))) +nthreads(tp::Int=Threads.threadpoolid()) = + ccall(:jl_num_threads, Cint, (Cint,), tp-1) + +""" + Threads.spawn_threadpool(n::Int) -> Int + +Spawns a new threadpool of size `n`, and returns the threadpool ID. +""" +function spawn_threadpool(n::Int) + wq = [Base.StickyWorkqueue() for _ in 1:n] + push!(Base.AllWorkqueues, wq) + return ccall(:jl_start_threads_dedicated, Cint, (Cint,), n) +end + +""" + Threads.threadpoolid() -> Int + +Returns the threadpool ID that the current thread resides in. The default +threadpool has ID `1`. +""" +threadpoolid() = ccall(:jl_threadpoolid, Cint, ())+1 + +""" + Threads.npools() -> Int + +Returns the number of threadpools currently configured. +""" +npools() = Int(unsafe_load(cglobal(:jl_threadpools, Cint))) function threading_run(func) ccall(:jl_enter_threaded_region, Cvoid, ()) diff --git a/src/jl_exported_data.inc b/src/jl_exported_data.inc index b8d5ae0e35b29f..13c9c2bc00c735 100644 --- a/src/jl_exported_data.inc +++ b/src/jl_exported_data.inc @@ -94,6 +94,7 @@ XX(jl_string_type) \ XX(jl_symbol_type) \ XX(jl_task_type) \ + XX(jl_threadpools) \ XX(jl_top_module) \ XX(jl_true) \ XX(jl_tuple_typename) \ diff --git a/src/jl_exported_funcs.inc b/src/jl_exported_funcs.inc index 877c603c7ac3ed..1c52f3568b6c4f 100644 --- a/src/jl_exported_funcs.inc +++ b/src/jl_exported_funcs.inc @@ -261,6 +261,8 @@ XX(jl_get_root_symbol) \ XX(jl_get_safe_restore) \ XX(jl_get_size) \ + XX(jl_get_task_poolid) \ + XX(jl_get_task_relative_tid) \ XX(jl_get_task_tid) \ XX(jl_gettimeofday) \ XX(jl_get_tls_world_age) \ @@ -458,6 +460,7 @@ XX(jl_spawn) \ XX(jl_specializations_get_linfo) \ XX(jl_specializations_lookup) \ + XX(jl_start_threads_dedicated) \ XX(jl_static_show) \ XX(jl_static_show_func_sig) \ XX(jl_stderr_obj) \ @@ -496,6 +499,7 @@ XX(jl_test_cpu_feature) \ XX(jl_threadid) \ XX(jl_threading_enabled) \ + XX(jl_threadpoolid) \ XX(jl_throw) \ XX(jl_throw_out_of_memory_error) \ XX(jl_too_few_args) \ diff --git a/src/julia.h b/src/julia.h index e53b33bef674da..8ae409c6a3dea6 100644 --- a/src/julia.h +++ b/src/julia.h @@ -1583,6 +1583,7 @@ JL_DLLEXPORT jl_sym_t *jl_get_UNAME(void) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_sym_t *jl_get_ARCH(void) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_get_libllvm(void) JL_NOTSAFEPOINT; extern JL_DLLIMPORT int jl_n_threads; +extern JL_DLLIMPORT int jl_threadpools; // environment entries JL_DLLEXPORT jl_value_t *jl_environ(int i); diff --git a/src/task.c b/src/task.c index 88d4eac0863c9a..e86619a6ae69fc 100644 --- a/src/task.c +++ b/src/task.c @@ -1353,6 +1353,16 @@ JL_DLLEXPORT int16_t jl_get_task_tid(jl_task_t *t) JL_NOTSAFEPOINT return t->tid; } +JL_DLLEXPORT int16_t jl_get_task_relative_tid(jl_task_t *t) JL_NOTSAFEPOINT +{ + return jl_tid_to_relative(t->tid); +} + +JL_DLLEXPORT int16_t jl_get_task_poolid(jl_task_t *t) JL_NOTSAFEPOINT +{ + return jl_tid_to_poolid(t->tid); +} + #ifdef _OS_WINDOWS_ #if defined(_CPU_X86_) diff --git a/src/threading.c b/src/threading.c index ffe53c07b45ee3..74a5bf3fb1085a 100644 --- a/src/threading.c +++ b/src/threading.c @@ -287,17 +287,46 @@ void jl_pgcstack_getkey(jl_get_pgcstack_func **f, jl_pgcstack_key_t *k) #endif jl_ptls_t *jl_all_tls_states JL_GLOBALLY_ROOTED; +int *jl_threadpool_map; +int *jl_threadpool_sizes; uint8_t jl_measure_compile_time_enabled = 0; uint64_t jl_cumulative_compile_time = 0; -// return calling thread's ID -// Also update the suspended_threads list in signals-mach when changing the -// type of the thread id. +// return calling thread's absolute ID JL_DLLEXPORT int16_t jl_threadid(void) { return jl_current_task->tid; } +// return calling thread's relative ID (with respect to its threadpool) +int16_t jl_tid_to_relative(int16_t rawtid) +{ + if (rawtid < 0) + return rawtid; + int poolid = jl_threadpool_map[rawtid]; + int tp_offset = 0; + for (int tp = 0; tp < poolid; tp++) { + tp_offset += jl_threadpool_sizes[tp]; + } + return rawtid - tp_offset; +} +JL_DLLEXPORT int16_t jl_relative_threadid(void) +{ + return jl_tid_to_relative(jl_current_task->tid); +} + +int16_t jl_tid_to_poolid(int16_t tid) +{ + int tp_offset = 0; + for (int tp = 0; tp < jl_threadpools; tp++) { + int tp_size = jl_threadpool_sizes[tp]; + if (tp_offset + tp_size > tid) + return tp; + tp_offset += tp_size; + } + return -1; +} + jl_ptls_t jl_init_threadtls(int16_t tid) { jl_ptls_t ptls = (jl_ptls_t)calloc(1, sizeof(jl_tls_states_t)); @@ -467,18 +496,61 @@ void jl_init_threading(void) } if (jl_n_threads <= 0) jl_n_threads = 1; + int jl_extra_threads = 8; // FIXME: ENV[NUM_EXTRA_THREADS_NAME] + int jl_max_threadpools = 8; // FIXME: ENV[NUM_THREADPOOLS_NAME] #ifndef __clang_analyzer__ - jl_all_tls_states = (jl_ptls_t*)calloc(jl_n_threads, sizeof(void*)); + jl_all_tls_states = (jl_ptls_t*)calloc(jl_n_threads+jl_extra_threads, sizeof(void*)); #endif + jl_threadpools = 0; + jl_threadpool_map = (int*)calloc(jl_max_threadpools, sizeof(int)); + jl_threadpool_sizes = (int*)calloc(jl_max_threadpools, sizeof(int)); } static uv_barrier_t thread_init_done; +int jl_start_threads_(size_t nthreads, size_t cur_n_threads, uv_barrier_t *barrier, int exclusive) +{ + int cpumasksize = uv_cpumask_size(); + if (cpumasksize < jl_n_threads) // also handles error case + cpumasksize = jl_n_threads; + char *mask = (char*)alloca(cpumasksize); + int i; + uv_thread_t uvtid; + + int tp = jl_threadpools; + jl_threadpools++; + if (tp > 0) { + jl_threadpool_sizes[tp] = nthreads; + } else { + jl_threadpool_sizes[tp] = jl_n_threads; + } + + uv_barrier_init(barrier, nthreads); + + for (i = cur_n_threads; i < cur_n_threads+nthreads; ++i) { + jl_threadarg_t *t = (jl_threadarg_t*)malloc_s(sizeof(jl_threadarg_t)); // ownership will be passed to the thread + t->tid = i; + t->barrier = barrier; + uv_thread_create(&uvtid, jl_threadfun, t); + jl_threadpool_map[i] = tp; + if (exclusive) { + mask[i] = 1; + uv_thread_setaffinity(&uvtid, mask, NULL, cpumasksize); + mask[i] = 0; + } + uv_thread_detach(&uvtid); + } + + uv_barrier_wait(barrier); + + return tp; +} + void jl_start_threads(void) { int cpumasksize = uv_cpumask_size(); char *cp; - int i, exclusive; + int exclusive; uv_thread_t uvtid; if (cpumasksize < jl_n_threads) // also handles error case cpumasksize = jl_n_threads; @@ -509,22 +581,27 @@ void jl_start_threads(void) size_t nthreads = jl_n_threads; // create threads - uv_barrier_init(&thread_init_done, nthreads); + jl_start_threads_(nthreads-1, 1, &thread_init_done, exclusive); +} - for (i = 1; i < nthreads; ++i) { - jl_threadarg_t *t = (jl_threadarg_t*)malloc_s(sizeof(jl_threadarg_t)); // ownership will be passed to the thread - t->tid = i; - t->barrier = &thread_init_done; - uv_thread_create(&uvtid, jl_threadfun, t); - if (exclusive) { - mask[i] = 1; - uv_thread_setaffinity(&uvtid, mask, NULL, cpumasksize); - mask[i] = 0; - } - uv_thread_detach(&uvtid); - } +JL_DLLEXPORT int jl_start_threads_dedicated(size_t nthreads, int exclusive) +{ + uv_barrier_t tbar; + return jl_start_threads_(nthreads, jl_n_threads, &tbar, exclusive); +} + +JL_DLLEXPORT int jl_num_threads(int tp) +{ + if (jl_threadpools == 0) + // Pre-init + return jl_n_threads; + assert(tp < jl_threadpools); + return jl_threadpool_sizes[tp]; +} - uv_barrier_wait(&thread_init_done); +JL_DLLEXPORT int jl_threadpoolid(void) +{ + return jl_threadpool_map[jl_current_task->tid]; } unsigned volatile _threadedregion; // HACK: keep track of whether it is safe to do IO