Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Band aid to make threaded loop a little easier to work with #21452

Merged
merged 1 commit into from
Apr 22, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 23 additions & 6 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,25 @@ on `threadid()`.
"""
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))

# Only read/written by the main thread
const in_threaded_loop = Ref(false)

function _threadsfor(iter,lbody)
fun = gensym("_threadsfor")
lidx = iter.args[1] # index
range = iter.args[2]
quote
range = $(esc(range))
function $fun()
tid = threadid()
function threadsfor_fun(onethread=false)
r = range # Load into local variable
lenr = length(r)
# divide loop iterations among threads
len, rem = divrem(length(r), nthreads())
if onethread
tid = 1
len, rem = lenr, 0
else
tid = threadid()
len, rem = divrem(lenr, nthreads())
end
# not enough iterations for all the threads?
if len == 0
if tid > rem
Expand All @@ -55,7 +63,17 @@ function _threadsfor(iter,lbody)
$(esc(lbody))
end
end
ccall(:jl_threading_run, Ref{Void}, (Any,), $fun)
# Hack to make nested threaded loops kinda work
if threadid() != 1 || in_threaded_loop[]
# We are in a nested threaded loop
threadsfor_fun(true)
else
in_threaded_loop[] = true
# the ccall is not expected to throw
ccall(:jl_threading_run, Ref{Void}, (Any,), threadsfor_fun)
in_threaded_loop[] = false
end
nothing
end
end
"""
Expand All @@ -81,4 +99,3 @@ macro threads(args...)
throw(ArgumentError("unrecognized argument to @threads"))
end
end

15 changes: 14 additions & 1 deletion src/threading.c
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,20 @@ static jl_value_t *ti_run_fun(const jl_generic_fptr_t *fptr, jl_method_instance_
jl_call_fptr_internal(fptr, mfunc, args, nargs);
}
JL_CATCH {
return ptls->exception_in_transit;
// Lock this output since we know it'll likely happen on multiple threads
static jl_mutex_t lock;
JL_LOCK_NOGC(&lock);
jl_jmp_buf *old_buf = ptls->safe_restore;
jl_jmp_buf buf;
if (!jl_setjmp(buf, 0)) {
// Set up the safe_restore context so that the printing uses the thread safe version
ptls->safe_restore = &buf;
jl_printf(JL_STDERR, "\nError thrown in threaded loop on thread %d: ",
(int)ptls->tid);
jl_static_show(JL_STDERR, ptls->exception_in_transit);
}
ptls->safe_restore = old_buf;
JL_UNLOCK_NOGC(&lock);
}
return jl_nothing;
}
Expand Down
17 changes: 17 additions & 0 deletions test/threads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -415,3 +415,20 @@ function test_load_and_lookup_18020(n)
end
end
test_load_and_lookup_18020(10000)

# Nested threaded loops
# This may not be efficient/fully supported but should work without crashing.....
function test_nested_loops()
a = zeros(Int, 100, 100)
@threads for i in 1:100
@threads for j in 1:100
a[j, i] = i + j
end
end
for i in 1:100
for j in 1:100
@test a[j, i] == i + j
end
end
end
test_nested_loops()