Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix write dag #531

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
230 changes: 181 additions & 49 deletions ext/GraphVizSimpleExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,17 @@ else
end

import Dagger
import Dagger: Chunk, Thunk, Processor
import Dagger: Chunk, Thunk, DTask, Processor
import Dagger: show_logs
import Dagger: istask, dependents
import Dagger: unwrap_weak
import Dagger.TimespanLogging: Timespan

### DAG-based graphing

global _part_labels = Dict()

function write_node(ctx, io, t::Chunk, c)
function write_node(io, t::Chunk, c, ctx=nothing)
_part_labels[t]="part_$c"
c+1
end
Expand All @@ -32,27 +34,109 @@ function node_id(t::Chunk)
_part_labels[t]
end

function node_name(t::Thunk)
"n_$(t.id)"
end

function node_name(t::Chunk)
_part_labels[t]
end

function node_name(name::String)
"n_$name"
end

function node_name(id)
"n_$id"
end

# Modified version of the function from Dagger compute.jl
function custom_dependents(node::Thunk)
deps = Dict{Union{Thunk,Chunk}, Set{Thunk}}()
visited = Set{Thunk}()
to_visit = Set{Thunk}()
push!(to_visit, node)
while !isempty(to_visit)
next = pop!(to_visit)
(next in visited) && continue
if !haskey(deps, next)
deps[next] = Set{Thunk}()
end
for inp in next.syncdeps
unwrapped = unwrap_weak(inp)
if (unwrapped === nothing)
continue
end
inp = unwrapped
if istask(inp) || (inp isa Chunk)
s = get!(()->Set{Thunk}(), deps, inp)
push!(s, next)
if istask(inp) && !(inp in visited)
push!(to_visit, inp)
end
end
end
push!(visited, next)
end
return deps
end

# Writing DAG using DTask involves unwrapping WeakRefs, which might return `nothing` if garbage collected,
# so the part of the DAG is not displayed. This is an unstable behavior, so disabled by default.
function write_dag(io, e::DTask, stable::Bool=true)
if (stable)
throw(ArgumentError("Writing DAG for DTask is not supported by default. Use the logs instead."))
else
t = convert_to_thunk(e)
write_dag(io, t)
end
end

function write_dag(io, t::Thunk)
!istask(t) && return
deps = dependents(t)

# Chunk/Thunk nodes
deps = custom_dependents(t)
c=1
for k in keys(deps)
c = write_node(nothing, io, k, c)
c = write_node(io, k, c)
end
for (k, v) in deps
for dep in v
if isa(k, Union{Chunk, Thunk})
println(io, "$(node_id(k)) -> $(node_id(dep))")
println(io, "$(node_name(k)) -> $(node_name(dep))")
end
end
end

# Argument nodes (not Chunks/Thunks)
argmap = Dict{Int,Vector}()
getargs!(argmap, t)
argids = IdDict{Any,String}()
for id in keys(argmap)
for (argidx,arg) in argmap[id]
name = "arg_$(argidx)_to_$(id)"
if !isimmutable(arg)
if arg in keys(argids)
name = argids[arg]
else
argids[arg] = name
c = write_node(io, arg, c, name)
end
else
c = write_node(io, arg, c, name)
end
# Arg-to-compute edges
write_edge(io, name, id)
end
end
end

### Timespan-based graphing

pretty_time(ts::Timespan) = pretty_time(ts.finish-ts.start)
function pretty_time(t)
r(t) = round(t; digits=3)
pretty_time(ts::Timespan; digits::Integer=3) = pretty_time(ts.finish-ts.start; digits=digits)
function pretty_time(t; digits::Integer=3)
r(t) = round(t; digits)
if t > 1000^3
"$(r(t/(1000^3))) s"
elseif t > 1000^2
Expand Down Expand Up @@ -99,47 +183,53 @@ _proc_shape(ctx, proc::Processor) = get!(ctx.proc_to_shape, typeof(proc)) do
end
_proc_shape(ctx, ::Nothing) = "ellipse"

function write_node(ctx, io, t::Thunk, c)
function write_node(io, t::Thunk, c, ctx=nothing)
f = isa(t.f, Function) ? "$(t.f)" : "fn"
println(io, "n_$(t.id) [label=\"$f - $(t.id)\"];")
println(io, "$(node_name(t)) [label=\"$f - $(t.id)\"];")
c
end

dec(x) = Base.dec(x, 0, false)
function write_node(ctx, io, t, c, id=dec(hash(t)))
function write_node(io, t, c, ctx, id=dec(hash(t)))
l = replace(node_label(t), "\""=>"")
proc = node_proc(t)
color = _proc_color(ctx, proc)
shape = _proc_shape(ctx, proc)
println(io, "n_$id [label=\"$l\",color=\"$color\",shape=\"$shape\",penwidth=5];")
println(io, "$(node_name(id)) [label=\"$l\",color=\"$color\",shape=\"$shape\",penwidth=5];")
c
end

