Skip to content

Add closure parsing to at-spawn #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Manifest.toml
Original file line number Diff line number Diff line change
@@ -3,6 +3,12 @@
julia_version = "1.7.3"
manifest_format = "2.0"

[[deps.Adapt]]
deps = ["LinearAlgebra"]
git-tree-sha1 = "195c5505521008abea5aee4f96930717958eac6f"
uuid = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
version = "3.4.0"

[[deps.Artifacts]]
uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"

2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -3,6 +3,7 @@ uuid = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
version = "0.18.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ContextVariablesX = "6add18c4-b38d-439d-96f6-d6bc489c04c5"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
@@ -21,6 +22,7 @@ TimespanLogging = "a526e669-04d3-4846-9525-c66122c55f63"
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"

[compat]
Adapt = "1, 2, 3"
ContextVariablesX = "0.1"
DataStructures = "0.18"
MacroTools = "0.5"
2 changes: 2 additions & 0 deletions src/Dagger.jl
Original file line number Diff line number Diff line change
@@ -14,12 +14,14 @@ using UUIDs

import ContextVariablesX

import Adapt
using Requires
using MacroTools
using TimespanLogging

include("lib/util.jl")
include("utils/dagdebug.jl")
include("utils/find-thunk.jl")

# Distributed data
include("options.jl")
123 changes: 96 additions & 27 deletions src/thunk.jl
Original file line number Diff line number Diff line change
@@ -332,7 +332,7 @@ generated thunks.
macro par(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=true, opts=opts)
generate_spawn(ex; lazy=true, opts=opts)
end

"""
@@ -346,7 +346,7 @@ See the docs for `@par` for more information and usage examples.
macro spawn(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=false, opts=opts)
generate_spawn(ex; lazy=false, opts=opts)
end

struct ExpandedBroadcast{F} end
@@ -360,40 +360,109 @@ function replace_broadcast(fn::Symbol)
return fn
end

function _par(ex::Expr; lazy=true, recur=true, opts=())
if ex.head == :call && recur
f = replace_broadcast(ex.args[1])
if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters)
args = ex.args[3:end]
kwargs = ex.args[2]
else
args = ex.args[2:end]
kwargs = Expr(:parameters)
function generate_spawn(ex::Expr; lazy=true, mode=nothing, opts=())
if mode === nothing
parse_idx = nothing
for (idx, opt) in enumerate(opts)
if Meta.isexpr(opt, :(=)) && opt.args[1] == :parse
if parse_idx !== nothing
throw(ArgumentError("`parse` can only be specified once"))
end
if !(opt.args[2] isa QuoteNode)
throw(ArgumentError("`parse` option value must be a constant Symbol"))
end
mode = opt.args[2].value
if !(mode in (:closure, :call)) # TODO: :recurse
throw(ArgumentError("Invalid parse mode: $(repr(mode))"))
end
parse_idx = idx
end
end
opts = esc.(opts)
args_ex = _par.(args; lazy=lazy, recur=false)
kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false)
if lazy
return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...)))
if parse_idx !== nothing
opts = (opts[1:(parse_idx-1)]..., opts[(parse_idx+1):end]...)
end
end
if mode === nothing
# Automatically pick a mode
if Meta.isexpr(ex, :call)
mode = :call
else
sync_var = esc(Base.sync_varname)
@gensym result
return quote
let args = ($(args_ex...),)
$result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...))
if $(Expr(:islocal, sync_var))
put!($sync_var, schedule(Task(()->wait($result))))
mode = :closure
end
end
if mode == :call
if !Meta.isexpr(ex, :call)
throw(ArgumentError("When `parse=:call`, expression must be a function call"))
end
f = esc(replace_broadcast(ex.args[1]))
has_kw(ex) = length(ex.args) >= 2 &&
(Meta.isexpr(ex.args[2], :parameters) ||
any(iex->Meta.isexpr(iex, :kw), ex.args))
if has_kw(ex)
if Meta.isexpr(ex.args[2], :parameters)
args = ex.args[3:end]
kwargs = ex.args[2].args
else
kwargs = Expr[]
for argidx in length(ex.args):-1:2
arg = ex.args[argidx]
if Meta.isexpr(arg, :kw)
pushfirst!(kwargs, arg)
deleteat!(ex.args, argidx)
end
$result
end
args = ex.args[2:end]
end
else
args = ex.args[2:end]
kwargs = Expr[]
end
args = map(esc, args)
kwargs = map(esc, kwargs)
elseif mode == :closure
f = :(()->$(esc(ex)))
args = []
kwargs = Expr[]
#= TODO: Recurse through AST
elseif mode == :recur
if Meta.isexpr(ex, :(=))
return Expr(:(=), ex.args[1], generate_spawn(ex.args[2]; lazy, mode, opts))
elseif Meta.isexpr(ex, :block) ||
Meta.isexpr(ex, :tuple)
return Expr(ex.head, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args)...)
elseif Meta.isexpr(ex, :if)
cond = ex.args[1]
cond = Expr(:call, :fetch, generate_spawn(cond; lazy, mode, opts))
return Expr(:if, cond, map(arg->generate_spawn(arg; lazy, mode, opts), ex.args[2:end])...)
elseif Meta.isexpr(ex, :call)
# FIXME: Handle recursive calls
return generate_spawn(ex; lazy, mode=:call, opts)
else
return ex
end
=#
end
opts = map(esc, opts)
if lazy
return quote
$delayed($(esc(f)), $Options(;$(opts...)))($(args...); $(kwargs...))
end
else
return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
sync_var = esc(Base.sync_varname)
@gensym result
return quote
let
$result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...))
if $(Expr(:islocal, sync_var))
put!($sync_var, $schedule($Task(()->wait($result))))
end
$result
end
end
end
end
_par(ex::Symbol; kwargs...) = esc(ex)
_par(ex; kwargs...) = ex
generate_spawn(ex::Symbol; kwargs...) = ex
generate_spawn(ex; kwargs...) = ex

persist!(t::Thunk) = (t.persist=true; t)
cache_result!(t::Thunk) = (t.cache=true; t)
6 changes: 6 additions & 0 deletions src/utils/find-thunks.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
struct RecordAdaptor
tasks::Set{Any}
end
struct FetchAdaptor end
Adapt.adapt_storage(ra::RecordAdaptor, t::Thunk) = (push!(ra.tasks, t); t)
Adapt.adapt_storage(::FetchAdaptor, t::Thunk) = fetch(t)