Skip to content

Add DaggerMPI subpackage for MPI integrations #356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lib/DaggerMPI/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name = "DaggerMPI"
uuid = "37bfb287-2338-4693-8557-581796463535"
authors = ["Julian P Samaroo <jpsamaroo@jpsamaroo.me>"]
version = "0.1.0"

[deps]
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
MPI = "da04e1cc-30fd-572f-bb4f-1f8673147195"
171 changes: 171 additions & 0 deletions lib/DaggerMPI/src/DaggerMPI.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
module DaggerMPI

using Dagger
using MPI

struct MPIProcessor{P,C} <: Dagger.Processor
proc::P
comm::MPI.Comm
color_algo::C
end

struct SimpleColoring end
function (sc::SimpleColoring)(comm, key)
return UInt64(rem(key, MPI.Comm_size(comm)))
end

const MPI_PROCESSORS = Ref{Int}(-1)

const PREVIOUS_PROCESSORS = Set()


function initialize(comm::MPI.Comm=MPI.COMM_WORLD; color_algo=SimpleColoring())
@assert MPI_PROCESSORS[] == -1 "DaggerMPI already initialized"

# Force eager_thunk to run
fetch(Dagger.@spawn 1+1)

MPI.Init(; finalize_atexit=false)
procs = Dagger.get_processors(OSProc())
i = 0
empty!(Dagger.PROCESSOR_CALLBACKS)
empty!(Dagger.OSPROC_PROCESSOR_CACHE)
for proc in procs
Dagger.add_processor_callback!("mpiprocessor_$i") do
return MPIProcessor(proc, comm, color_algo)
end
i += 1
end
MPI_PROCESSORS[] = i

# FIXME: Hack to populate new processors
Dagger.get_processors(OSProc())

return nothing
end

function finalize()
@assert MPI_PROCESSORS[] > -1 "DaggerMPI not yet initialized"
for i in 1:MPI_PROCESSORS[]
Dagger.delete_processor_callback!("mpiprocessor_$i")
end
empty!(Dagger.PROCESSOR_CALLBACKS)
empty!(Dagger.OSPROC_PROCESSOR_CACHE)
i = 1
for proc in PREVIOUS_PROCESSORS
Dagger.add_processor_callback!("old_processor_$i") do
return proc
end
i += 1
end
empty!(PREVIOUS_PROCESSORS)
MPI.Finalize()
MPI_PROCESSORS[] = -1
end

"References a value stored on some MPI rank."
struct MPIColoredValue{T}
color::UInt64
value::T
comm::MPI.Comm
end

Dagger.get_parent(proc::MPIProcessor) = Dagger.OSProc()
Dagger.default_enabled(proc::MPIProcessor) = true


"Busy-loop Irecv that yields to other tasks."
function recv_yield(src, tag, comm)
while true
(got, msg, stat) = MPI.Improbe(src, tag, comm, MPI.Status)
if got
count = MPI.Get_count(stat, UInt8)
buf = Array{UInt8}(undef, count)
req = MPI.Imrecv!(MPI.Buffer(buf), msg)
while true
finish = MPI.Test(req)
if finish
value = MPI.deserialize(buf)
return value
end
yield()
end
end
# TODO: Sigmoidal backoff
yield()
end
end

function Dagger.execute!(proc::MPIProcessor, f, args...)
rank = MPI.Comm_rank(proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash() >> 32))
color = proc.color_algo(proc.comm, tag)
if rank == color
@debug "[$rank] Executing $f on $tag"
return MPIColoredValue(color, Dagger.execute!(proc.proc, f, args...), proc.comm)
end
# Return nothing, we won't use this value anyway
@debug "[$rank] Skipped $f on $tag"
return MPIColoredValue(color, nothing, proc.comm)
end

function Dagger.move(from_proc::MPIProcessor, to_proc::MPIProcessor, x::Dagger.Chunk)
@assert from_proc.comm == to_proc.comm "Mixing different MPI communicators is not supported"
@assert Dagger.chunktype(x) <: MPIColoredValue
x_value = fetch(x)
rank = MPI.Comm_rank(from_proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32))
other_tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:self) >> 32))
other = from_proc.color_algo(from_proc.comm, other_tag)
if x_value.color == rank == other
# We generated and will use this input
return Dagger.move(from_proc.proc, to_proc.proc, x_value.value)
elseif x_value.color == rank
# We generated this input
@debug "[$rank] Starting P2P send to [$other] from $tag to $other_tag"
MPI.isend(x_value.value, other, tag, from_proc.comm)
@debug "[$rank] Finished P2P send to [$other] from $tag to $other_tag"
return Dagger.move(from_proc.proc, to_proc.proc, x_value.value)
elseif other == rank
# We will use this input
@debug "[$rank] Starting P2P recv from $tag to $other_tag"
value = recv_yield(x_value.color, tag, from_proc.comm)
@debug "[$rank] Finished P2P recv from $tag to $other_tag"
return Dagger.move(from_proc.proc, to_proc.proc, value)
else
# We didn't generate and will not use this input
return nothing
end
end

