Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

datadeps: Add at-stencil helper #564

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
5 changes: 4 additions & 1 deletion docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ makedocs(;
"Scopes" => "scopes.md",
"Processors" => "processors.md",
"Task Queues" => "task-queues.md",
"Datadeps" => "datadeps.md",
"Datadeps" => [
"Basics" => "datadeps.md",
"Stencils" => "stencils.md",
],
"Option Propagation" => "propagation.md",
"Logging and Visualization" => [
"Logging: Basics" => "logging.md",
Expand Down
43 changes: 43 additions & 0 deletions docs/src/stencils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Stencil Operations



```julia
N = 27
nt = 3
tiles = zeros(Blocks(N, N), Bool, N*nt, N*nt)
outputs = zeros(Blocks(N, N), Bool, N*nt, N*nt)

# Create fun initial state
tiles[13, 14] = 1
tiles[14, 14] = 1
tiles[15, 14] = 1
tiles[15, 15] = 1
tiles[14, 16] = 1
@view(tiles[(2N+1):3N, (2N+1):3N]) .= rand(Bool, N, N)

import Dagger: @stencil, Wrap

anim = @animate for _ in 1:niters
Dagger.spawn_datadeps() do
@stencil begin
outputs[idx] = begin
nhood = @neighbors(tiles[idx], 1, Wrap())
neighs = sum(nhood) - tiles[idx]
if tiles[idx] && neighs < 2
0
elseif tiles[idx] && neighs > 3
0
elseif !tiles[idx] && neighs == 3
1
else
tiles[idx]
end
end
tiles[idx] = outputs[idx]
end
end
heatmap(Int.(collect(outputs)))
end
path = mp4(anim; fps=5, show_msg=true).filename
```
5 changes: 4 additions & 1 deletion src/Dagger.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ include("utils/dagdebug.jl")
include("utils/locked-object.jl")
include("utils/tasks.jl")

import MacroTools: @capture
import MacroTools: @capture, prewalk

include("options.jl")
include("processor.jl")
include("threadproc.jl")
Expand All @@ -65,6 +66,8 @@ include("sch/Sch.jl"); using .Sch

# Data dependency task queue
include("datadeps.jl")
include("utils/haloarray.jl")
include("stencil.jl")

# Array computations
include("array/darray.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/array/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ function Base.setindex!(A::DArray{T,N}, value, idx::NTuple{N,Int}) where {T,N}
# Set the value
part = A.chunks[part_idx...]
space = memory_space(part)
scope = Dagger.scope(worker=root_worker_id(space))
# FIXME: Do this correctly w.r.t memory space of part
scope = Dagger.scope(worker=root_worker_id(space), threads=:)
return fetch(Dagger.@spawn scope=scope setindex!(part, value, offset_idx...))
end
Base.setindex!(A::DArray, value, idx::Integer...) =
Expand Down
4 changes: 2 additions & 2 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,8 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# Is the data written previously or now?
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
if !type_may_alias(typeof(arg))
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)"
spec.args[idx] = pos => arg
continue
end
Expand Down
21 changes: 16 additions & 5 deletions src/scopes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,20 @@ function to_scope(sc::NamedTuple)
else
nothing
end
all_threads = false
threads = if haskey(sc, :thread)
Int[sc.thread]
elseif haskey(sc, :threads)
Int[sc.threads...]
if sc.threads == Colon()
all_threads = true
nothing
else
Int[sc.threads...]
end
else
nothing
end
want_threads = all_threads || threads !== nothing

# Simple cases
if workers !== nothing && threads !== nothing
Expand All @@ -341,18 +348,22 @@ function to_scope(sc::NamedTuple)
end
return simplified_union_scope(subscopes)
elseif workers !== nothing && threads === nothing
subscopes = AbstractScope[ProcessScope(w) for w in workers]
return simplified_union_scope(subscopes)
subscopes = simplified_union_scope(AbstractScope[ProcessScope(w) for w in workers])
if all_threads
return constrain(subscopes, ProcessorTypeScope(ThreadProc))
else
return subscopes
end
end

# More complex cases that require querying the cluster
# FIXME: Use per-field scope taint
if workers === nothing
workers = procs()
workers = map(p->p.pid, filter(p->p isa OSProc, procs(Dagger.Sch.eager_context())))
end
subscopes = AbstractScope[]
for w in workers
if threads === nothing
if threads === nothing && want_threads
threads = map(c->c.tid,
filter(c->c isa ThreadProc,
collect(children(OSProc(w)))))
Expand Down
Loading
Loading