Skip to content

Commit

Permalink
fixup! Initial support for robustly migrating streaming tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
JamesWrigley committed Sep 6, 2024
1 parent 3ac1a40 commit 6168d06
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 20 deletions.
80 changes: 60 additions & 20 deletions src/stream.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@ mutable struct StreamStore{T,B}
buffers::Dict{Int,B}
buffer_amount::Int
open::Bool
migrating::Bool
lock::Threads.Condition
StreamStore{T,B}(buffer_amount::Integer) where {T,B} =
new{T,B}(zeros(Int, 0), Dict{Int,B}(), buffer_amount,
true, Threads.Condition())
true, false, Threads.Condition())
end

function tid_to_uid(thunk_id)
Expand Down Expand Up @@ -175,11 +176,24 @@ end

remove_waiters!(stream::Stream, waiter::Integer) = remove_waiters!(stream, Int[waiter])

function migrate_streamingfunction!(sf::StreamingFunction, w::Integer=myid())
current_worker = sf.stream.store_ref.handle.owner
if myid() != current_worker
return remotecall_fetch(migrate_streamingfunction!, current_worker, sf, w)
end

sf.stream.store.migrating = true
@lock sf.status_event wait(sf.status_event) # Wait for the streaming function to finish
end

function migrate_stream!(stream::Stream, w::Integer=myid())
# Perform migration of the StreamStore
# MemPool will block access to the new ref until the migration completes
# FIXME: Do this with MemPool.access_ref, in case stream was already migrated
if stream.store_ref.handle.owner != w
thunk_id = STREAM_THUNK_ID[]
@dagdebug thunk_id :stream "Beginning migration..."

new_store_ref = MemPool.migrate!(stream.store_ref.handle, w;
pre_migration=store->begin
# Lock store to prevent any further modifications
Expand All @@ -197,6 +211,9 @@ function migrate_stream!(stream::Stream, w::Integer=myid())
put!(store.buffers[id], item)
end
end

# Ensure that the 'migrating' flag is not set
store.migrating = false
end,
post_migration=store->begin
# Unlock the store
Expand All @@ -206,6 +223,8 @@ function migrate_stream!(stream::Stream, w::Integer=myid())
if w == myid()
stream.store = MemPool.access_ref(identity, new_store_ref; local_only=true)
end

@dagdebug thunk_id :stream "Migration complete"
end
end

Expand Down Expand Up @@ -289,11 +308,25 @@ struct StreamingFunction{F, S}
f::F
stream::S
max_evals::Int
status_event::Threads.Event
migration_complete::Threads.Event
end

chunktype(sf::StreamingFunction{F}) where F = F

function (sf::StreamingFunction)(args...; kwargs...)
ret = :migrating
while ret === :migrating
worker_id = sf.stream.store_ref.handle.owner
ret = if worker_id == myid()
_run_streamingfunction(args...; kwargs...)
else
remotecall_fetch(_run_streamingfunction, worker_id, args...; kwargs...)
end
end
end

function _run_streamingfunction(args...; kwargs...)
@nospecialize sf args kwargs
result = nothing
thunk_id = Sch.sch_handle().thunk_id.id
Expand All @@ -309,10 +342,9 @@ function (sf::StreamingFunction)(args...; kwargs...)
end
end

# Migrate our output stream to this worker
# Migrate our output stream store to this worker
if sf.stream isa Stream
migrate_stream!(sf.stream)
@dagdebug thunk_id :stream "Migration complete"
end

try
Expand All @@ -327,27 +359,31 @@ function (sf::StreamingFunction)(args...; kwargs...)
end
return stream!(sf, uid, (args...,), kwarg_names, kwarg_values)
finally
# Remove ourself as a waiter for upstream Streams
streams = Set{Stream}()
for (idx, arg) in enumerate(args)
if arg isa Stream
push!(streams, arg)
if !sf.stream.store.migrated
# Remove ourself as a waiter for upstream Streams
streams = Set{Stream}()
for (idx, arg) in enumerate(args)
if arg isa Stream
push!(streams, arg)
end
end
end
for (idx, (pos, arg)) in enumerate(kwargs)
if arg isa Stream
push!(streams, arg)
for (idx, (pos, arg)) in enumerate(kwargs)
if arg isa Stream
push!(streams, arg)
end
end
end
for stream in streams
@dagdebug thunk_id :stream "dropping waiter"
remove_waiters!(stream, uid)
@dagdebug thunk_id :stream "dropped waiter"
for stream in streams
@dagdebug thunk_id :stream "dropping waiter"
remove_waiters!(stream, uid)
@dagdebug thunk_id :stream "dropped waiter"
end

# Ensure downstream tasks also terminate
@dagdebug thunk_id :stream "closed stream"
close(sf.stream)
end

# Ensure downstream tasks also terminate
@dagdebug thunk_id :stream "closed stream"
close(sf.stream)
notify(sf.status_event)
end
end

Expand All @@ -358,6 +394,10 @@ function stream!(sf::StreamingFunction, uid,
counter = 0

while sf.max_evals < 0 || counter < sf.max_evals
if sf.stream.store.migrating
return :migrating
end

# Get values from Stream args/kwargs
stream_args = _stream_take_values!(args, uid)
stream_kwarg_values = _stream_take_values!(kwarg_values, uid)
Expand Down
40 changes: 40 additions & 0 deletions test/streaming.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import MemPool: access_ref

@everywhere begin
"""
A functor to produce a certain number of outputs.
Expand Down Expand Up @@ -77,6 +79,44 @@ end
@testset "Basics" begin
master_scope = Dagger.scope(worker=myid())

@test test_finishes("Migration") do
if nprocs() == 1
@warn "Skipping migration test because it requires at least 1 extra worker"
return
end

# Start streaming locally
mailbox = RemoteChannel()
producer = Producer(Inf, mailbox)
x = Dagger.spawn_streaming() do
Dagger.spawn(producer, Dagger.Options(; scope=master_scope))
end

# Wait for the stream to get started
while producer.count < 2
sleep(0.1)
end

# Migrate to another worker
access_ref(x.thunk_ref) do thunk
access_ref(thunk.f.handle) do streaming_function
Dagger.migrate_stream!(streaming_function.stream, workers()[1])
end
end

# Wait a bit for the stream to get started again on the other node
sleep(0.5)

# Stop it
put!(mailbox, :exit)
fetch(x)

final_count = take!(mailbox)
@info "Counts:" producer.count final_count
end

return

@test test_finishes("Single task") do
local x
Dagger.spawn_streaming() do
Expand Down

0 comments on commit 6168d06

Please sign in to comment.