Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial support for robustly migrating streaming tasks #568

Merged
merged 35 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
f11b2bf
Minor style cleanup
JamesWrigley Jun 25, 2024
92216fb
Use `DTaskFailedException` and increase the default timeout
JamesWrigley Aug 18, 2024
a146387
Initial support for robustly migrating streaming tasks
JamesWrigley Aug 18, 2024
ccca00e
Inherit the top-level testsets in the streaming tests
JamesWrigley Aug 20, 2024
3ac1a40
Replace `rand_finite()` with a deterministic `Producer` functor
JamesWrigley Aug 20, 2024
6168d06
fixup! Initial support for robustly migrating streaming tasks
JamesWrigley Sep 6, 2024
c194d2a
fixup! fixup! Initial support for robustly migrating streaming tasks
jpsamaroo Sep 13, 2024
a2771e1
task-tls: Refactor into DTaskTLS struct
jpsamaroo May 22, 2024
d0507b5
fixup! task-tls: Refactor into DTaskTLS struct
jpsamaroo Sep 13, 2024
a70c4e9
cancellation: Add cancel token support
jpsamaroo Sep 13, 2024
8ea068f
streaming: Handle cancellation
jpsamaroo Sep 13, 2024
42ad978
fixup! cancellation: Add cancel token support
jpsamaroo Sep 13, 2024
48fae56
fixup! fixup! fixup! Initial support for robustly migrating streaming…
jpsamaroo Sep 13, 2024
8d90e6c
Sch: Add unwrap_nested_exception for DTaskFailedException
jpsamaroo Sep 14, 2024
1989aeb
ProcessRingBuffer: Add length method
jpsamaroo Sep 14, 2024
ec4fe5c
fixup! fixup! cancellation: Add cancel token support
jpsamaroo Sep 14, 2024
c38f44a
streaming: Buffers and tasks per input/output
jpsamaroo Sep 14, 2024
94fbaaf
fixup! fixup! fixup! cancellation: Add cancel token support
jpsamaroo Sep 24, 2024
f40e33a
Sch: Trigger cancel token on task exit
jpsamaroo Sep 24, 2024
deb5c4d
Add task_id for DTask
jpsamaroo Sep 24, 2024
ffafe83
ProcessRingBuffer: Allow closure
jpsamaroo Sep 24, 2024
2403395
RemoteFetcher: Only collect values up to free buffer space
jpsamaroo Sep 24, 2024
92ffc3b
streaming: Close buffers on closing StreamStore
jpsamaroo Sep 24, 2024
bc1ebb4
task-tls: Tweaks and fixes, task_id helper
jpsamaroo Sep 24, 2024
66c089c
task-tls: Add task_cancel!
jpsamaroo Sep 24, 2024
68b0622
streaming: max_evals cannot be specified as 0
jpsamaroo Sep 24, 2024
022e398
streaming: Small tweaks to migration and cancellation
jpsamaroo Sep 24, 2024
c9c052d
dagdebug: Always yield to avoid heisenbugs
jpsamaroo Sep 24, 2024
00d1729
tests: Revamp streaming tests
jpsamaroo Sep 24, 2024
782c5ac
tests: Add offline mode
jpsamaroo Sep 24, 2024
65a42e2
dagdebug: Add JULIA_DAGGER_DEBUG config variable
jpsamaroo Oct 2, 2024
4edb446
cancellation: Add graceful vs. forced
jpsamaroo Oct 3, 2024
9051a41
cancellation: Wrap InterruptException in DTaskFailedException
jpsamaroo Oct 3, 2024
796bcf0
options: Add internal helper to strip all options
jpsamaroo Oct 3, 2024
0a86a70
streaming: Get tests passing
jpsamaroo Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ else
end
import TaskLocalValues: TaskLocalValue

import TaskLocalValues: TaskLocalValue

if !isdefined(Base, :get_extension)
import Requires: @require
end
Expand All @@ -47,16 +49,16 @@ include("processor.jl")
include("threadproc.jl")
include("context.jl")
include("utils/processors.jl")
include("dtask.jl")
include("cancellation.jl")
include("task-tls.jl")
include("scopes.jl")
include("utils/scopes.jl")
include("dtask.jl")
include("queue.jl")
include("thunk.jl")
include("submission.jl")
include("chunks.jl")
include("memory-spaces.jl")
include("cancellation.jl")

