From c67ed24abd8f56fbebc158924b7fd42a943f9eba Mon Sep 17 00:00:00 2001 From: Yichao Yu Date: Thu, 20 Apr 2017 11:30:28 -0400 Subject: [PATCH] Band aid to make threaded loop a little easier to work with * Print a warning if an error occurs in the threaded loop (Helps #17532) * Make recursive threaded loops "work" (Fix #18335). The proper fix will be tracked by #21017 --- base/threadingconstructs.jl | 29 +++++++++++++++++++++++------ src/threading.c | 15 ++++++++++++++- test/threads.jl | 17 +++++++++++++++++ 3 files changed, 54 insertions(+), 7 deletions(-) diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 678e9c3d92c3b0..da47caead6f177 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -18,16 +18,24 @@ on `threadid()`. """ nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint))) +# Only read/write by the main thread +const in_threaded_loop = Ref(false) + function _threadsfor(iter,lbody) - fun = gensym("_threadsfor") lidx = iter.args[1] # index range = iter.args[2] quote - function $fun() - tid = threadid() + function threadsfor_fun(onethread=false) r = $(esc(range)) + lenr = length(r) # divide loop iterations among threads - len, rem = divrem(length(r), nthreads()) + if onethread + tid = 1 + len, rem = lenr, 0 + else + tid = threadid() + len, rem = divrem(lenr, nthreads()) + end # not enough iterations for all the threads? if len == 0 if tid > rem @@ -54,7 +62,17 @@ function _threadsfor(iter,lbody) $(esc(lbody)) end end - ccall(:jl_threading_run, Ref{Void}, (Any,), $fun) + # Hack to make nested threaded loops kinda work + if threadid() != 1 || in_threaded_loop[] + # We are in a nested threaded loop + threadsfor_fun(true) + else + in_threaded_loop[] = true + # the ccall is not expected to throw + ccall(:jl_threading_run, Ref{Void}, (Any,), threadsfor_fun) + in_threaded_loop[] = false + end + nothing end end """ @@ -80,4 +98,3 @@ macro threads(args...) throw(ArgumentError("unrecognized argument to @threads")) end end - diff --git a/src/threading.c b/src/threading.c index 81a47b55977e6f..f2fdaf96eb3e6e 100644 --- a/src/threading.c +++ b/src/threading.c @@ -316,7 +316,20 @@ static jl_value_t *ti_run_fun(const jl_generic_fptr_t *fptr, jl_method_instance_ jl_call_fptr_internal(fptr, mfunc, args, nargs); } JL_CATCH { - return ptls->exception_in_transit; + // Lock this output since we know it'll likely happen on multiple threads + static jl_mutex_t lock; + JL_LOCK_NOGC(&lock); + jl_jmp_buf *old_buf = ptls->safe_restore; + jl_jmp_buf buf; + if (!jl_setjmp(buf, 0)) { + // Set up the safe_restore context so that the printing uses the thread safe version + ptls->safe_restore = &buf; + jl_printf(JL_STDERR, "\nError thrown in threaded loop on thread %d: ", + (int)ptls->tid); + jl_static_show(JL_STDERR, ptls->exception_in_transit); + } + ptls->safe_restore = old_buf; + JL_UNLOCK_NOGC(&lock); } return jl_nothing; } diff --git a/test/threads.jl b/test/threads.jl index d0626fdbc3d79e..9be402deb60215 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -401,3 +401,20 @@ function test_load_and_lookup_18020(n) end end test_load_and_lookup_18020(10000) + +# Nested threaded loops +# This may not be efficient/fully supported but should work without crashing..... +function test_nested_loops() + a = zeros(Int, 100, 100) + @threads for i in 1:100 + @threads for j in 1:100 + a[j, i] = i + j + end + end + for i in 1:100 + for j in 1:100 + @test a[j, i] == i + j + end + end +end +test_nested_loops()