Skip to content

Commit

Permalink
Merge pull request #555 from JuliaParallel/jps/fetch-all
Browse files Browse the repository at this point in the history
Add fetch_all recursive helper
  • Loading branch information
jpsamaroo authored Jul 22, 2024
2 parents 2764b76 + bc042c9 commit 52a97dd
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 13 deletions.
28 changes: 15 additions & 13 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
version = "0.18.12"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
Expand All @@ -24,7 +25,21 @@ TaskLocalValues = "ed4db957-447d-4319-bfb6-7fa9ae7ecf34"
TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[extensions]
GraphVizExt = "GraphViz"
GraphVizSimpleExt = "Colors"
JSON3Ext = "JSON3"
PlotsExt = ["DataFrames", "Plots"]

[compat]
Adapt = "4.0.4"
Colors = "0.12"
DataFrames = "1"
DataStructures = "0.18"
Expand All @@ -43,22 +58,9 @@ TaskLocalValues = "0.1"
TimespanLogging = "0.1"
julia = "1.8"

[extensions]
GraphVizExt = "GraphViz"
GraphVizSimpleExt = "Colors"
JSON3Ext = "JSON3"
PlotsExt = ["DataFrames", "Plots"]

[extras]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"

[weakdeps]
Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
GraphViz = "f526b714-d49f-11e8-06ff-31ed36ee7ee0"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
2 changes: 2 additions & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ end
import TimespanLogging
import TimespanLogging: timespan_start, timespan_finish

import Adapt

include("lib/util.jl")
include("utils/dagdebug.jl")

Expand Down
2 changes: 2 additions & 0 deletions src/chunks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ is_task_or_chunk(c::Chunk) = true
Base.:(==)(c1::Chunk, c2::Chunk) = c1.handle == c2.handle
Base.hash(c::Chunk, x::UInt64) = hash(c.handle, hash(Chunk, x))

Adapt.adapt_storage(::FetchAdaptor, x::Chunk) = fetch(x)

collect_remote(chunk::Chunk) =
move(chunk.processor, OSProc(), poolget(chunk.handle))

Expand Down
14 changes: 14 additions & 0 deletions src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,20 @@ function spawn(f, args...; kwargs...)
return task
end

struct FetchAdaptor end
Adapt.adapt_storage(::FetchAdaptor, x::DTask) = fetch(x)
Adapt.adapt_structure(::FetchAdaptor, A::AbstractArray) =
map(x->Adapt.adapt(FetchAdaptor(), x), A)

"""
Dagger.fetch_all(x)
Recursively fetches all `DTask`s and `Chunk`s in `x`, returning an equivalent
object. Useful for converting arbitrary Dagger-enabled objects into a
non-Dagger form.
"""
fetch_all(x) = Adapt.adapt(FetchAdaptor(), x)

persist!(t::Thunk) = (t.persist=true; t)
cache_result!(t::Thunk) = (t.cache=true; t)

Expand Down
15 changes: 15 additions & 0 deletions test/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,4 +378,19 @@ end
Dagger.@spawn error()
end
end
@testset "fetch_all" begin
ts = [Dagger.@spawn(1+1) for _ in 1:4]
@test Dagger.fetch_all(ts) == [2, 2, 2, 2]
cs = map(t->fetch(t; raw=true), ts)
@test Dagger.fetch_all(cs) == [2, 2, 2, 2]

ts = Tuple(Dagger.@spawn(1+1) for _ in 1:4)
@test Dagger.fetch_all(ts) == (2, 2, 2, 2)
cs = fetch.(ts; raw=true)
@test Dagger.fetch_all(cs) == (2, 2, 2, 2)

t = Dagger.@spawn 1+1
@test Dagger.fetch_all(t) == 2
@test Dagger.fetch_all(fetch(t; raw=true)) == 2
end
end

0 comments on commit 52a97dd

Please sign in to comment.