Skip to content

Commit

Permalink
Merge pull request #533 from JuliaParallel/jps/parser-fixes
Browse files Browse the repository at this point in the history
parser: Assorted improvements
  • Loading branch information
jpsamaroo authored Jun 19, 2024
2 parents 684d80c + 86f2f5a commit b2fd2ab
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 20 deletions.
63 changes: 43 additions & 20 deletions src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ generated thunks.
macro par(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=true, opts=opts)
return esc(_par(ex; lazy=true, opts=opts))
end

"""
Expand Down Expand Up @@ -348,7 +348,7 @@ also passes along any options in an `Options` struct. For example,
macro spawn(exs...)
opts = exs[1:end-1]
ex = exs[end]
_par(ex; lazy=false, opts=opts)
return esc(_par(ex; lazy=false, opts=opts))
end

struct ExpandedBroadcast{F} end
Expand All @@ -363,39 +363,62 @@ function replace_broadcast(fn::Symbol)
end

function _par(ex::Expr; lazy=true, recur=true, opts=())
if ex.head == :call && recur
f = replace_broadcast(ex.args[1])
if length(ex.args) >= 2 && Meta.isexpr(ex.args[2], :parameters)
args = ex.args[3:end]
kwargs = ex.args[2]
else
args = ex.args[2:end]
kwargs = Expr(:parameters)
f = nothing
body = nothing
arg1 = nothing
if recur && @capture(ex, f_(allargs__)) || @capture(ex, f_(allargs__) do cargs_ body_ end) || @capture(ex, allargs__->body_) || @capture(ex, arg1_[allargs__])
f = replace_broadcast(f)
if arg1 !== nothing
# Indexing (A[2,3])
f = Base.getindex
pushfirst!(allargs, arg1)
end
args = filter(arg->!Meta.isexpr(arg, :parameters), allargs)
kwargs = filter(arg->Meta.isexpr(arg, :parameters), allargs)
if !isempty(kwargs)
kwargs = only(kwargs).args
end
if body !== nothing
if f !== nothing
f = quote
($(args...); $(kwargs...))->$f($(args...); $(kwargs...)) do $cargs
$body
end
end
else
f = quote
($(args...); $(kwargs...))->begin
$body
end
end
end
end
opts = esc.(opts)
args_ex = _par.(args; lazy=lazy, recur=false)
kwargs_ex = _par.(kwargs.args; lazy=lazy, recur=false)
if lazy
return :(Dagger.delayed($(esc(f)), $Options(;$(opts...)))($(args_ex...); $(kwargs_ex...)))
return :(Dagger.delayed($f, $Options(;$(opts...)))($(args...); $(kwargs...)))
else
sync_var = esc(Base.sync_varname)
sync_var = Base.sync_varname
@gensym result
return quote
let args = ($(args_ex...),)
$result = $spawn($(esc(f)), $Options(;$(opts...)), args...; $(kwargs_ex...))
let
$result = $spawn($f, $Options(;$(opts...)), $(args...); $(kwargs...))
if $(Expr(:islocal, sync_var))
put!($sync_var, schedule(Task(()->wait($result))))
end
$result
end
end
end
elseif lazy
# Recurse into the expression
return Expr(ex.head, _par_inner.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
else
return Expr(ex.head, _par.(ex.args, lazy=lazy, recur=recur, opts=opts)...)
throw(ArgumentError("Invalid Dagger task expression: $ex"))
end
end
_par(ex::Symbol; kwargs...) = esc(ex)
_par(ex; kwargs...) = ex
_par(ex; kwargs...) = throw(ArgumentError("Invalid Dagger task expression: $ex"))

_par_inner(ex; kwargs...) = ex
_par_inner(ex::Expr; kwargs...) = _par(ex; kwargs...)

"""
Dagger.spawn(f, args...; kwargs...) -> DTask
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ if PROGRAM_FILE != "" && realpath(PROGRAM_FILE) == @__FILE__
pushfirst!(LOAD_PATH, joinpath(@__DIR__, ".."))
using Pkg
Pkg.activate(@__DIR__)
Pkg.instantiate()

using ArgParse
s = ArgParseSettings(description = "Dagger Testsuite")
Expand Down
64 changes: 64 additions & 0 deletions test/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,70 @@ end
@test fetch(@spawn A .+ B) A .+ B
@test fetch(@spawn A .* B) A .* B
end
@testset "inner macro" begin
A = rand(4)
t = @spawn sum(@view A[2:3])
@test t isa Dagger.DTask
@test fetch(t) sum(@view A[2:3])
end
@testset "do block" begin
A = rand(4)

t = @spawn sum(A) do a
a + 1
end
@test t isa Dagger.DTask
@test fetch(t) sum(a->a+1, A)

t = @spawn sum(A; dims=1) do a
a + 1
end
@test t isa Dagger.DTask
@test fetch(t) sum(a->a+1, A; dims=1)

do_f = f -> f(42)
t = @spawn do_f() do x
x + 1
end
@test t isa Dagger.DTask
@test fetch(t) == 43
end
@testset "anonymous direct call" begin
A = rand(4)

t = @spawn A->sum(A)
@test t isa Dagger.DTask
@test fetch(t) == sum(A)

t = @spawn A->sum(A; dims=1)
@test t isa Dagger.DTask
@test fetch(t) == sum(A; dims=1)
end
@testset "getindex" begin
A = rand(4, 4)

t = @spawn A[1, 2]
@test t isa Dagger.DTask
@test fetch(t) == A[1, 2]

B = Dagger.@spawn rand(4, 4)
t = @spawn B[1, 2]
@test t isa Dagger.DTask
@test fetch(t) == fetch(B)[1, 2]

R = Ref(42)
t = @spawn R[]
@test t isa Dagger.DTask
@test fetch(t) == 42
end
@testset "invalid expression" begin
@test_throws LoadError eval(:(@spawn 1))
@test_throws LoadError eval(:(@spawn begin 1 end))
@test_throws LoadError eval(:(@spawn begin
1+1
1+1
end))
end
@testset "waiting" begin
a = @spawn sleep(1)
@test !isready(a)
Expand Down

0 comments on commit b2fd2ab

Please sign in to comment.