# Task scheduling
include("compute.jl")
Expand All @@ -69,9 +71,9 @@ include("sch/Sch.jl"); using .Sch
include("datadeps.jl")

# Streaming
include("stream-buffers.jl")
include("stream-fetchers.jl")
include("stream.jl")
include("stream-buffers.jl")
include("stream-transfer.jl")

# Array computations
include("array/darray.jl")
Expand Down Expand Up @@ -152,6 +154,20 @@ function __init__()
ThreadProc(myid(), tid)
end
end

# Set up @dagdebug categories, if specified
try
if haskey(ENV, "JULIA_DAGGER_DEBUG")
empty!(DAGDEBUG_CATEGORIES)
for category in split(ENV["JULIA_DAGGER_DEBUG"], ",")
if category != ""
push!(DAGDEBUG_CATEGORIES, Symbol(category))
end
end
end
catch err
@warn "Error parsing JULIA_DAGGER_DEBUG" exception=err
end
end

end # module
2 changes: 0 additions & 2 deletions src/array/indexing.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import TaskLocalValues: TaskLocalValue

### getindex

struct GetIndex{T,N} <: ArrayOp{T,N}
Expand Down
55 changes: 52 additions & 3 deletions src/cancellation.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,51 @@
# DTask-level cancellation

mutable struct CancelToken
@atomic cancelled::Bool
@atomic graceful::Bool
event::Base.Event
end
CancelToken() = CancelToken(false, false, Base.Event())
function cancel!(token::CancelToken; graceful::Bool=true)
if !graceful
@atomic token.graceful = false
end
@atomic token.cancelled = true
notify(token.event)
return
end
function is_cancelled(token::CancelToken; must_force::Bool=false)
if token.cancelled[]
if must_force && token.graceful[]
# If we're only responding to forced cancellation, ignore graceful cancellations
return false
end
return true
end
return false
end
Base.wait(token::CancelToken) = wait(token.event)
# TODO: Enable this for safety
#Serialization.serialize(io::AbstractSerializer, ::CancelToken) =
# throw(ConcurrencyViolationError("Cannot serialize a CancelToken"))

const DTASK_CANCEL_TOKEN = TaskLocalValue{Union{CancelToken,Nothing}}(()->nothing)

function clone_cancel_token_remote(orig_token::CancelToken, wid::Integer)
remote_token = remotecall_fetch(wid) do
return poolset(CancelToken())
end
errormonitor_tracked("remote cancel_token communicator", Threads.@spawn begin
wait(orig_token)
@dagdebug nothing :cancel "Cancelling remote token on worker $wid"
MemPool.access_ref(remote_token) do remote_token
cancel!(remote_token)
end
end)
end

# Global-level cancellation

"""
cancel!(task::DTask; force::Bool=false, halt_sch::Bool=false)

Expand Down Expand Up @@ -48,7 +96,7 @@ function _cancel!(state, tid, force, halt_sch)
for task in state.ready
tid !== nothing && task.id != tid && continue
@dagdebug tid :cancel "Cancelling ready task"
state.cache[task] = InterruptException()
state.cache[task] = DTaskFailedException(task, task, InterruptException())
state.errored[task] = true
Sch.set_failed!(state, task)
end
Expand All @@ -58,7 +106,7 @@ function _cancel!(state, tid, force, halt_sch)
for task in keys(state.waiting)
tid !== nothing && task.id != tid && continue
@dagdebug tid :cancel "Cancelling waiting task"
state.cache[task] = InterruptException()
state.cache[task] = DTaskFailedException(task, task, InterruptException())
state.errored[task] = true
Sch.set_failed!(state, task)
end
Expand All @@ -80,11 +128,11 @@ function _cancel!(state, tid, force, halt_sch)
Tf === typeof(Sch.eager_thunk) && continue
istaskdone(task) && continue
any_cancelled = true
@dagdebug tid :cancel "Cancelling running task ($Tf)"
if force
@dagdebug tid :cancel "Interrupting running task ($Tf)"
Threads.@spawn Base.throwto(task, InterruptException())
else
@dagdebug tid :cancel "Cancelling running task ($Tf)"
# Tell the processor to just drop this task
task_occupancy = task_spec[4]
time_util = task_spec[2]
Expand All @@ -93,6 +141,7 @@ function _cancel!(state, tid, force, halt_sch)
push!(istate.cancelled, tid)
to_proc = istate.proc
put!(istate.return_queue, (myid(), to_proc, tid, (InterruptException(), nothing)))
cancel!(istate.cancel_tokens[tid]; graceful=false)
end
end
end
Expand Down
6 changes: 6 additions & 0 deletions src/options.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@ function with_options(f, options::NamedTuple)
end
with_options(f; options...) = with_options(f, NamedTuple(options))

function _without_options(f)
with(options_context => NamedTuple()) do
f()
end
end

"""
get_options(key::Symbol, default) -> Any
get_options(key::Symbol) -> Any
Expand Down
16 changes: 15 additions & 1 deletion src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ struct ProcessorInternalState
proc_occupancy::Base.RefValue{UInt32}
time_pressure::Base.RefValue{UInt64}
cancelled::Set{Int}
cancel_tokens::Dict{Int,Dagger.CancelToken}
done::Base.RefValue{Bool}
end
struct ProcessorState
Expand All @@ -1189,7 +1190,7 @@ function proc_states(f::Base.Callable, uid::UInt64)
end
end
proc_states(f::Base.Callable) =
proc_states(f, task_local_storage(:_dagger_sch_uid)::UInt64)
proc_states(f, Dagger.get_tls().sch_uid)

