diff --git a/src/datadeps.jl b/src/datadeps.jl index 3c1110820..346cd40c6 100644 --- a/src/datadeps.jl +++ b/src/datadeps.jl @@ -22,8 +22,6 @@ end Deps(x, deps...) = Deps(x, deps) struct DataDepsTaskQueue <: AbstractTaskQueue - # The queue above us - upper_queue::AbstractTaskQueue # The set of tasks that have already been seen seen_tasks::Union{Vector{Pair{DTaskSpec,DTask}},Nothing} # The data-dependency graph of all tasks @@ -39,14 +37,14 @@ struct DataDepsTaskQueue <: AbstractTaskQueue # The fields following only apply when aliasing==true aliasing::Bool - function DataDepsTaskQueue(upper_queue; + function DataDepsTaskQueue(; traversal::Symbol=:inorder, scheduler::Symbol=:naive, aliasing::Bool=true) seen_tasks = Pair{DTaskSpec,DTask}[] g = SimpleDiGraph() task_to_id = Dict{DTask,Int}() - return new(upper_queue, seen_tasks, g, task_to_id, traversal, scheduler, + return new(seen_tasks, g, task_to_id, traversal, scheduler, aliasing) end end @@ -137,6 +135,35 @@ struct DataDepsNonAliasingState args_owner, args_readers) end end +struct DeferralState + # The map of DTask to DTaskSpec + task_to_spec::Dict{DTask,DTaskSpec} + + # The set of not yet fulfilled tasks upstream/downstream of a given task + waiting_upstream::Dict{DTask,Set{DTask}} + waiting_downstream::Dict{DTask,Set{DTask}} + + # The condition that finished tasks will trigger + cond::Threads.Condition + + # The set of tasks that are ready to be executed + ready::Set{DTask} + + # The Julia tasks waiting on Futures to trigger + tasks::IdDict{DTask,Task} + + function DeferralState() + task_to_spec = Dict{DTask,DTaskSpec}() + waiting_upstream = Dict{DTask,Set{DTask}}() + waiting_downstream = Dict{DTask,Set{DTask}}() + cond = Threads.Condition() + ready = Set{DTask}() + tasks = IdDict{DTask,Task}() + return new(task_to_spec, + waiting_upstream, waiting_downstream, + cond, ready, tasks) + end +end struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState}} # Whether aliasing is being analyzed aliasing::Bool @@ -147,18 +174,23 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState # The mapping of memory space to remote argument copies remote_args::Dict{MemorySpace,IdDict{Any,Any}} + # The state for deferring tasks + deferral_state::DeferralState + # The aliasing analysis state alias_state::State function DataDepsState(aliasing::Bool) dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[] remote_args = Dict{MemorySpace,IdDict{Any,Any}}() + deferral_state = DeferralState() if aliasing state = DataDepsAliasingState() else state = DataDepsNonAliasingState() end - return new{typeof(state)}(aliasing, dependencies, remote_args, state) + return new{typeof(state)}(aliasing, dependencies, remote_args, + deferral_state, state) end end @@ -232,6 +264,117 @@ function is_writedep(arg, deps, task::DTask) return any(dep->dep[3], deps) end +"Unwraps a completed `DTask` to get its `Chunk`." +fetch_ready(arg) = arg +function fetch_ready(arg::DTask) + @assert istaskdone(arg) "Task is not yet ready" + return fetch(arg; raw=true) +end +"Checks if a `DTask` is waiting on unfinished tasks, and must be deferred." +function must_defer!(state::DataDepsState, spec::DTaskSpec, task::DTask) + upstreams = Set{DTask}() + for (_, arg) in spec.args + arg, _ = unwrap_inout(arg) + if arg isa DTask && !istaskdone(arg) + push!(upstreams, arg) + end + end + if !isempty(upstreams) + dstate = state.deferral_state + dstate.task_to_spec[task] = spec + dstate.waiting_upstream[task] = upstreams + for upstream in upstreams + push!(get!(Set{DTask}, dstate.waiting_downstream, upstream), task) + try_watch_deferred!(dstate, upstream) + end + @dagdebug nothing :spawn_datadeps "Deferring task UID $(task.uid) ($(length(upstreams)) upstreams)" + return true + else + @dagdebug nothing :spawn_datadeps "Not deferring task UID $(task.uid)" + return false + end +end +"Tries to start watching for the completion of `task`." +function try_watch_deferred!(dstate, task) + if haskey(dstate.tasks, task) + # Already watching this task + return + end + if !istaskstarted(task) + # Can't start watching this task yet directly + return + end + if !haskey(dstate.waiting_downstream, task) + # No downstreams are watching + return + end + + # Start watching + ctx = Sch.eager_context() + sch_state = Sch.EAGER_STATE[] + future = ThunkFuture() + thunk_id = Sch.ThunkID(lock(Sch.EAGER_ID_MAP) do id_map + id_map[task.uid] + end, task.thunk_ref) + dstate.tasks[task] = Sch.errormonitor_tracked("datadeps future listener", Threads.@spawn begin + wait(future) + lock(dstate.cond) do + push!(dstate.ready, task) + notify(dstate.cond) + end + end) + Sch._register_future!(ctx, sch_state, current_task(), 0, (future, thunk_id, false)) + @dagdebug nothing :spawn_datadeps "Waiting on task UID $(task.uid)" +end +"Waits for deferred tasks to become unblocked." +function wait_deferrals!(state::DataDepsState) + dstate = state.deferral_state + specs_tasks = Vector{Pair{DTaskSpec,DTask}}() + + @label wait + + if isempty(dstate.waiting_upstream) + return specs_tasks + end + + # Wait for upstreams to become ready + ready = lock(dstate.cond) do + while isempty(dstate.ready) + wait(dstate.cond) + end + ready = copy(dstate.ready) + empty!(dstate.ready) + return ready + end + + # Find deferred downstreams that are now ready + for upstream in ready + @assert istaskdone(upstream) + downstreams = dstate.waiting_downstream[upstream] + delete!(dstate.waiting_downstream, upstream) + for downstream in downstreams + @dagdebug nothing :spawn_datadeps "Upstream task UID $(upstream.uid) unblocking downstream task UID $(downstream.uid)" + pop!(dstate.waiting_upstream[downstream], upstream) + if isempty(dstate.waiting_upstream[downstream]) + @dagdebug nothing :spawn_datadeps "Undeferring task UID $(downstream.uid)" + delete!(dstate.waiting_upstream, downstream) + push!(specs_tasks, dstate.task_to_spec[downstream] => downstream) + delete!(dstate.task_to_spec, downstream) + end + end + delete!(dstate.tasks, upstream) + end + + if isempty(specs_tasks) + @goto wait + end + + # Sort tasks by sequential ordering + sort!(specs_tasks, by=task->task[2].uid) + + return specs_tasks +end + # Aliasing state setup function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) # Populate task dependencies @@ -243,7 +386,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask) arg, deps = unwrap_inout(arg) # Unwrap the Chunk underlying any DTask arguments - arg = arg isa DTask ? fetch(arg; raw=true) : arg + arg = fetch_ready(arg) # Skip non-aliasing arguments type_may_alias(typeof(arg)) || continue @@ -397,9 +540,7 @@ end # Make a copy of each piece of data on each worker # memory_space => {arg => copy_of_arg} function generate_slot!(state::DataDepsState, dest_space, data) - if data isa DTask - data = fetch(data; raw=true) - end + data = fetch_ready(data) orig_space = memory_space(data) to_proc = first(processors(dest_space)) from_proc = first(processors(orig_space)) @@ -471,9 +612,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) @warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1 end - # Round-robin assign tasks to processors - upper_queue = get_options(:task_queue) - traversal = queue.traversal if traversal == :inorder # As-is @@ -518,7 +656,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue) end state = DataDepsState(queue.aliasing) - astate = state.alias_state sstate = DataDepsSchedulerState() for proc in all_procs space = only(memory_spaces(proc)) @@ -530,7 +667,24 @@ function distribute_tasks!(queue::DataDepsTaskQueue) write_num = 1 proc_idx = 1 pressures = Dict{Processor,Int}() - for (spec, task) in queue.seen_tasks[task_order] + specs_tasks = queue.seen_tasks[task_order] + while !isempty(specs_tasks) + write_num, proc_idx = distribute_tasks_launch!(queue, state, sstate, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures) + specs_tasks = wait_deferrals!(state) + end + distribute_tasks_copy_back!(queue, state, sstate, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures) +end +function distribute_tasks_launch!(queue::DataDepsTaskQueue, state::DataDepsState, sstate::DataDepsSchedulerState, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures) + astate = state.alias_state + + upper_queue = get_options(:task_queue) + + for (spec, task) in specs_tasks + # Check if this task must be deferred + if must_defer!(state, spec, task) + continue + end + # Populate all task dependencies populate_task_info!(state, spec, task) @@ -601,9 +755,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) args = Base.mapany(spec.args) do arg pos, data = arg data, _ = unwrap_inout(data) - if data isa DTask - data = fetch(data; raw=true) - end + data = fetch(data; raw=true) return pos => tochunk(data) end f_chunk = tochunk(spec.f) @@ -676,7 +828,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) for (idx, (pos, arg)) in enumerate(task_args) # Is the data written previously or now? arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg + arg = fetch_ready(arg) if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task) @dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)" spec.args[idx] = pos => arg @@ -748,7 +900,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue) syncdeps = get(Set{Any}, spec.options, :syncdeps) for (idx, (_, arg)) in enumerate(task_args) arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg + arg = fetch_ready(arg) type_may_alias(typeof(arg)) || continue if queue.aliasing for (dep_mod, _, writedep) in deps @@ -777,11 +929,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue) task_scope = our_scope spec.options = merge(spec.options, (;syncdeps, scope=task_scope)) enqueue!(upper_queue, spec=>task) + try_watch_deferred!(state.deferral_state, task) # Update read/write tracking for arguments for (idx, (_, arg)) in enumerate(task_args) arg, deps = unwrap_inout(arg) - arg = arg isa DTask ? fetch(arg; raw=true) : arg + arg = fetch_ready(arg) type_may_alias(typeof(arg)) || continue if queue.aliasing for (dep_mod, _, writedep) in deps @@ -810,6 +963,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue) proc_idx = mod1(proc_idx + 1, length(all_procs)) end + return write_num, proc_idx +end +function distribute_tasks_copy_back!(queue::DataDepsTaskQueue, state::DataDepsState, sstate::DataDepsSchedulerState, all_procs, exec_spaces, specs_tasks, write_num, proc_idx, pressures) + astate = state.alias_state + # Copy args from remote to local if queue.aliasing # We need to replay the writes from all tasks in-order (skipping any @@ -952,13 +1110,13 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true, launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool if launch_wait result = spawn_bulk() do - queue = DataDepsTaskQueue(get_options(:task_queue); + queue = DataDepsTaskQueue(; traversal, scheduler, aliasing) with_options(f; task_queue=queue) distribute_tasks!(queue) end else - queue = DataDepsTaskQueue(get_options(:task_queue); + queue = DataDepsTaskQueue(; traversal, scheduler, aliasing) result = with_options(f; task_queue=queue) distribute_tasks!(queue)