Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graphviz suggestions #257

Merged
merged 4 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 45 additions & 62 deletions ext/GraphPPLGraphVizExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ using GraphPPL, MetaGraphsNext, GraphViz
using GraphPPL.MetaGraphsNext
import MetaGraphsNext: nv

import GraphViz: load

"""
This abstract type represents a node traversal strategy for use with the `generate_dot` function.

Expand Down Expand Up @@ -372,6 +374,20 @@ function dot_string_to_pdf(dot_string::String, dst_pdf_file::String)::Bool
end
end

function get_displayed_label(properties::GraphPPL.FactorNodeProperties)
return GraphPPL.prettyname(properties)
end

function get_displayed_label(properties::GraphPPL.VariableNodeProperties)
if GraphPPL.is_constant(properties)
return GraphPPL.value(properties)
elseif !isnothing(GraphPPL.index(properties))
return string("<", GraphPPL.getname(properties), "<SUB><FONT POINT-SIZE=\"6\">", GraphPPL.index(properties), "</FONT></SUB>", ">")
else
return GraphPPL.getname(properties)
end
end

"""
Constructs the portion of the DOT string that specifies the nodes in the GraphViz visualization.
Specifically, by means of the simple iteration strategy specified by the `SimpleIteration` subtype.
Expand Down Expand Up @@ -399,17 +415,18 @@ function add_nodes!(
)
for vertex in MetaGraphsNext.vertices(model_graph.graph)

# index the label of model_namespace_variables with "vertex"
san_label = get_sanitized_node_name(global_namespace_dict[vertex])

# index the label of model_namespace_variables with "vertex"
label = MetaGraphsNext.label_for(model_graph.graph, vertex)

properties = model_graph[label].properties
displayed_label = get_displayed_label(properties)

if isa(properties, GraphPPL.FactorNodeProperties)
write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray];\n")
write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray, label=$(displayed_label)];\n")
elseif isa(properties, GraphPPL.VariableNodeProperties)
write(io_buffer, " \"$(san_label)\" [shape=circle];\n")
write(io_buffer, " \"$(san_label)\" [shape=circle, label=$(displayed_label)];\n")
else
error("Unknown node type for label $(san_label)")
end
Expand Down Expand Up @@ -459,11 +476,12 @@ function add_nodes!(io_buffer::IOBuffer, model_graph::GraphPPL.Model, global_nam

label = MetaGraphsNext.label_for(model_graph.graph, v)
properties = model_graph[label].properties
displayed_label = get_displayed_label(properties)

if isa(properties, GraphPPL.FactorNodeProperties)
write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray];\n")
write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray, label=$(displayed_label)];\n")
elseif isa(properties, GraphPPL.VariableNodeProperties)
write(io_buffer, " \"$(san_label)\" [shape=circle];\n")
write(io_buffer, " \"$(san_label)\" [shape=circle, label=$(displayed_label)];\n")
else
error("Unknown node type for label $(san_label)")
end
Expand Down Expand Up @@ -596,55 +614,6 @@ function convert_strategy(strategy::Symbol)
end
end

"""
Constructs a DOT string from an input `GraphPPL.Model` for visualization with GraphViz.jl.
The DOT string includes configuration options for node appearance, edge length, layout, and more.

# Arguments:
- `model_graph::GraphPPL.Model`: The `GraphPPL.Model` structure containing the raw factor
graph to be visualized.
- `strategy::TraversalStrategy`: Specifies the traversal strategy for graph traversal.
Either `SimpleIteration()` or `BFSTraversal()`.
- `font_size::Int`: The font size of the node labels.
- `edge_length::Float64` (default is `1.0`): Controls the visual length of edges in the graph.
- `layout::String` (default is `"neato"`): The layout engine to be used by GraphViz for
arranging the nodes.
- `overlap::Bool`: Controls whether node overlap is allowed in the visualization.
- `width::Float64` (default is `10.0`): The width of the display window in inches.
- `height::Float64` (default is `10.0`): The height of the display window in inches.

# Returns:
- `String`: A DOT format string that can be used to generate a GraphViz visualization.
"""
function generate_dot(;
model_graph::GraphPPL.Model,
strategy::Symbol,
font_size::Int,
edge_length::Float64 = 1.0,
layout::String = "neato",
overlap::Bool,
width::Float64 = 10.0,
height::Float64 = 10.0
)
# convert user-specified symbolic expression to the associated type for eventual dispatch
traversal_strategy = convert_strategy(strategy)

# dispatch on the type of traversal strategy
_generate_dot(
model_graph = model_graph,
strategy = traversal_strategy,
font_size = font_size,
edge_length = edge_length,
layout = layout,
overlap = overlap,
width = width,
height = height
)

end



