diff --git a/ext/GraphPPLGraphVizExt.jl b/ext/GraphPPLGraphVizExt.jl index 18f486a..476ffbf 100644 --- a/ext/GraphPPLGraphVizExt.jl +++ b/ext/GraphPPLGraphVizExt.jl @@ -6,6 +6,8 @@ using GraphPPL, MetaGraphsNext, GraphViz using GraphPPL.MetaGraphsNext import MetaGraphsNext: nv +import GraphViz: load + """ This abstract type represents a node traversal strategy for use with the `generate_dot` function. @@ -372,6 +374,20 @@ function dot_string_to_pdf(dot_string::String, dst_pdf_file::String)::Bool end end +function get_displayed_label(properties::GraphPPL.FactorNodeProperties) + return GraphPPL.prettyname(properties) +end + +function get_displayed_label(properties::GraphPPL.VariableNodeProperties) + if GraphPPL.is_constant(properties) + return GraphPPL.value(properties) + elseif !isnothing(GraphPPL.index(properties)) + return string("<", GraphPPL.getname(properties), "", GraphPPL.index(properties), "", ">") + else + return GraphPPL.getname(properties) + end +end + """ Constructs the portion of the DOT string that specifies the nodes in the GraphViz visualization. Specifically, by means of the simple iteration strategy specified by the `SimpleIteration` subtype. @@ -399,17 +415,18 @@ function add_nodes!( ) for vertex in MetaGraphsNext.vertices(model_graph.graph) - # index the label of model_namespace_variables with "vertex" san_label = get_sanitized_node_name(global_namespace_dict[vertex]) + # index the label of model_namespace_variables with "vertex" label = MetaGraphsNext.label_for(model_graph.graph, vertex) properties = model_graph[label].properties + displayed_label = get_displayed_label(properties) if isa(properties, GraphPPL.FactorNodeProperties) - write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray];\n") + write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray, label=$(displayed_label)];\n") elseif isa(properties, GraphPPL.VariableNodeProperties) - write(io_buffer, " \"$(san_label)\" [shape=circle];\n") + write(io_buffer, " \"$(san_label)\" [shape=circle, label=$(displayed_label)];\n") else error("Unknown node type for label $(san_label)") end @@ -459,11 +476,12 @@ function add_nodes!(io_buffer::IOBuffer, model_graph::GraphPPL.Model, global_nam label = MetaGraphsNext.label_for(model_graph.graph, v) properties = model_graph[label].properties + displayed_label = get_displayed_label(properties) if isa(properties, GraphPPL.FactorNodeProperties) - write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray];\n") + write(io_buffer, " \"$(san_label)\" [shape=square, style=filled, fillcolor=lightgray, label=$(displayed_label)];\n") elseif isa(properties, GraphPPL.VariableNodeProperties) - write(io_buffer, " \"$(san_label)\" [shape=circle];\n") + write(io_buffer, " \"$(san_label)\" [shape=circle, label=$(displayed_label)];\n") else error("Unknown node type for label $(san_label)") end @@ -596,55 +614,6 @@ function convert_strategy(strategy::Symbol) end end -""" -Constructs a DOT string from an input `GraphPPL.Model` for visualization with GraphViz.jl. -The DOT string includes configuration options for node appearance, edge length, layout, and more. - -# Arguments: -- `model_graph::GraphPPL.Model`: The `GraphPPL.Model` structure containing the raw factor - graph to be visualized. -- `strategy::TraversalStrategy`: Specifies the traversal strategy for graph traversal. - Either `SimpleIteration()` or `BFSTraversal()`. -- `font_size::Int`: The font size of the node labels. -- `edge_length::Float64` (default is `1.0`): Controls the visual length of edges in the graph. -- `layout::String` (default is `"neato"`): The layout engine to be used by GraphViz for - arranging the nodes. -- `overlap::Bool`: Controls whether node overlap is allowed in the visualization. -- `width::Float64` (default is `10.0`): The width of the display window in inches. -- `height::Float64` (default is `10.0`): The height of the display window in inches. - -# Returns: -- `String`: A DOT format string that can be used to generate a GraphViz visualization. -""" -function generate_dot(; - model_graph::GraphPPL.Model, - strategy::Symbol, - font_size::Int, - edge_length::Float64 = 1.0, - layout::String = "neato", - overlap::Bool, - width::Float64 = 10.0, - height::Float64 = 10.0 -) - # convert user-specified symbolic expression to the associated type for eventual dispatch - traversal_strategy = convert_strategy(strategy) - - # dispatch on the type of traversal strategy - _generate_dot( - model_graph = model_graph, - strategy = traversal_strategy, - font_size = font_size, - edge_length = edge_length, - layout = layout, - overlap = overlap, - width = width, - height = height - ) - -end - - - """ # Constructs a DOT string from an input `GraphPPL.Model` for visualization with GraphViz.jl. # The DOT string includes configuration options for node appearance, edge length, layout, and more. @@ -665,16 +634,19 @@ end # # Returns: # - `String`: A DOT format string that can be used to generate a GraphViz visualization. # """ -function _generate_dot(; - model_graph::GraphPPL.Model, - strategy::TraversalStrategy, - font_size::Int, +function GraphViz.load(model_graph::GraphPPL.Model; + strategy::Symbol, + font_size::Int = 12, edge_length::Float64 = 1.0, layout::String = "neato", - overlap::Bool, + overlap::Bool = false, width::Float64 = 10.0, - height::Float64 = 10.0 + height::Float64 = 10.0, + show::Bool = false, + save_to::Union{String, Nothing} = nothing ) + traversal_strategy = convert_strategy(strategy) + # get the entire namespace dict global_namespace_dict = get_namespace_variables_dict(model_graph) @@ -688,15 +660,26 @@ function _generate_dot(; write(io_buffer, " node [shape=circle, fontsize=$(font_size)];\n") # Nodes - add_nodes!(io_buffer, model_graph, global_namespace_dict, strategy) + add_nodes!(io_buffer, model_graph, global_namespace_dict, traversal_strategy) # Edges - add_edges!(io_buffer, model_graph, global_namespace_dict, strategy, edge_length) + add_edges!(io_buffer, model_graph, global_namespace_dict, traversal_strategy, edge_length) write(io_buffer, "}\n\"\"\"") final_dot = String(take!(io_buffer)) + if !isnothing(save_to) + @info "Saving the DOT string to file: $(save_to)" + if !dot_string_to_pdf(final_dot, save_to) + @warn "Failed to save the DOT string to file: $(save_to)" + end + end + + if show + return show_gv(final_dot) + end + return final_dot end diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 79d6358..b8aad38 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -739,6 +739,10 @@ function Base.convert(::Type{FactorNodeProperties}, fform, options::NodeCreation return FactorNodeProperties(fform = fform, neighbors = get(options, :neighbors, Tuple{NodeLabel, EdgeLabel, NodeData}[])) end +getname(properties::FactorNodeProperties) = string(properties.fform) +prettyname(properties::FactorNodeProperties) = prettyname(properties.fform) +prettyname(fform::Any) = string(fform) # Can be overloaded for custom pretty names + fform(properties::FactorNodeProperties) = properties.fform neighbors(properties::FactorNodeProperties) = properties.neighbors addneighbor!(properties::FactorNodeProperties, variable::NodeLabel, edge::EdgeLabel, data) =