diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index c6d2d84d7..5601678a1 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -24,7 +24,11 @@ _scalar_arg_to_dict, _update_lims, ) -from .layout import _augmented_projection, barycenter_spring_layout +from .layout import ( + _augmented_projection, + barycenter_spring_layout, + bipartite_spring_layout, +) __all__ = [ "draw", @@ -36,6 +40,7 @@ "draw_hypergraph_hull", "draw_multilayer", "draw_dihypergraph", + "draw_bipartite", ] @@ -1767,3 +1772,274 @@ def draw_dihypergraph( _update_lims(pos, ax) return ax + + +def draw_bipartite( + H, + node_pos=None, + edge_pos=None, + ax=None, + dyad_color="black", + dyad_lw=1, + node_fc="white", + node_ec="black", + node_marker="o", + node_lw=1, + node_size=10, + edge_marker_fc="lightblue", + edge_marker_ec="black", + edge_marker="s", + edge_marker_lw=1, + edge_marker_size=10, + max_order=None, + settings=None, + **kwargs, +): + """Draw a hypergraph as a bipartite network. + + Parameters + ---------- + H : Hypergraph + The hypergraph to draw. + ax : matplotlib.pyplot.axes, optional + Axis to draw on. If None (default), get the current axes. + lines_fc : str, dict, iterable, optional + Color of the hyperedges (lines). If str, use the same color for all hyperedges. + If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable, + assume the colors are specified in the same order as the hyperedges are found + in DH.edges. If None (default), use the size of the hyperedges. + lines_lw : int, float, dict, iterable, optional + Line width of the hyperedges (lines). If int or float, use the same width for + all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other + iterable, assume the widths are specified in the same order as the hyperedges + are found in DH.edges. By default, 1.5. + line_head_width : float, optional + Length of arrows' heads. By default, 0.05 + node_fc : str, dict, iterable, or NodeStat, optional + Color of the nodes. If str, use the same color for all nodes. If a dict, must + contain (node_id: color_str) pairs. If other iterable, assume the colors are + specified in the same order as the nodes are found in H.nodes. If NodeStat, use + the colormap specified with node_fc_cmap. By default, "white". + node_ec : str, dict, iterable, or NodeStat, optional + Color of node borders. If str, use the same color for all nodes. If a dict, + must contain (node_id: color_str) pairs. If other iterable, assume the colors + are specified in the same order as the nodes are found in H.nodes. If NodeStat, + use the colormap specified with node_ec_cmap. By default, "black". + node_lw : int, float, dict, iterable, or NodeStat, optional + Line width of the node borders in pixels. If int or float, use the same width + for all node borders. If a dict, must contain (node_id: width) pairs. If + iterable, assume the widths are specified in the same order as the nodes are + found in H.nodes. If NodeStat, use a monotonic linear interpolation defined + between min_node_lw and max_node_lw. By default, 1. + node_size : int, float, dict, iterable, or NodeStat, optional + Radius of the nodes in pixels. If int or float, use the same radius for all + nodes. If a dict, must contain (node_id: radius) pairs. If iterable, assume + the radiuses are specified in the same order as the nodes are found in + H.nodes. If NodeStat, use a monotonic linear interpolation defined between + min_node_size and max_node_size. By default, 15. + edge_marker_toggle: bool, optional + If True then marker representing the hyperedges are drawn. By default True. + edge_marker_fc: str, dict, iterable, optional + Filling color of the hyperedges (markers). If str, use the same color for all hyperedges. + If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable, + assume the colors are specified in the same order as the hyperedges are found + in DH.edges. If None (default), use the size of the hyperedges. + edge_marker_ec: str, dict, iterable, optional + Edge color of the hyperedges (markers). If str, use the same color for all hyperedges. + If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable, + assume the colors are specified in the same order as the hyperedges are found + in DH.edges. If None (default), use the size of the hyperedges. + edge_marker: str, optional + Marker used for the hyperedges. By default 's' (square marker). + max_order : int, optional + Maximum of hyperedges to plot. If None (default), plots all orders. + node_labels : bool or dict, optional + If True, draw ids on the nodes. If a dict, must contain (node_id: label) pairs. + By default, False. + hyperedge_labels : bool or dict, optional + If True, draw ids on the hyperedges. If a dict, must contain (edge_id: label) + pairs. By default, False. + **kwargs : optional args + Alternate default values. Values that can be overwritten are the following: + * min_node_size + * max_node_size + * min_node_lw + * max_node_lw + * node_fc_cmap + * node_ec_cmap + * min_lines_lw + * max_lines_lw + * lines_fc_cmap + * edge_fc_cmap + * edge_marker_fc_cmap + * edge_marker_ec_cmap + + Returns + ------- + ax : matplotlib.pyplot.axes + + Raises + ------ + XGIError + If something different than a DiHypergraph is passed. + + See Also + -------- + draw + draw_nodes + draw_node_labels + + """ + if not isinstance(H, Hypergraph): + raise XGIError("The input must be a Hypergraph") + + if settings is None: + settings = { + "min_node_size": 10.0, + "max_node_size": 30.0, + "min_edge_marker_size": 10.0, + "max_edge_marker_size": 30.0, + "min_node_lw": 10.0, + "max_node_lw": 30.0, + "min_edge_marker_lw": 10.0, + "max_edge_marker_lw": 30.0, + "min_dyad_lw": 1.0, + "max_dyad_lw": 5.0, + "node_fc_cmap": cm.Reds, + "node_ec_cmap": cm.RdBu, + "dyad_fc_cmap": cm.Blues, + "edge_marker_fc_cmap": cm.Blues, + "edge_marker_ec_cmap": cm.Greys, + } + + settings.update(kwargs) + + if ax is None: + ax = plt.gca() + + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.axis("off") + + if not max_order: + max_order = max_edge_order(H) + + if not node_pos or not edge_pos: + node_pos, edge_pos = bipartite_spring_layout(H) + + dyad_lw = _scalar_arg_to_dict( + dyad_lw, H.edges, settings["min_dyad_lw"], settings["max_dyad_lw"] + ) + + if dyad_color is None: + dyad_color = H.edges.size + + dyad_color = _color_arg_to_dict(dyad_color, H.edges, settings["dyad_fc_cmap"]) + + if edge_marker_fc is None: + edge_marker_fc = H.edges.size + + edge_marker_fc = _color_arg_to_dict( + edge_marker_fc, H.edges, settings["edge_marker_fc_cmap"] + ) + + if edge_marker_ec is None: + edge_marker_ec = H.edges.size + + edge_marker_ec = _color_arg_to_dict( + edge_marker_ec, H.edges, settings["edge_marker_ec_cmap"] + ) + + edge_marker_lw = _scalar_arg_to_dict( + edge_marker_lw, + H.edges, + settings["min_edge_marker_lw"], + settings["max_edge_marker_lw"], + ) + + edge_marker_size = _scalar_arg_to_dict( + edge_marker_size, + H.edges, + settings["min_edge_marker_size"], + settings["max_edge_marker_size"], + ) + + node_size = _scalar_arg_to_dict( + node_size, H.nodes, settings["min_node_size"], settings["max_node_size"] + ) + node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"]) + node_ec = _color_arg_to_dict(node_ec, H.nodes, settings["node_ec_cmap"]) + node_lw = _scalar_arg_to_dict( + node_lw, + H.nodes, + settings["min_node_lw"], + settings["max_node_lw"], + ) + + for id, e in H.edges.members(dtype=dict).items(): + d = len(e) - 1 + if d > 0: + x_edge, y_edge = edge_pos[id] + for n in e: + x_node, y_node = node_pos[n] + line = plt.Line2D( + [x_node, x_edge], + [y_node, y_edge], + color=dyad_color[id], + lw=dyad_lw[id], + zorder=max_order - d, + ) + ax.add_line(line) + + (xs, ys, s, c, ec, lw,) = zip( + *[ + ( + node_pos[i][0], + node_pos[i][1], + node_size[i], + node_fc[i], + node_ec[i], + node_lw[i], + ) + for i in H.nodes + ] + ) + ax.scatter( + x=xs, + y=ys, + marker=node_marker, + s=s, + c=c, + edgecolors=ec, + linewidths=lw, + zorder=max_order, + ) + + (xs, ys, s, c, ec, lw,) = zip( + *[ + ( + edge_pos[i][0], + edge_pos[i][1], + edge_marker_size[i], + edge_marker_fc[i], + edge_marker_ec[i], + edge_marker_lw[i], + ) + for i in H.edges + ] + ) + ax.scatter( + x=xs, + y=ys, + marker=edge_marker, + s=s, + c=c, + edgecolors=ec, + linewidths=lw, + zorder=max_order, + ) + + # compute axis limits + _update_lims(node_pos, ax) + + return ax diff --git a/xgi/drawing/layout.py b/xgi/drawing/layout.py index 9d3033dd3..4d073bc63 100644 --- a/xgi/drawing/layout.py +++ b/xgi/drawing/layout.py @@ -8,12 +8,14 @@ from .. import convert from ..algorithms import max_edge_order +from ..convert import to_bipartite_graph from ..core import SimplicialComplex __all__ = [ "random_layout", "pairwise_spring_layout", "barycenter_spring_layout", + "bipartite_spring_layout", "weighted_barycenter_spring_layout", "pca_transform", "circular_layout", @@ -267,6 +269,70 @@ def barycenter_spring_layout( return pos +def bipartite_spring_layout(H, seed=None, k=None, **kwargs): + """ + Position the nodes using Fruchterman-Reingold force-directed + algorithm using an augmented version of the the graph projection + of the hypergraph (or simplicial complex), where phantom nodes + (barycenters) are created for each edge composed by more than two nodes. + If a simplicial complex is provided the results will be based on the + hypergraph constructed from its maximal simplices. + + Parameters + ---------- + H : xgi Hypergraph or SimplicialComplex + A position will be assigned to every node in H. + seed : int, RandomState instance or None optional (default=None) + Set the random state for deterministic node layouts. + If int, `seed` is the seed used by the random number generator, + If None (default), random numbers are sampled from the + numpy random number generator without initialization. + k : float + The spring constant of the links. When k=None (default), + k = 1/sqrt(N). For more information, see the documentation + for the NetworkX spring_layout() function. + kwargs : + Optional arguments for the NetworkX spring_layout() function. + See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html + + + Returns + ------- + pos : dict + A dictionary of positions keyed by node + + See Also + -------- + random_layout + pairwise_spring_layout + weighted_barycenter_spring_layout + + Examples + -------- + >>> import xgi + >>> N = 50 + >>> ps = [0.1, 0.01] + >>> H = xgi.random_hypergraph(N, ps) + >>> pos = xgi.barycenter_spring_layout(H) + """ + if seed is not None: + random.seed(seed) + + if isinstance(H, SimplicialComplex): + H = convert.from_max_simplices(H) + + G, nodedict, edgedict = to_bipartite_graph(H, index=True) + + # Creating a dictionary for the position of the nodes with the standard spring + # layout + pos = nx.spring_layout(G, seed=seed, k=k, **kwargs) + + node_pos = {nodedict[i]: pos[i] for i in nodedict} + edge_pos = {edgedict[i]: pos[i] for i in edgedict} + + return node_pos, edge_pos + + def weighted_barycenter_spring_layout( H, return_phantom_graph=False, seed=None, k=None, **kwargs ):