Skip to content

datadeps: Optimize using task results via deferrals #567

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
202 changes: 180 additions & 22 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
@@ -22,8 +22,6 @@ end
Deps(x, deps...) = Deps(x, deps)

struct DataDepsTaskQueue <: AbstractTaskQueue
# The queue above us
upper_queue::AbstractTaskQueue
# The set of tasks that have already been seen
seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing}
# The data-dependency graph of all tasks
@@ -39,14 +37,14 @@ struct DataDepsTaskQueue <: AbstractTaskQueue
# The fields following only apply when aliasing==true
aliasing::Bool

function DataDepsTaskQueue(upper_queue;
function DataDepsTaskQueue(;
traversal::Symbol=:inorder,
scheduler::Symbol=:naive,
aliasing::Bool=true)
seen_tasks = Pair{DTaskSpec,DTask}[]
g = SimpleDiGraph()
task_to_id = Dict{DTask,Int}()
return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler,
return new(seen_tasks, g, task_to_id, traversal, scheduler,
aliasing)
end
end
@@ -137,6 +135,35 @@ struct DataDepsNonAliasingState
args_owner, args_readers)
end
end
struct DeferralState
# The map of DTask to DTaskSpec
task_to_spec::Dict{DTask,DTaskSpec}

# The set of not yet fulfilled tasks upstream/downstream of a given task
waiting_upstream::Dict{DTask,Set{DTask}}
waiting_downstream::Dict{DTask,Set{DTask}}

# The condition that finished tasks will trigger
cond::Threads.Condition

# The set of tasks that are ready to be executed
ready::Set{DTask}

# The Julia tasks waiting on Futures to trigger
tasks::IdDict{DTask,Task}

function DeferralState()
task_to_spec = Dict{DTask,DTaskSpec}()
waiting_upstream = Dict{DTask,Set{DTask}}()
waiting_downstream = Dict{DTask,Set{DTask}}()
cond = Threads.Condition()
ready = Set{DTask}()
tasks = IdDict{DTask,Task}()
return new(task_to_spec,
waiting_upstream, waiting_downstream,
cond, ready, tasks)
end
end
struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}}
# Whether aliasing is being analyzed
aliasing::Bool
@@ -147,18 +174,23 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
# The mapping of memory space to remote argument copies
remote_args::Dict{MemorySpace,IdDict{Any,Any}}

# The state for deferring tasks
deferral_state::DeferralState

# The aliasing analysis state
alias_state::State

function DataDepsState(aliasing::Bool)
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
deferral_state = DeferralState()
if aliasing
state = DataDepsAliasingState()
else
state = DataDepsNonAliasingState()
end
return new{typeof(state)}(aliasing, dependencies, remote_args, state)
return new{typeof(state)}(aliasing, dependencies, remote_args,
deferral_state, state)
end
end

@@ -232,6 +264,117 @@ function is_writedep(arg, deps, task::DTask)
return any(dep->dep[3], deps)
end

"Unwraps a completed `DTask` to get its `Chunk`."
fetch_ready(arg) = arg
function fetch_ready(arg::DTask)
@assert istaskdone(arg) "Task is not yet ready"
return fetch(arg; raw=true)
end
"Checks if a `DTask` is waiting on unfinished tasks, and must be deferred."
function must_defer!(state::DataDepsState, spec::DTaskSpec, task::DTask)
upstreams = Set{DTask}()
for (_, arg) in spec.args
arg, _ = unwrap_inout(arg)
if arg isa DTask && !istaskdone(arg)
push!(upstreams, arg)
end
end
if !isempty(upstreams)
dstate = state.deferral_state
dstate.task_to_spec[task] = spec
dstate.waiting_upstream[task] = upstreams
for upstream in upstreams
push!(get!(Set{DTask}, dstate.waiting_downstream, upstream), task)
try_watch_deferred!(dstate, upstream)
end
@dagdebug nothing :spawn_datadeps "Deferring task UID $(task.uid) ($(length(upstreams)) upstreams)"
return true
else
@dagdebug nothing :spawn_datadeps "Not deferring task UID $(task.uid)"
return false
end
end
"Tries to start watching for the completion of `task`."
function try_watch_deferred!(dstate, task)
if haskey(dstate.tasks, task)
# Already watching this task
return
end
if !istaskstarted(task)
# Can't start watching this task yet directly
return
end
if !haskey(dstate.waiting_downstream, task)
# No downstreams are watching
return
end

# Start watching
ctx = Sch.eager_context()
sch_state = Sch.EAGER_STATE[]
future = ThunkFuture()
thunk_id = Sch.ThunkID(lock(Sch.EAGER_ID_MAP) do id_map
id_map[task.uid]
end, task.thunk_ref)
dstate.tasks[task] = Sch.errormonitor_tracked("datadeps future listener", Threads.@spawn begin
wait(future)
lock(dstate.cond) do
push!(dstate.ready, task)
notify(dstate.cond)
end
end)
Sch._register_future!(ctx, sch_state, current_task(), 0, (future, thunk_id, false))
@dagdebug nothing :spawn_datadeps "Waiting on task UID $(task.uid)"
end
"Waits for deferred tasks to become unblocked."
function wait_deferrals!(state::DataDepsState)
dstate = state.deferral_state
specs_tasks = Vector{Pair{DTaskSpec,DTask}}()

@label wait

if isempty(dstate.waiting_upstream)
return specs_tasks
end