"""
# Constructs a DOT string from an input `GraphPPL.Model` for visualization with GraphViz.jl.
# The DOT string includes configuration options for node appearance, edge length, layout, and more.
Expand All @@ -665,16 +634,19 @@ end
# # Returns:
# - `String`: A DOT format string that can be used to generate a GraphViz visualization.
# """
function _generate_dot(;
model_graph::GraphPPL.Model,
strategy::TraversalStrategy,
font_size::Int,
function GraphViz.load(model_graph::GraphPPL.Model;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have attempted to call this method in a version of the v2 test file that I wrote as follows:

push!(LOAD_PATH, joinpath(@__DIR__, "../ext"))

using GraphPPLGraphVizExt: GraphViz


## CREATE AN RXINFER.JL MODEL:
# GraphPPL.jl export `@model` macro for model specification
# It accepts a regular Julia function and builds an FFG under the hood
@model function coin_model(y, a, b)
    # We endow θ parameter of our model with some prior
    θ ~ Beta(a, b)
    # or, in this particular case, the `Uniform(0.0, 1.0)` prior also works:
    # θ ~ Uniform(0.0, 1.0)

    # We assume that outcome of each coin flip is governed by the Bernoulli distribution
    for i in eachindex(y)
        y[i] ~ Bernoulli(θ)
    end
end

# condition the model on some observed data
conditioned = coin_model(a = 2.0, b = 7.0) | (y = [ true, false, true ], )

# `Create` the actual graph of the model conditioned on the data
rxi_model = RxInfer.create_model(conditioned)

gppl_model = RxInfer.getmodel(rxi_model)


gen_dot_result_coin_simple = GraphViz.load(
    gppl_model,
    strategy = :simple,
    font_size = 12,
    edge_length = 1.0,
    layout = "neato",
    overlap = false,
    width = 10.0,
    height = 10.0,
    show = false
)

As yet, my use of GraphViz.load is resulting in a persistent failure to correctly dispatch on the overloaded method:

$ julia test/visualization_tests_233_v2.jl
ERROR: LoadError: MethodError: no method matching var"#load#3"(::Symbol, ::Int64, ::Float64, ::String, ::Bool, ::Float64, ::Float64, ::Bool, ::Nothing, ::typeof(GraphViz.load), ::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}})
The function `#load#3` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  var"#load#3"(::Symbol, ::Int64, ::Float64, ::String, ::Bool, ::Float64, ::Float64, ::Bool, ::String, ::typeof(GraphViz.load), ::GraphPPL.Model)
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:637

Stacktrace:
 [1] load(model_graph::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}})
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:637
 [2] top-level scope
   @ ~/Desktop/GraphPPL.jl/test/visualization_tests_233_v2.jl:51

Perhaps I have just done something silly but I can't seem to shake this error, no matter how I re-define the positionality of the arguments.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update: Ok its my mistake, I pushed the fix, the type for the save_to variable was wrong.

I'm not sure about those lines

push!(LOAD_PATH, joinpath(@__DIR__, "../ext"))

using GraphPPLGraphVizExt: GraphViz

You should simply do

using GraphViz

and the package should be installed on your system. Otherwise the extension won't load properly. Does that fix the error?

You error however says something weird, so it definitely sees the method here in the stacktrace (I assume the extension loaded?)

load(model_graph::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}})
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:637

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarification @bvdmitri.

I have two things to discuss.

First Thing:

You seem to be saying that I can/should simply replace the following lines in test/visualization_tests_233_v2.jl:

push!(LOAD_PATH, joinpath(@__DIR__, "../ext"))
using GraphPPLGraphVizExt: GraphViz

with the simpler: using GraphViz . I did try this, it results in:

$ julia test/visualization_tests_233_v2.jl
ERROR: LoadError: MethodError: no method matching load(::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}}; strategy::Symbol)
The function `load` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  load(::IO) got unsupported keyword argument "strategy"
   @ GraphViz ~/.julia/packages/GraphViz/IsUMl/src/GraphViz.jl:90
  load(::FileIO.File{FileIO.DataFormat{:DOT}}) got unsupported keyword argument "strategy"
   @ GraphViz ~/.julia/packages/GraphViz/IsUMl/src/GraphViz.jl:89

Stacktrace:
 [1] top-level scope
   @ ~/Desktop/GraphPPL.jl/test/visualization_tests_233_v2.jl:50
in expression starting at / ... /test/visualization_tests_233_v2.jl:50

Thus it seems that Julia is unable to find the overloaded version of GraphViz.load if I don't explicitly use

using GraphPPLGraphVizExt: GraphViz

to specify that I want my overlaoded version from GraphPPLGraphVizExt. In order to do this, I think I need to append the extension to Julia's LOAD_PATH variable. That is why I have:

push!(LOAD_PATH, joinpath(@__DIR__, "../ext"))

My understanding is that it is necessary to append the ext directory to Julia's current list of available directories in the LOAD_PATH variable. Perhaps there is a much better/cleaner way to do this.

Second Thing:

When I attempt to run my existing test script:

using GraphPPL
using Test
using RxInfer
using Distributions
using Random
using Graphs
using MetaGraphsNext
using Dictionaries
using Cairo
using Fontconfig
using Compose
using GraphPlot

push!(LOAD_PATH, joinpath(@__DIR__, "../ext"))

