Skip to content

Commit

Permalink
fix api and documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
nwlandry committed Feb 26, 2024
1 parent d5adfa7 commit 48e27f4
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 35 deletions.
48 changes: 27 additions & 21 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,7 @@
_parse_color_arg,
_update_lims,
)
from .layout import (
barycenter_spring_layout,
bipartite_spring_layout,
)
from .layout import barycenter_spring_layout, bipartite_spring_layout

__all__ = [
"draw",
Expand Down Expand Up @@ -848,6 +845,7 @@ def draw_simplices(
Returns
-------
ax
collections : a tuple of 2 collections:
* dyad_collection : matplotlib LineCollection
Collection containing the dyads
Expand Down Expand Up @@ -1516,8 +1514,7 @@ def draw_multilayer(
def draw_bipartite(
H,
ax=None,
node_pos=None,
edge_pos=None,
pos=None,
node_fc="white",
node_ec="black",
node_shape="o",
Expand Down Expand Up @@ -1549,14 +1546,11 @@ def draw_bipartite(
The hypergraph to draw.
ax : matplotlib.pyplot.axes, optional
Axis to draw on. If None (default), get the current axes.
node_pos : dict, optional
If passed, this dictionary of positions node_id:(x,y) is used for placing the
node markers. If None (default), use the `bipartite_spring_layout` to compute
the positions.
edge_pos : dict, optional
If passed, this dictionary of positions edge_id:(x,y) is used for placing the
edge markers. If None (default), use the `bipartite_spring_layout` to compute
the positions.
pos : tuple of two dicts, optional
The tuple should contains a (1) dictionary of positions node_id:(x,y) for
placing node markers, and (2) a dictionary of positions edge_id:(x,y) for
placing the edge markers. If None (default), use the
`bipartite_spring_layout` to compute the positions.
node_fc : str, dict, iterable, or NodeStat, optional
Color of the nodes. If str, use the same color for all nodes. If a dict, must
contain (node_id: color_str) pairs. If other iterable, assume the colors are
Expand Down Expand Up @@ -1595,7 +1589,8 @@ def draw_bipartite(
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
edge_marker_shape: str, optional
Marker used for the hyperedges. By default 's' (square marker).
Marker used for the hyperedges. By default 's' (square marker). If "", no marker is
displayed.
edge_marker_lw : int, float, dict, iterable, or EdgeStat, optional
Line width of the edge marker borders in pixels. If int or float, use the same width
for all edge marker borders. If a dict, must contain (edge_id: width) pairs. If
Expand All @@ -1609,8 +1604,8 @@ def draw_bipartite(
H.edges. If EdgeStat, use a monotonic linear interpolation defined between
min_edge_marker_size and max_edge_marker_size. By default, 7.
edge_marker_fc_cmap : colormap
Colormap for mapping edge marker colors. By default, "Reds". Ignored, if `edge_marker_fc` is
a str (single color).
Colormap for mapping edge marker colors. By default, "Reds".
Ignored, if `edge_marker_fc` is a str (single color) or an iterable of colors.
dyad_color : str, dict, iterable, optional
Color of the bipartite edges. If str, use the same color for all edges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
Expand Down Expand Up @@ -1665,6 +1660,15 @@ def draw_bipartite(
Returns
-------
ax : matplotlib.pyplot.axes
The axes corresponding the visualization
collections : a tuple of 3 collections:
* node_collection : matplotlib PathCollection
Collection containing the nodes
* edge_marker_collection : matplotlib PathCollection
Collection containing the edge markers
* dyad_collection : matplotlib LineCollection if undirected,
list of FancyArrowPatches if not
Collection containing the edges
Raises
------
Expand All @@ -1674,8 +1678,7 @@ def draw_bipartite(
See Also
--------
draw
draw_nodes
draw_node_labels
draw_multilayer
"""
if isinstance(H, DiHypergraph):
Expand Down Expand Up @@ -1732,8 +1735,11 @@ def draw_bipartite(
"node_ec_cmap": settings["edge_marker_ec_cmap"],
}

if not node_pos or not edge_pos:
node_pos, edge_pos = bipartite_spring_layout(H)
if not pos:
pos = bipartite_spring_layout(H)

node_pos = pos[0]
edge_pos = pos[1]

if ax is None:
ax = plt.gca()
Expand Down
23 changes: 9 additions & 14 deletions xgi/drawing/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,13 @@ def _augmented_projection(H, weighted=False):

def bipartite_spring_layout(H, seed=None, k=None, **kwargs):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
of the hypergraph (or simplicial complex), where phantom nodes
(barycenters) are created for each edge composed by more than two nodes.
If a simplicial complex is provided the results will be based on the
hypergraph constructed from its maximal simplices.
Position the nodes and edges using Fruchterman-Reingold force-directed
algorithm using the hypergraph converted to a bipartite network.
Parameters
----------
H : xgi Hypergraph or SimplicialComplex
A position will be assigned to every node in H.
H : Hypergraph
A position will be assigned to every node and edge in H.
seed : int, RandomState instance or None optional (default=None)
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
Expand All @@ -219,8 +215,10 @@ def bipartite_spring_layout(H, seed=None, k=None, **kwargs):
Returns
-------
pos : dict
A dictionary of positions keyed by node
pos : tuple of dicts
A tuple of two dictionaries:
the first is a dictionary of positions keyed by node
the second is a dictionary of positions keyed by edge
See Also
--------
Expand All @@ -234,14 +232,11 @@ def bipartite_spring_layout(H, seed=None, k=None, **kwargs):
>>> N = 50
>>> ps = [0.1, 0.01]
>>> H = xgi.random_hypergraph(N, ps)
>>> pos = xgi.barycenter_spring_layout(H)
>>> pos = xgi.bipartite_spring_layout(H)
"""
if seed is not None:
random.seed(seed)

if isinstance(H, SimplicialComplex):
H = convert.from_max_simplices(H)

G, nodedict, edgedict = to_bipartite_graph(H, index=True)

# Creating a dictionary for the position of the nodes with the standard spring
Expand Down

0 comments on commit 48e27f4

Please sign in to comment.