function Dagger.move(from_proc::MPIProcessor, to_proc::Dagger.Processor, x::Dagger.Chunk)
@assert Dagger.chunktype(x) <: MPIColoredValue
x_value = fetch(x)
rank = MPI.Comm_rank(from_proc.comm)
tag = abs(Base.unsafe_trunc(Int32, Dagger.get_task_hash(:input) >> 32))
if rank == x_value.color
# FIXME: Broadcast send
@sync for other in 0:(MPI.Comm_size(from_proc.comm)-1)
other == rank && continue
@async begin
@debug "[$rank] Starting bcast send to [$other] on $tag"
MPI.isend(x_value.value, other, tag, from_proc.comm)
@debug "[$rank] Finished bcast send to [$other] on $tag"
end
end
return Dagger.move(from_proc.proc, to_proc, x_value.value)
else
@debug "[$rank] Starting bcast recv on $tag"
value = recv_yield(x_value.color, tag, from_proc.comm)
@debug "[$rank] Finished bcast recv on $tag"
return Dagger.move(from_proc.proc, to_proc, value)
end
end

function Dagger.move(from_proc::Dagger.Processor, to_proc::MPIProcessor, x::Dagger.Chunk)
@assert !(Dagger.chunktype(x) <: MPIColoredValue)
rank = MPI.Comm_rank(to_proc.comm)
return MPIColoredValue(rank, Dagger.move(from_proc, to_proc.proc, x), from_proc.comm)
end

end # module
1 change: 1 addition & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
@@ -36,6 +36,7 @@ include("chunks.jl")
include("compute.jl")
include("utils/clock.jl")
include("utils/system_uuid.jl")
include("utils/uhash.jl")
include("sch/Sch.jl"); using .Sch

# Array computations
7 changes: 4 additions & 3 deletions src/chunks.jl
Original file line number Diff line number Diff line change
@@ -53,6 +53,7 @@ mutable struct Chunk{T, H, P<:Processor, S<:AbstractScope}
processor::P
scope::S
persist::Bool
hash::UInt
end

domain(c::Chunk) = c.domain
@@ -242,16 +243,16 @@ be used.

All other kwargs are passed directly to `MemPool.poolset`.
"""
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, kwargs...) where {X,P,S}
function tochunk(x::X, proc::P=OSProc(), scope::S=AnyScope(); persist=false, cache=false, device=nothing, hash=UInt(0), kwargs...) where {X,P,S}
if device === nothing
device = if Sch.walk_storage_safe(x)
MemPool.GLOBAL_DEVICE[]
else
MemPool.CPURAMDevice()
end
end
ref = poolset(x; device, kwargs...)
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist)
ref = poolset(move(OSProc(), proc, x); device, kwargs...)
Chunk{X,typeof(ref),P,S}(X, domain(x), ref, proc, scope, persist, hash)
end
tochunk(x::Union{Chunk, Thunk}, proc=nothing, scope=nothing; kwargs...) = x

41 changes: 34 additions & 7 deletions src/processor.jl
Original file line number Diff line number Diff line change
@@ -288,19 +288,40 @@ end
# In-Thunk Helpers

"""
thunk_processor()
thunk_processor() -> Dagger.Processor

Get the current processor executing the current thunk.
"""
thunk_processor() = task_local_storage(:_dagger_processor)::Processor

"""
in_thunk()
in_thunk() -> Bool

Returns `true` if currently in a [`Thunk`](@ref) process, else `false`.
"""
in_thunk() = haskey(task_local_storage(), :_dagger_sch_uid)

"""
get_task_hash(kind::Symbol=:self) -> UInt

