From 13ec3ceb57c5392d0e556d1d35b26966c6b8f16d Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Mon, 21 Aug 2023 09:06:00 -0400 Subject: [PATCH] Separate foreign threads into a :foreign threadpool (#50912) Co-authored-by: Gabriel Baraldi Co-authored-by: Dilum Aluthge (cherry picked from commit 8be469e275a455ca894fdc5fad8a80aafb359544) --- base/partr.jl | 7 +++++++ base/task.jl | 2 +- base/threadingconstructs.jl | 14 ++++++++++---- src/partr.c | 2 +- src/threading.c | 2 +- 5 files changed, 20 insertions(+), 7 deletions(-) diff --git a/base/partr.jl b/base/partr.jl index a02272ceab202..c77a24bdcc003 100644 --- a/base/partr.jl +++ b/base/partr.jl @@ -95,6 +95,7 @@ end function multiq_insert(task::Task, priority::UInt16) tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), task) + @assert tpid > -1 heap_p = multiq_size(tpid) tp = tpid + 1 @@ -131,6 +132,9 @@ function multiq_deletemin() tid = Threads.threadid() tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1 + if tp == 0 # Foreign thread + return nothing + end tpheaps = heaps[tp] @label retry @@ -182,6 +186,9 @@ end function multiq_check_empty() tid = Threads.threadid() tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1 + if tp == 0 # Foreign thread + return true + end for i = UInt32(1):length(heaps[tp]) if heaps[tp][i].ntasks != 0 return false diff --git a/base/task.jl b/base/task.jl index db2f7e22bce67..137b0f7c4a3f6 100644 --- a/base/task.jl +++ b/base/task.jl @@ -794,7 +794,7 @@ function enq_work(t::Task) else @label not_sticky tp = Threads.threadpool(t) - if Threads.threadpoolsize(tp) == 1 + if tp === :foreign || Threads.threadpoolsize(tp) == 1 # There's only one thread in the task's assigned thread pool; # use its work queue. tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1 diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 7a70132a9dccc..a5a1294be049b 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -63,6 +63,8 @@ function _tpid_to_sym(tpid::Int8) return :interactive elseif tpid == 1 return :default + elseif tpid == -1 + return :foreign else throw(ArgumentError("Unrecognized threadpool id $tpid")) end @@ -73,6 +75,8 @@ function _sym_to_tpid(tp::Symbol) return Int8(0) elseif tp === :default return Int8(1) + elseif tp == :foreign + return Int8(-1) else throw(ArgumentError("Unrecognized threadpool name `$(repr(tp))`")) end @@ -81,7 +85,7 @@ end """ Threads.threadpool(tid = threadid()) -> Symbol -Returns the specified thread's threadpool; either `:default` or `:interactive`. +Returns the specified thread's threadpool; either `:default`, `:interactive`, or `:foreign`. """ function threadpool(tid = threadid()) tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) @@ -108,6 +112,8 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the function threadpoolsize(pool::Symbol = :default) if pool === :default || pool === :interactive tpid = _sym_to_tpid(pool) + elseif pool == :foreign + error("Threadpool size of `:foreign` is indeterminant") else error("invalid threadpool specified") end @@ -151,7 +157,7 @@ function threading_run(fun, static) else # TODO: this should be the current pool (except interactive) if there # are ever more than two pools. - ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default)) + @assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, _sym_to_tpid(:default)) == 1 end tasks[i] = t schedule(t) @@ -357,10 +363,10 @@ end function _spawn_set_thrpool(t::Task, tp::Symbol) tpid = _sym_to_tpid(tp) - if _nthreads_in_pool(tpid) == 0 + if tpid == -1 || _nthreads_in_pool(tpid) == 0 tpid = _sym_to_tpid(:default) end - ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid) + @assert ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid) == 1 nothing end diff --git a/src/partr.c b/src/partr.c index 428389db7f218..0f3b581f5122f 100644 --- a/src/partr.c +++ b/src/partr.c @@ -70,7 +70,7 @@ JL_DLLEXPORT int jl_set_task_tid(jl_task_t *task, int16_t tid) JL_NOTSAFEPOINT JL_DLLEXPORT int jl_set_task_threadpoolid(jl_task_t *task, int8_t tpid) JL_NOTSAFEPOINT { - if (tpid < 0 || tpid >= jl_n_threadpools) + if (tpid < -1 || tpid >= jl_n_threadpools) return 0; task->threadpoolid = tpid; return 1; diff --git a/src/threading.c b/src/threading.c index e2eb686e3061a..4faa8a0a2dc46 100644 --- a/src/threading.c +++ b/src/threading.c @@ -332,7 +332,7 @@ JL_DLLEXPORT int8_t jl_threadpoolid(int16_t tid) JL_NOTSAFEPOINT if (tid < n) return (int8_t)i; } - return 0; // everything else uses threadpool 0 (though does not become part of any threadpool) + return -1; // everything else uses threadpool -1 (does not belong to any threadpool) } jl_ptls_t jl_init_threadtls(int16_t tid)