Skip to content

Commit

Permalink
[ENH] Add ability to draw graphs with a title (#71)
Browse files Browse the repository at this point in the history
* added a keyword argument to label graphs
* Add labels to graphs in some examples

---------
Signed-off-by: Aryan Roy <aryanroy5678@gmail.com>
Co-authored-by: Adam Li <adam2392@gmail.com>
  • Loading branch information
aryan26roy authored Apr 3, 2023
1 parent ac6333b commit c1b4107
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 7 deletions.
2 changes: 2 additions & 0 deletions docs/whats_new/_contributors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,5 @@
.. _Adam Li: https://github.com/adam2392
.. _Julien Siebert: https://github.com/siebert-julien
.. _Jaron Lee: https://github.com/jaron-lee
.. _Aryan Roy: https://github.com/aryan26roy

3 changes: 2 additions & 1 deletion docs/whats_new/v0.1.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Version 0.1

Changelog
---------

- |Feature| Add keyword argument for graph labels in :func:`pywhy_graphs.viz.draw`, by `Aryan Roy`_ (:pr:`71`)
- |Feature| Implement minimal m-separator function in :func:`pywhy_graphs.networkx.minimal_m_separator` with a BFS approach, by `Jaron Lee`_ (:pr:`53`)
- |Feature| Implement m-separation :func:`pywhy_graphs.networkx.m_separated` with the BallTree approach, by `Jaron Lee`_ (:pr:`48`)
- |Feature| Add support for undirected edges in m-separation :func:`pywhy_graphs.networkx.m_separated`, by `Jaron Lee`_ (:pr:`46`)
Expand All @@ -52,3 +52,4 @@ the project since version inception, including:
* `Adam Li`_
* `Julien Siebert`_
* `Jaron Lee`_
* `Aryan Roy`_
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@

# draw the graphs (i.e., generate a graphviz object that can be rendered)
# each time we call draw() we pass the layout position of G
dot_G = draw(G, pos=pos_G)
dot_admg = draw(admg, pos=pos_G)
dot_cpdag = draw(cpdag, pos=pos_G)
dot_pag = draw(pag, pos=pos_G)
dot_G = draw(G, name="A DiGraph", pos=pos_G)
dot_admg = draw(admg, name="An ADMG", pos=pos_G)
dot_cpdag = draw(cpdag, name="A CPDAG", pos=pos_G)
dot_pag = draw(pag, name="A PAG", pos=pos_G)

# render the graphs using graphviz render() function
dot_G.render(outfile="G.png", view=True, engine="neato")
Expand Down
19 changes: 17 additions & 2 deletions pywhy_graphs/viz/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ def _draw_pag_edges(
return dot, found_circle_sibs


def draw(G, direction: Optional[str] = None, pos: Optional[dict] = None, shape="square", **attrs):
def draw(
G,
direction: Optional[str] = None,
pos: Optional[dict] = None,
name: Optional[str] = None,
shape="square",
**attrs,
):
"""Visualize the graph.
Parameters
Expand All @@ -67,6 +74,8 @@ def draw(G, direction: Optional[str] = None, pos: Optional[dict] = None, shape="
The positions of the nodes keyed by node with (x, y) coordinates as values.
By default None, which will
use the default layout from graphviz.
name : str, optional
Label for the generated graph.
shape : str
The shape of each node. By default 'square'. Can be 'circle', 'plaintext'.
attrs : dict
Expand All @@ -80,7 +89,13 @@ def draw(G, direction: Optional[str] = None, pos: Optional[dict] = None, shape="
"""
from graphviz import Digraph

dot = Digraph()
# make a dict to pass to the Digraph object
g_attr = {"label": name}

if name is not None:
dot = Digraph(graph_attr=g_attr)
else:
dot = Digraph()

# set direction from left to right if that's preferred
if direction == "LR":
Expand Down
24 changes: 24 additions & 0 deletions pywhy_graphs/viz/tests/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,3 +105,27 @@ def test_draw_with_ts_layout():
pos_G = timeseries_layout(G, variable_order=["x", "y", "z"], scale=10)

assert all(node in pos_G for node in G.nodes)


def test_draw_name_is_given():
"""
Ensure the generated graph contains the label provided by the user.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# draw the graphs
dot = draw(graph, name="test")
# assert that the produced graph contains a label
assert "label=test" in dot.source


def test_draw_name_is_not_given():
"""
Ensure the generated graph does not contain a label.
"""
# create a dummy graph x --> y <-- z and z --> x
graph = nx.DiGraph([("x", "y"), ("z", "y"), ("z", "x")])
# draw the graphs
dot = draw(graph)
# assert that the produced graph does not contain a label
assert "label=" not in dot.source

0 comments on commit c1b4107

Please sign in to comment.