Returns the unified hash of the current task or of an input to the current
task. If `kind == :self`, then the hash is for the current task; if `kind ==
:input`, then the hash is for the current input to the task that is being
processed. The `:self` hash is available during `Dagger.execute!` and
`Dagger.move`, whereas the `:input` hash is only available during
`Dagger.move`. This hash is consistent across Julia processes (if all
processes are running the same Julia version on the same architecture).
"""
function get_task_hash(kind::Symbol=:self)::UInt
if kind == :self
return task_local_storage(:_dagger_task_hash)::UInt
elseif kind == :input
return task_local_storage(:_dagger_input_hash)::UInt
else
throw(ArgumentError("Invalid task hash kind: $kind"))
end
end

"""
get_tls()

@@ -309,6 +330,8 @@ Gets all Dagger TLS variable as a `NamedTuple`.
get_tls() = (
sch_uid=task_local_storage(:_dagger_sch_uid),
sch_handle=task_local_storage(:_dagger_sch_handle),
task_hash=task_local_storage(:_dagger_task_hash),
input_hash=get(task_local_storage(), :_dagger_input_hash, nothing),
processor=thunk_processor(),
time_utilization=task_local_storage(:_dagger_time_utilization),
alloc_utilization=task_local_storage(:_dagger_alloc_utilization),
@@ -320,9 +343,13 @@ get_tls() = (
Sets all Dagger TLS variables from the `NamedTuple` `tls`.
"""
function set_tls!(tls)
task_local_storage(:_dagger_sch_uid, tls.sch_uid)
task_local_storage(:_dagger_sch_handle, tls.sch_handle)
task_local_storage(:_dagger_processor, tls.processor)
task_local_storage(:_dagger_time_utilization, tls.time_utilization)
task_local_storage(:_dagger_alloc_utilization, tls.alloc_utilization)
task_local_storage(:_dagger_sch_uid, get(tls, :sch_uid, nothing))
task_local_storage(:_dagger_sch_handle, get(tls, :sch_handle, nothing))
task_local_storage(:_dagger_task_hash, get(tls, :task_hash, nothing))
if haskey(tls, :input_hash) && tls.input_hash !== nothing
task_local_storage(:_dagger_input_hash, tls.input_hash)
end
task_local_storage(:_dagger_processor, get(tls, :processor, nothing))
task_local_storage(:_dagger_time_utilization, get(tls, :time_utilization, nothing))
task_local_storage(:_dagger_alloc_utilization, get(tls, :alloc_utilization, nothing))
end
37 changes: 25 additions & 12 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@ import Random: randperm

import ..Dagger
import ..Dagger: Context, Processor, Thunk, WeakThunk, ThunkFuture, ThunkFailedException, Chunk, OSProc, AnyScope
import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime
import ..Dagger: order, dependents, noffspring, istask, inputs, unwrap_weak_checked, affinity, tochunk, timespan_start, timespan_finish, procs, move, chunktype, processor, default_enabled, get_processors, get_parent, execute!, rmprocs!, addprocs!, thunk_processor, constrain, cputhreadtime, uhash

const OneToMany = Dict{Thunk, Set{Thunk}}

@@ -613,6 +613,10 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
@assert !haskey(state.cache, task)
opts = merge(ctx.options, task.options)
sig = signature(task, state)
if task.hash == UInt(0)
# Compute the hash and cache it in the task
uhash(task, UInt(0); sig)
end

# Calculate scope
scope = if task.f isa Chunk
@@ -672,7 +676,7 @@ function schedule!(ctx, state, procs=procs_to_use(ctx))
# Schedule task onto proc
# FIXME: est_time_util = est_time_util isa MaxUtilization ? cap : est_time_util
push!(get!(()->Vector{Tuple{Thunk,<:Any,<:Any}}(), to_fire, (gproc, proc)), (task, est_time_util, est_alloc_util))
state.worker_time_pressure[gproc.pid][proc] += est_time_util
state.worker_time_pressure[gproc.pid][proc] = get(state.worker_time_pressure[gproc.pid], proc, UInt64(0)) + est_time_util
@goto pop_task
end
end
@@ -893,10 +897,12 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)

ids = Int[0]
data = Any[thunk.f]
hashes = Union{UInt,Nothing}[uhash(thunk.f, UInt(0))]
for (idx, x) in enumerate(thunk.inputs)
x = unwrap_weak_checked(x)
push!(ids, istask(x) ? x.id : -idx)
push!(data, istask(x) ? state.cache[x] : x)
push!(hashes, uhash(x, UInt(0)))
end
toptions = thunk.options !== nothing ? thunk.options : ThunkOptions()
options = merge(ctx.options, toptions)
@@ -906,9 +912,10 @@ function fire_tasks!(ctx, thunks::Vector{<:Tuple}, (gproc, proc), state)
sch_handle = SchedulerHandle(ThunkID(thunk.id, nothing), state.worker_chans[gproc.pid]...)

# TODO: De-dup common fields (log_sink, uid, etc.)
push!(to_send, Any[thunk.id, time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result,
push!(to_send, Any[thunk.id, thunk.hash,
time_util, alloc_util, chunktype(thunk.f), data, thunk.get_result,
thunk.persist, thunk.cache, thunk.meta, options,
propagated, ids,
propagated, ids, hashes,
(log_sink=ctx.log_sink, profile=ctx.profile),
sch_handle, state.uid])
end
@@ -964,7 +971,7 @@ function do_tasks(to_proc, chan, tasks)
end
"Executes a single task on `to_proc`."
function do_task(to_proc, comm)
thunk_id, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, ctx_vars, sch_handle, uid = comm
thunk_id, task_hash, est_time_util, est_alloc_util, Tf, data, send_result, persist, cache, meta, options, propagated, ids, hashes, ctx_vars, sch_handle, sch_uid = comm
ctx = Context(Processor[]; log_sink=ctx_vars.log_sink, profile=ctx_vars.profile)

from_proc = OSProc()
@@ -1002,7 +1009,7 @@ function do_task(to_proc, comm)
lock(TASK_SYNC) do
while true
# Get current time utilization for the selected processor
time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, uid)
time_dict = get!(()->Dict{Processor,Ref{UInt64}}(), PROCESSOR_TIME_UTILIZATION, sch_uid)
real_time_util = get!(()->Ref{UInt64}(UInt64(0)), time_dict, to_proc)

# Get current allocation utilization and capacity
@@ -1043,14 +1050,19 @@ function do_task(to_proc, comm)
# Initiate data transfers for function and arguments
transfer_time = Threads.Atomic{UInt64}(0)
transfer_size = Threads.Atomic{UInt64}(0)
_data, _ids = if meta
(Any[first(data)], Int[first(ids)]) # always fetch function
_data, _ids, _hashes = if meta
(Any[first(data)], Int[first(ids)], Union{UInt,Nothing}[first(hashes)]) # always fetch function
else
(data, ids)
(data, ids, hashes)
end
fetch_tasks = map(Iterators.zip(_data,_ids)) do (x, id)
fetch_tasks = map(Iterators.zip(_data, _ids, _hashes)) do (x, id, hash)
@async begin
timespan_start(ctx, :move, (;thunk_id, id), (;f, id, data=x))
Dagger.set_tls!((
sch_uid=sch_uid,
input_hash=hash,
task_hash,
))
x = if x isa Chunk
value = lock(TASK_SYNC) do
if haskey(CHUNK_CACHE, x)
@@ -1123,8 +1135,9 @@ function do_task(to_proc, comm)
result_meta = try
# Set TLS variables
Dagger.set_tls!((
sch_uid=uid,
sch_uid,
sch_handle=sch_handle,
task_hash,
processor=to_proc,
time_utilization=est_time_util,
alloc_utilization=est_alloc_util,
@@ -1149,7 +1162,7 @@ function do_task(to_proc, comm)

# Construct result
# TODO: We should cache this locally
send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache)
send_result || meta ? res : tochunk(res, to_proc; device, persist, cache=persist ? true : cache, hash=task_hash)
catch ex
bt = catch_backtrace()
RemoteException(myid(), CapturedException(ex, bt))
2 changes: 1 addition & 1 deletion src/sch/util.jl
Original file line number Diff line number Diff line change
@@ -394,7 +394,7 @@ function estimate_task_costs(state, procs, task, inputs)
transfer_costs = Dict(proc=>impute_sum([affinity(chunk)[2] for chunk in filter(c->get_parent(processor(c))!=get_parent(proc), chunks)]) for proc in procs)

# Estimate total cost to move data and get task running after currently-scheduled tasks
costs = Dict(proc=>state.worker_time_pressure[get_parent(proc).pid][proc]+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs)
costs = Dict(proc=>get(state.worker_time_pressure[get_parent(proc).pid], proc, UInt64(0))+(tx_cost/tx_rate) for (proc, tx_cost) in transfer_costs)

# Shuffle procs around, so equally-costly procs are equally considered
P = randperm(length(procs))
22 changes: 19 additions & 3 deletions src/thunk.jl
Original file line number Diff line number Diff line change
@@ -52,6 +52,7 @@ mutable struct Thunk
f::Any # usually a Function, but could be any callable
inputs::Vector{Any} # TODO: Use `ImmutableArray` in 1.8
id::Int
hash::UInt
get_result::Bool # whether the worker should send the result or only the metadata
meta::Bool
persist::Bool # don't `free!` result after computing
@@ -64,6 +65,7 @@ mutable struct Thunk
propagates::Tuple # which options we'll propagate
function Thunk(f, xs...;
id::Int=next_id(),
hash=UInt(0),
get_result::Bool=false,
meta::Bool=false,
persist::Bool=false,
@@ -85,17 +87,19 @@ mutable struct Thunk
xs = Any[xs...]
if options !== nothing
@assert isempty(kwargs)
new(f, xs, id, get_result, meta, persist, cache, cache_ref,
new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref,
affinity, eager_ref, options, propagates)
else
new(f, xs, id, get_result, meta, persist, cache, cache_ref,
new(f, xs, id, hash, get_result, meta, persist, cache, cache_ref,
affinity, eager_ref, Sch.ThunkOptions(;kwargs...), propagates)
end
end
end
Serialization.serialize(io::AbstractSerializer, t::Thunk) =
throw(ArgumentError("Cannot serialize a Thunk"))

get_task_hash(t::Thunk) = t.hash

function affinity(t::Thunk)
if t.affinity !== nothing
return t.affinity
@@ -183,6 +187,7 @@ end
unwrap_weak_checked(t) = t
Base.show(io::IO, t::WeakThunk) = (print(io, "~"); Base.show(io, t.x.value))
Base.convert(::Type{WeakThunk}, t::Thunk) = WeakThunk(t)
get_task_hash(t::WeakThunk) = unwrap_weak_checked(t).hash

struct ThunkFailedException{E<:Exception} <: Exception
thunk::WeakThunk
@@ -223,12 +228,23 @@ function Base.fetch(t::EagerThunk; raw=false)
if raw
fetch(t.future; raw=true)
else
move(OSProc(), fetch(t.future))
value = fetch(t.future)
if value isa Chunk
return fetch(@async begin
Dagger.set_tls!((input_hash=value.hash,
task_hash=value.hash))
return move(OSProc(), value)
end)
else
return move(OSProc(), value)
end
end
end
function Base.show(io::IO, t::EagerThunk)
print(io, "EagerThunk ($(isready(t) ? "finished" : "running"))")
end
get_task_hash(t::EagerThunk) =
remotecall_fetch(d->get_task_hash(poolget(d)), t.thunk_ref.owner, t.thunk_ref)

"When finalized, cleans-up the associated `EagerThunk`."
mutable struct EagerThunkFinalizer
61 changes: 61 additions & 0 deletions src/utils/uhash.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# unified hash algorithm

using Dagger

uhash(x, h::UInt)::UInt = hash(x, h)
function uhash(x::Dagger.Thunk, h::UInt; sig=nothing)::UInt
value = hash(0xdead7453, h)
if x.hash != UInt(0)
return uhash(x.hash, value)
end
@assert sig !== nothing
tt = Any[]
for input in x.inputs
input = unwrap_weak_checked(input)
if input isa Dagger.Thunk && input.hash != UInt(0)
value = uhash(input.hash, value)
else
value = uhash(input, value)
push!(tt, typeof(input))
end
end
sig = (typeof(x.f), tt)
value = uhash_sig(sig, value)
x.hash = value
return value
end
uhash(x::Dagger.WeakThunk, h::UInt)::UInt =
uhash(Dagger.unwrap_weak_checked(x), h)
function uhash_sig((f, tt), h::UInt)::UInt
value = hash(0xdead5160, h)
ci_list = Base.code_typed(f, tt)
if length(ci_list) == 0
return hash(Union{}, hash(typeof(f), hash(tt, value)))
end
# tt must be concrete
ci = first(only(ci_list))::Core.CodeInfo
return uhash_code(ci, hash(typeof(f), hash(tt, value)))
end
function uhash_code(ci::Core.CodeInfo, h::UInt)::UInt
value = hash(0xdeadc0de, h)
for insn in ci.code
dump(insn)
value = uhash_insn(insn, h)
end
return value
end
function uhash_insn(insn::Expr, h::UInt)::UInt
value = hash(0xdeadeec54, h)
value = hash(insn.head, value)
for arg in insn.args
dump(insn)
@show uhash_insn(arg, value)
value = uhash_insn(arg, value)
end
return value
end
function uhash_insn(insn::GlobalRef, h::UInt)::UInt
value = hash(0xdead6147, h)
return hash(nameof(insn.mod), hash(insn.name, value))
end
uhash_insn(insn, h::UInt)::UInt = hash(insn, h)