using GraphPPLGraphVizExt: GraphViz

## CREATE AN RXINFER.JL MODEL:
# GraphPPL.jl export `@model` macro for model specification
# It accepts a regular Julia function and builds an FFG under the hood
@model function coin_model(y, a, b)
    # We endow θ parameter of our model with some prior
    θ ~ Beta(a, b)

    # We assume that outcome of each coin flip is governed by the Bernoulli distribution
    for i in eachindex(y)
        y[i] ~ Bernoulli(θ)
    end
end

# condition the model on some observed data
conditioned = coin_model(a = 2.0, b = 7.0) | (y = [ true, false, true ], )

# `Create` the actual graph of the model conditioned on the data
rxi_model = RxInfer.create_model(conditioned)

gppl_model = RxInfer.getmodel(rxi_model)

# generate visualization:
gen_dot_result_coin_simple = GraphViz.load(
    gppl_model, 
    strategy = :simple
)

println(gen_dot_result_coin_simple)

Julia can't seem to find the prettyname function:

$ julia test/visualization_tests_233_v2.jl
ERROR: LoadError: UndefVarError: `prettyname` not defined in `GraphPPL`
Stacktrace:
 [1] get_displayed_label(properties::GraphPPL.FactorNodeProperties{GraphPPL.NodeData})
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:377
 [2] add_nodes!(io_buffer::IOBuffer, model_graph::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}}, global_namespace_dict::Dict{Int64, Dict{Symbol, Any}}, ::GraphPPLGraphVizExt.SimpleIteration)
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:423
 [3] load(model_graph::GraphPPL.Model{MetaGraph{Int64, SimpleGraph{Int64}, GraphPPL.NodeLabel, GraphPPL.NodeData, GraphPPL.EdgeLabel, GraphPPL.Context, MetaGraphsNext.var"#4#8", Float64}, GraphPPL.PluginsCollection{Tuple{}}, RxInfer.ReactiveMPGraphPPLBackend{Static.False}}; strategy::Symbol, font_size::Int64, edge_length::Float64, layout::String, overlap::Bool, width::Float64, height::Float64, show::Bool, save_to::Nothing)
   @ GraphPPLGraphVizExt ~/Desktop/GraphPPL.jl/ext/GraphPPLGraphVizExt.jl:662
 [4] top-level scope
   @ ~/ ... /test/visualization_tests_233_v2.jl:48
in expression starting at / ... /test/visualization_tests_233_v2.jl:48

I have tried all kinds of things to resolve this issue, none have worked.

Perhaps I am systematically misunderstanding how Julia loads various modules and files for actual execution and this is the reason why I'm having all this difficulty. That is my guess! I apologise for any such deficiency on my behalf.

I did actually get an overloaded version of GraphViz.load working in my own branch (minus the prettyname) so I'm not sure why there are all these difficulties here.

I hope this is enough to go on for general comments and suggestions.

strategy::Symbol,
font_size::Int = 12,
edge_length::Float64 = 1.0,
layout::String = "neato",
overlap::Bool,
overlap::Bool = false,
width::Float64 = 10.0,
height::Float64 = 10.0
height::Float64 = 10.0,
show::Bool = false,
save_to::Union{String, Nothing} = nothing
)
traversal_strategy = convert_strategy(strategy)

# get the entire namespace dict
global_namespace_dict = get_namespace_variables_dict(model_graph)

Expand All @@ -688,15 +660,26 @@ function _generate_dot(;
write(io_buffer, " node [shape=circle, fontsize=$(font_size)];\n")

# Nodes
add_nodes!(io_buffer, model_graph, global_namespace_dict, strategy)
add_nodes!(io_buffer, model_graph, global_namespace_dict, traversal_strategy)

# Edges
add_edges!(io_buffer, model_graph, global_namespace_dict, strategy, edge_length)
add_edges!(io_buffer, model_graph, global_namespace_dict, traversal_strategy, edge_length)

write(io_buffer, "}\n\"\"\"")

final_dot = String(take!(io_buffer))

if !isnothing(save_to)
@info "Saving the DOT string to file: $(save_to)"
if !dot_string_to_pdf(final_dot, save_to)
@warn "Failed to save the DOT string to file: $(save_to)"
end
end

if show
return show_gv(final_dot)
end

return final_dot
end

Expand Down
4 changes: 4 additions & 0 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,10 @@
return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel, NodeData}[]))
end

getname(properties::FactorNodeProperties) = string(properties.fform)
prettyname(properties::FactorNodeProperties) = prettyname(properties.fform)
prettyname(fform::Any) = string(fform) # Can be overloaded for custom pretty names

Check warning on line 744 in src/graph_engine.jl

View check run for this annotation

Codecov / codecov/patch

src/graph_engine.jl#L742-L744

Added lines #L742 - L744 were not covered by tests

fform(properties::FactorNodeProperties) = properties.fform
neighbors(properties::FactorNodeProperties) = properties.neighbors
addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data) =
Expand Down
Loading