task_tid_for_processor(::Processor) = nothing
task_tid_for_processor(proc::Dagger.ThreadProc) = proc.tid
Expand Down Expand Up @@ -1318,7 +1319,14 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re

# Execute the task and return its result
t = @task begin
# Set up cancellation
cancel_token = Dagger.CancelToken()
Dagger.DTASK_CANCEL_TOKEN[] = cancel_token
lock(istate.queue) do _
istate.cancel_tokens[thunk_id] = cancel_token
end
was_cancelled = false

result = try
do_task(to_proc, task)
catch err
Expand All @@ -1335,6 +1343,7 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
# Task was cancelled, so occupancy and pressure are
# already reduced
pop!(istate.cancelled, thunk_id)
delete!(istate.cancel_tokens, thunk_id)
was_cancelled = true
end
end
Expand All @@ -1352,6 +1361,9 @@ function start_processor_runner!(istate::ProcessorInternalState, uid::UInt64, re
else
rethrow(err)
end
finally
# Ensure that any spawned tasks get cleaned up
Dagger.cancel!(cancel_token)
end
end
lock(istate.queue) do _
Expand Down Expand Up @@ -1401,6 +1413,7 @@ function do_tasks(to_proc, return_queue, tasks)
Dict{Int,Vector{Any}}(),
Ref(UInt32(0)), Ref(UInt64(0)),
Set{Int}(),
Dict{Int,Dagger.CancelToken}(),
Ref(false))
runner = start_processor_runner!(istate, uid, return_queue)
@static if VERSION < v"1.9"
Expand Down Expand Up @@ -1640,6 +1653,7 @@ function do_task(to_proc, task_desc)
sch_handle,
processor=to_proc,
task_spec=task_desc,
cancel_token=Dagger.DTASK_CANCEL_TOKEN[],
))

res = Dagger.with_options(propagated) do
Expand Down
2 changes: 1 addition & 1 deletion src/sch/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ struct SchedulerHandle
end

"Gets the scheduler handle for the currently-executing thunk."
sch_handle() = task_local_storage(:_dagger_sch_handle)::SchedulerHandle
sch_handle() = Dagger.get_tls().sch_handle::SchedulerHandle

"Thrown when the scheduler halts before finishing processing the DAG."
struct SchedulerHaltedException <: Exception end
Expand Down
3 changes: 3 additions & 0 deletions src/sch/eager.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,6 @@ function _find_thunk(e::Dagger.DTask)
unwrap_weak_checked(EAGER_STATE[].thunk_dict[tid])
end
end
Dagger.task_id(t::Dagger.DTask) = lock(EAGER_ID_MAP) do id_map
id_map[t.uid]
end
2 changes: 2 additions & 0 deletions src/sch/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ unwrap_nested_exception(err::CapturedException) =
unwrap_nested_exception(err.ex)
unwrap_nested_exception(err::RemoteException) =
unwrap_nested_exception(err.captured)
unwrap_nested_exception(err::DTaskFailedException) =
unwrap_nested_exception(err.ex)
unwrap_nested_exception(err) = err

"Gets a `NamedTuple` of options propagated by `thunk`."
Expand Down
Loading