diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index 895b667c14aa98..1237c1dba3bd0e 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -18,17 +18,25 @@ on `threadid()`. """ nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint))) +# Only read/write by the main thread +const in_threaded_loop = Ref(false) + function _threadsfor(iter,lbody) - fun = gensym("_threadsfor") lidx = iter.args[1] # index range = iter.args[2] quote range = $(esc(range)) - function $fun() - tid = threadid() + function threadsfor_fun(onethread=false) r = range # Load into local variable + lenr = length(r) # divide loop iterations among threads - len, rem = divrem(length(r), nthreads()) + if onethread + tid = 1 + len, rem = lenr, 0 + else + tid = threadid() + len, rem = divrem(lenr, nthreads()) + end # not enough iterations for all the threads? if len == 0 if tid > rem @@ -55,7 +63,17 @@ function _threadsfor(iter,lbody) $(esc(lbody)) end end - ccall(:jl_threading_run, Ref{Void}, (Any,), $fun) + # Hack to make nested threaded loops kinda work + if threadid() != 1 || in_threaded_loop[] + # We are in a nested threaded loop + threadsfor_fun(true) + else + in_threaded_loop[] = true + # the ccall is not expected to throw + ccall(:jl_threading_run, Ref{Void}, (Any,), threadsfor_fun) + in_threaded_loop[] = false + end + nothing end end """ @@ -81,4 +99,3 @@ macro threads(args...) throw(ArgumentError("unrecognized argument to @threads")) end end - diff --git a/src/threading.c b/src/threading.c index 81a47b55977e6f..f2fdaf96eb3e6e 100644 --- a/src/threading.c +++ b/src/threading.c @@ -316,7 +316,20 @@ static jl_value_t *ti_run_fun(const jl_generic_fptr_t *fptr, jl_method_instance_ jl_call_fptr_internal(fptr, mfunc, args, nargs); } JL_CATCH { - return ptls->exception_in_transit; + // Lock this output since we know it'll likely happen on multiple threads + static jl_mutex_t lock; + JL_LOCK_NOGC(&lock); + jl_jmp_buf *old_buf = ptls->safe_restore; + jl_jmp_buf buf; + if (!jl_setjmp(buf, 0)) { + // Set up the safe_restore context so that the printing uses the thread safe version + ptls->safe_restore = &buf; + jl_printf(JL_STDERR, "\nError thrown in threaded loop on thread %d: ", + (int)ptls->tid); + jl_static_show(JL_STDERR, ptls->exception_in_transit); + } + ptls->safe_restore = old_buf; + JL_UNLOCK_NOGC(&lock); } return jl_nothing; } diff --git a/test/threads.jl b/test/threads.jl index 9c587ede557e33..ff38963c09b04f 100644 --- a/test/threads.jl +++ b/test/threads.jl @@ -415,3 +415,20 @@ function test_load_and_lookup_18020(n) end end test_load_and_lookup_18020(10000) + +# Nested threaded loops +# This may not be efficient/fully supported but should work without crashing..... +function test_nested_loops() + a = zeros(Int, 100, 100) + @threads for i in 1:100 + @threads for j in 1:100 + a[j, i] = i + j + end + end + for i in 1:100 + for j in 1:100 + @test a[j, i] == i + j + end + end +end +test_nested_loops()