Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add Threads.foreach for convenient multithreaded Channel consumption #34543

Merged
merged 19 commits into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Language changes
Compiler/Runtime improvements
-----------------------------


* All platforms can now use `@executable_path` within `jl_load_dynamic_library()`.
This allows executable-relative paths to be embedded within executables on all
platforms, not just MacOS, which the syntax is borrowed from. ([#35627])
Expand All @@ -33,14 +32,17 @@ Build system changes

New library functions
---------------------

* New function `Base.kron!` and corresponding overloads for various matrix types for performing Kronecker product in-place. ([#31069]).
* New function `Base.Threads.foreach(f, channel::Channel)` for multithreaded `Channel` consumption. ([#34543]).

New library features
--------------------


Standard library changes
------------------------

* The `nextprod` function now accepts tuples and other array types for its first argument ([#35791]).
* The function `isapprox(x,y)` now accepts the `norm` keyword argument also for numeric (i.e., non-array) arguments `x` and `y` ([#35883]).
* `view`, `@view`, and `@views` now work on `AbstractString`s, returning a `SubString` when appropriate ([#35879]).
Expand Down
1 change: 1 addition & 0 deletions base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ include("threads.jl")
include("lock.jl")
include("channels.jl")
include("task.jl")
include("threads_overloads.jl")
include("weakkeydict.jl")

# Logging
Expand Down
8 changes: 8 additions & 0 deletions base/threadingconstructs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,11 @@ macro spawn(expr)
end
end
end

# This is a stub that can be overloaded for downstream structures like `Channel`
function foreach end

# Scheduling traits that can be employed for downstream overloads
abstract type AbstractSchedule end
struct StaticSchedule <: AbstractSchedule end
struct FairSchedule <: AbstractSchedule end
51 changes: 51 additions & 0 deletions base/threads_overloads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

"""
Threads.foreach(f, channel::Channel;
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
ntasks=Threads.nthreads())

Similar to `foreach(f, channel)`, but iteration over `channel` and calls to
`f` are split across `ntasks` tasks spawned by `Threads.@spawn`. This function
will wait for all internally spawned tasks to complete before returning.

If `schedule isa FairSchedule`, `Threads.foreach` will attempt to spawn tasks in a
manner that enables Julia's scheduler to more freely load-balance work items across
threads. This approach generally has higher per-item overhead, but may perform
better than `StaticSchedule` in concurrence with other multithreaded workloads.

If `schedule isa StaticSchedule`, `Threads.foreach` will spawn tasks in a manner
that incurs lower per-item overhead than `FairSchedule`, but is less amenable
to load-balancing. This approach thus may be more suitable for fine-grained,
uniform workloads, but may perform worse than `FairSchedule` in concurrence
with other multithreaded workloads.

!!! compat "Julia 1.6"
This function requires Julia 1.6 or later.
"""
function Threads.foreach(f, channel::Channel;
schedule::Threads.AbstractSchedule=Threads.FairSchedule(),
ntasks=Threads.nthreads())
apply = _apply_for_schedule(schedule)
stop = Threads.Atomic{Bool}(false)
@sync for _ in 1:ntasks
Threads.@spawn try
for item in channel
$apply(f, item)
# do `stop[] && break` after `f(item)` to avoid losing `item`.
# this isn't super comprehensive since a task could still get
# stuck on `take!` at `for item in channel`. We should think
# about a more robust mechanism to avoid dropping items. See also:
# https://github.com/JuliaLang/julia/pull/34543#discussion_r422695217
stop[] && break
end
catch
stop[] = true
rethrow()
end
end
return nothing
end

_apply_for_schedule(::Threads.StaticSchedule) = (f, x) -> f(x)
_apply_for_schedule(::Threads.FairSchedule) = (f, x) -> wait(Threads.@spawn f(x))
35 changes: 35 additions & 0 deletions test/threads_exec.jl
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,38 @@ fib34666(x) =
f(x)
end
@test fib34666(25) == 75025

function jitter_channel(f, k, delay, ntasks, schedule)
x = Channel(ch -> foreach(i -> put!(ch, i), 1:k), 1)
y = Channel(k) do ch
g = i -> begin
iseven(i) && sleep(delay)
put!(ch, f(i))
end
Threads.foreach(g, x; schedule=schedule, ntasks=ntasks)
end
return y
end

@testset "Threads.foreach(f, ::Channel)" begin
jrevels marked this conversation as resolved.
Show resolved Hide resolved
k = 50
delay = 0.01
expected = sin.(1:k)
ordered_fair = collect(jitter_channel(sin, k, delay, 1, Threads.FairSchedule()))
ordered_static = collect(jitter_channel(sin, k, delay, 1, Threads.StaticSchedule()))
@test expected == ordered_fair
@test expected == ordered_static

unordered_fair = collect(jitter_channel(sin, k, delay, 10, Threads.FairSchedule()))
unordered_static = collect(jitter_channel(sin, k, delay, 10, Threads.StaticSchedule()))
@test expected != unordered_fair
@test expected != unordered_static
@test Set(expected) == Set(unordered_fair)
@test Set(expected) == Set(unordered_static)

ys = Channel() do ys
inner = Channel(xs -> foreach(i -> put!(xs, i), 1:3))
Threads.foreach(x -> put!(ys, x), inner)
end
@test sort!(collect(ys)) == 1:3
end