function write_node(ctx, io, ts::Timespan, c)
function write_node(io, t, c, name::String)
l = replace(node_label(t), "\""=>"")
println(io, "$(node_name(name)) [label=\"$l\"];")
c
end

function write_node(io, ts::Timespan, c, ctx; times_digits::Integer=3)
(;thunk_id, processor) = ts.id
(;f) = ts.timeline
f = isa(f, Function) ? "$f" : "fn"
t_comp = pretty_time(ts)
t_comp = pretty_time(ts; digits=times_digits)
color = _proc_color(ctx, processor)
shape = _proc_shape(ctx, processor)
# TODO: t_log = log(ts.finish - ts.start) / 5
ctx.id_to_proc[thunk_id] = processor
println(io, "n_$thunk_id [label=\"$f\n$t_comp\",color=\"$color\",shape=\"$shape\",penwidth=5];")
println(io, "$(node_name(thunk_id)) [label=\"$f\n$t_comp\",color=\"$color\",shape=\"$shape\",penwidth=5];")
# TODO: "\n Thunk $(ts.id)\nResult Type: $res_type\nResult Size: $sz_comp\",
c
end

function write_edge(ctx, io, ts_move::Timespan, logs, inputname=nothing, inputarg=nothing)
function write_edge(io, ts_move::Timespan, logs, ctx, inputname=nothing, inputarg=nothing; times_digits::Integer=3)
(;thunk_id, id) = ts_move.id
(;f,) = ts_move.timeline
t_move = pretty_time(ts_move)
t_move = pretty_time(ts_move; digits=times_digits)
if id > 0
print(io, "n_$id -> n_$thunk_id [label=\"Move: $t_move")
print(io, "$(node_name(id)) -> $(node_name(thunk_id)) [label=\"Move: $t_move")
color_src = _proc_color(ctx, id)
else
@assert inputname !== nothing
@assert inputarg !== nothing
print(io, "n_$inputname -> n_$thunk_id [label=\"Move: $t_move")
print(io, "$(node_name(inputname)) -> $(node_name(thunk_id)) [label=\"Move: $t_move")
proc = node_proc(inputarg)
color_src = _proc_color(ctx, proc)
end
Expand All @@ -148,62 +238,97 @@ function write_edge(ctx, io, ts_move::Timespan, logs, inputname=nothing, inputar
println(io, "\",color=\"$color_src;0.5:$color_dst\",penwidth=2];")
end

write_edge(ctx, io, from::String, to::String) = println(io, "n_$from -> n_$to;")
write_edge(io, from::String, to::String, ctx=nothing) = println(io, "$(node_name(from)) -> $(node_name(to));")
write_edge(io, from::String, to::Int, ctx=nothing) = println(io, "$(node_name(from)) -> $(node_name(to));")

convert_to_thunk(t::Thunk) = t
convert_to_thunk(t::DTask) = Dagger.Sch._find_thunk(t)

getargs!(d, t) = nothing

function getargs!(d, t::Thunk)
raw_inputs = map(last, t.inputs)
d[t.id] = [filter(x->!istask(x[2]), collect(enumerate(raw_inputs)))...,]
foreach(i->getargs!(d, i), raw_inputs)
end
function write_dag(io, t, logs::Vector)

function getargs!(d, e::DTask)
getargs!(d, convert_to_thunk(e))
end

# DTask is not used in the current implementation, as it would be unstable, and the logs provide all the necessary information
function write_dag(io, logs::Vector, t::Union{Thunk, DTask, Nothing}=nothing; times_digits::Integer=3)
ctx = (proc_to_color = Dict{Processor,String}(),
proc_colors = Colors.distinguishable_colors(128),
proc_color_idx = Ref{Int}(1),
proc_to_shape = Dict{Type,String}(),
proc_shapes = ("ellipse","box","triangle"),
proc_shape_idx = Ref{Int}(1),
id_to_proc = Dict{Int,Processor}())
argmap = Dict{Int,Vector}()
getargs!(argmap, t)

c = 1
# Compute nodes
for ts in filter(x->x.category==:compute, logs)
c = write_node(ctx, io, ts, c)
c = write_node(io, ts, c, ctx; times_digits=times_digits)
end
# Argument nodes
argnodemap = Dict{Int,Vector{String}}()

# Argument nodes & edges
argmap = Dict{Int,Vector}()
argids = IdDict{Any,String}()
for id in keys(argmap)
nodes = String[]
arg_c = 1
for (argidx,arg) in argmap[id]
name = "arg_$(argidx)_to_$(id)"
if (isa(t, Thunk)) # Then can get info from the Thunk
getargs!(argmap, t)
argnodemap = Dict{Int,Vector{String}}()
for id in keys(argmap)
nodes = String[]
arg_c = 1
for (argidx,arg) in argmap[id]
name = "arg_$(argidx)_to_$(id)"
if !isimmutable(arg)
if arg in keys(argids)
name = argids[arg]
else
argids[arg] = name
c = write_node(io, arg, c, ctx, name)
end
push!(nodes, name)
else
c = write_node(io, arg, c, ctx, name)
push!(nodes, name)
end
# Arg-to-compute edges
for ts in filter(x->x.category==:move &&
x.id.thunk_id==id &&
x.id.id==-argidx, logs)
write_edge(io, ts, logs, ctx, name, arg; times_digits=times_digits)
end
arg_c += 1
end
argnodemap[id] = nodes
end
else # Rely on the logs only
for ts in filter(x->x.category==:move && x.id.id < 0, logs)
(;thunk_id, id) = ts.id
arg = ts.timeline[2]
name = "arg_$(-id)_to_$(thunk_id)"
if !isimmutable(arg)
if arg in keys(argids)
name = argids[arg]
else
argids[arg] = name
c = write_node(ctx, io, arg, c, name)
c = write_node(io, arg, c, ctx, name)
end
push!(nodes, name)
else
c = write_node(ctx, io, arg, c, name)
push!(nodes, name)
c = write_node(io, arg, c, ctx, name)
end

# Arg-to-compute edges
for ts in filter(x->x.category==:move &&
x.id.thunk_id==id &&
x.id.id==-argidx, logs)
write_edge(ctx, io, ts, logs, name, arg)
end
arg_c += 1
write_edge(io, ts, logs, ctx, name, arg; times_digits=times_digits)
end
argnodemap[id] = nodes
end

# Move edges
for ts in filter(x->x.category==:move && x.id.id>0, logs)
write_edge(ctx, io, ts, logs)
write_edge(io, ts, logs, ctx; times_digits=times_digits)
end
#= FIXME: Legend (currently it's laid out horizontally)
println(io, """
Expand All @@ -224,21 +349,28 @@ function write_dag(io, t, logs::Vector)
=#
end

function _show_plan(io::IO, t)
function _show_plan(io::IO, t::Union{Thunk,DTask})
println(io, """strict digraph {
graph [layout=dot,rankdir=LR];""")
write_dag(io, t)
println(io, "}")
end
function _show_plan(io::IO, t::Thunk, logs::Vector{Timespan})
function _show_plan(io::IO, logs::Vector; times_digits::Integer=3)
println(io, """strict digraph {
graph [layout=dot,rankdir=LR];""")
write_dag(io, t, logs)
write_dag(io, logs; times_digits)
println(io, "}")
end
function _show_plan(io::IO, t::Union{Thunk,DTask}, logs::Vector{Timespan}; times_digits::Integer=3)
println(io, """strict digraph {
graph [layout=dot,rankdir=LR];""")
write_dag(io, logs, t; times_digits)
println(io, "}/")
end

Dagger.show_logs(io::IO, t::Union{Thunk,DTask}, ::Val{:graphviz_simple}) = _show_plan(io, t)
Dagger.show_logs(io::IO, logs::Vector{Timespan}, ::Val{:graphviz_simple}; times_digits::Integer=3) = _show_plan(io, logs; times_digits=times_digits)
Dagger.show_logs(io::IO, t::Union{Thunk,DTask}, logs::Vector{Timespan}, ::Val{:graphviz_simple}; times_digits::Integer=3) = _show_plan(io, t, logs; times_digits=times_digits)

show_logs(io::IO, t::Thunk, ::Val{:graphviz_simple}) = _show_plan(io, t)
show_logs(io::IO, logs::Vector{Timespan}, ::Val{:graphviz_simple}) = _show_plan(io, logs)
show_logs(io::IO, t::Thunk, logs::Vector{Timespan}, ::Val{:graphviz_simple}) = _show_plan(io, t, logs)

end
6 changes: 3 additions & 3 deletions src/visualization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ Returns a string representation of the logs of a task `t` and/or logs object `lo
"""
function show_logs end

show_logs(io::IO, logs, vizmode::Symbol; options...) =
show_logs(io, logs, Val{vizmode}(); options...)
show_logs(io::IO, arg, vizmode::Symbol; options...) =
show_logs(io, arg, Val{vizmode}(); options...)
show_logs(io::IO, t, logs, vizmode::Symbol; options...) =
show_logs(io, t, Val{vizmode}(); options...)
show_logs(io, t, logs, Val{vizmode}(); options...)
show_logs(io::IO, ::T, ::Val{vizmode}; options...) where {T,vizmode} =
throw(ArgumentError("show_logs: Task/logs type `$T` not supported for visualization mode `$(repr(vizmode))`"))
show_logs(io::IO, ::T, ::Logs, ::Val{vizmode}; options...) where {T,Logs,vizmode} =
Expand Down