# Wait for upstreams to become ready
ready = lock(dstate.cond) do
while isempty(dstate.ready)
wait(dstate.cond)
end
ready = copy(dstate.ready)
empty!(dstate.ready)
return ready
end

# Find deferred downstreams that are now ready
for upstream in ready
@assert istaskdone(upstream)
downstreams = dstate.waiting_downstream[upstream]
delete!(dstate.waiting_downstream, upstream)
for downstream in downstreams
@dagdebug nothing :spawn_datadeps "Upstream task UID $(upstream.uid) unblocking downstream task UID $(downstream.uid)"
pop!(dstate.waiting_upstream[downstream], upstream)
if isempty(dstate.waiting_upstream[downstream])
@dagdebug nothing :spawn_datadeps "Undeferring task UID $(downstream.uid)"
delete!(dstate.waiting_upstream, downstream)
push!(specs_tasks, dstate.task_to_spec[downstream] => downstream)
delete!(dstate.task_to_spec, downstream)
end
end
delete!(dstate.tasks, upstream)
end

if isempty(specs_tasks)
@goto wait
end

# Sort tasks by sequential ordering
sort!(specs_tasks, by=task->task[2].uid)

return specs_tasks
end

# Aliasing state setup
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
# Populate task dependencies
@@ -243,7 +386,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
arg, deps = unwrap_inout(arg)

# Unwrap the Chunk underlying any DTask arguments
arg = arg isa DTask ? fetch(arg; raw=true) : arg
arg = fetch_ready(arg)

# Skip non-aliasing arguments
type_may_alias(typeof(arg)) || continue
@@ -397,9 +540,7 @@ end
# Make a copy of each piece of data on each worker
# memory_space => {arg => copy_of_arg}
function generate_slot!(state::DataDepsState, dest_space, data)
if data isa DTask
data = fetch(data; raw=true)
end
data = fetch_ready(data)
orig_space = memory_space(data)
to_proc = first(processors(dest_space))
from_proc = first(processors(orig_space))
@@ -471,9 +612,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
@warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1
end

# Round-robin assign tasks to processors
upper_queue = get_options(:task_queue)

traversal = queue.traversal
if traversal == :inorder
# As-is
@@ -518,7 +656,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
end

state = DataDepsState(queue.aliasing)
astate = state.alias_state
sstate = DataDepsSchedulerState()
for proc in all_procs
space = only(memory_spaces(proc))
@@ -530,7 +667,24 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
write_num = 1
proc_idx = 1
pressures = Dict{Processor,Int}()
for (spec, task) in queue.seen_tasks[task_order]
specs_tasks = queue.seen_tasks[task_order]
while !isempty(specs_tasks)
write_num, proc_idx = distribute_tasks_launch!(queue, state, sstate, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures)
specs_tasks = wait_deferrals!(state)
end
distribute_tasks_copy_back!(queue, state, sstate, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures)
end
function distribute_tasks_launch!(queue::DataDepsTaskQueue, state::DataDepsState, sstate::DataDepsSchedulerState, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures)
astate = state.alias_state

upper_queue = get_options(:task_queue)

for (spec, task) in specs_tasks
# Check if this task must be deferred
if must_defer!(state, spec, task)
continue
end

# Populate all task dependencies
populate_task_info!(state, spec, task)

@@ -601,9 +755,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
args = Base.mapany(spec.args) do arg
pos, data = arg
data, _ = unwrap_inout(data)
if data isa DTask
data = fetch(data; raw=true)
end
data = fetch(data; raw=true)
return pos => tochunk(data)
end
f_chunk = tochunk(spec.f)
@@ -676,7 +828,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
for (idx, (pos, arg)) in enumerate(task_args)
# Is the data written previously or now?
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
arg = fetch_ready(arg)
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
spec.args[idx] = pos => arg
@@ -748,7 +900,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
syncdeps = get(Set{Any}, spec.options, :syncdeps)
for (idx, (_, arg)) in enumerate(task_args)
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
arg = fetch_ready(arg)
type_may_alias(typeof(arg)) || continue
if queue.aliasing
for (dep_mod, _, writedep) in deps
@@ -777,11 +929,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
task_scope = our_scope
spec.options = merge(spec.options, (;syncdeps, scope=task_scope))
enqueue!(upper_queue, spec=>task)
try_watch_deferred!(state.deferral_state, task)

# Update read/write tracking for arguments
for (idx, (_, arg)) in enumerate(task_args)
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
arg = fetch_ready(arg)
type_may_alias(typeof(arg)) || continue
if queue.aliasing
for (dep_mod, _, writedep) in deps
@@ -810,6 +963,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
proc_idx = mod1(proc_idx + 1, length(all_procs))
end

return write_num, proc_idx
end
function distribute_tasks_copy_back!(queue::DataDepsTaskQueue, state::DataDepsState, sstate::DataDepsSchedulerState, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures)
astate = state.alias_state

# Copy args from remote to local
if queue.aliasing
# We need to replay the writes from all tasks in-order (skipping any
@@ -952,13 +1110,13 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool
if launch_wait
result = spawn_bulk() do
queue = DataDepsTaskQueue(get_options(:task_queue);
queue = DataDepsTaskQueue(;
traversal, scheduler, aliasing)
with_options(f; task_queue=queue)
distribute_tasks!(queue)
end
else
queue = DataDepsTaskQueue(get_options(:task_queue);
queue = DataDepsTaskQueue(;
traversal, scheduler, aliasing)
result = with_options(f; task_queue=queue)
distribute_tasks!(queue)