From aa4aafb07f3a622591d84d0d096169addec8bf68 Mon Sep 17 00:00:00 2001 From: Nicholas Landry Date: Fri, 3 Nov 2023 12:57:05 -0400 Subject: [PATCH] update --- tests/drawing/test_draw.py | 17 +-- xgi/drawing/draw.py | 251 ++++++++++++++++--------------------- xgi/drawing/layout.py | 3 +- 3 files changed, 111 insertions(+), 160 deletions(-) diff --git a/tests/drawing/test_draw.py b/tests/drawing/test_draw.py index ded692cc1..b1cdbe9c8 100644 --- a/tests/drawing/test_draw.py +++ b/tests/drawing/test_draw.py @@ -410,14 +410,10 @@ def test_draw_multilayer(edgelist8): assert edge_coll.get_zorder() == 2 # edges # node_fc - assert np.all( - node_coll.get_facecolor() == np.array([[1, 1, 1, 1]]) - ) # white + assert np.all(node_coll.get_facecolor() == np.array([[1, 1, 1, 1]])) # white # node_ec - assert np.all( - node_coll.get_edgecolor() == np.array([[0, 0, 0, 1]]) - ) # black + assert np.all(node_coll.get_edgecolor() == np.array([[0, 0, 0, 1]])) # black # node_lw assert np.all(node_coll.get_linewidth() == np.array([1])) @@ -499,14 +495,10 @@ def test_draw_multilayer(edgelist8): ) # node_fc - assert np.all( - node_coll4.get_facecolor() == np.array([[1, 0, 0, 1]]) - ) # red + assert np.all(node_coll4.get_facecolor() == np.array([[1, 0, 0, 1]])) # red # node_ec - assert np.all( - node_coll4.get_edgecolor() == np.array([[0, 0, 1, 1]]) - ) # blue + assert np.all(node_coll4.get_edgecolor() == np.array([[0, 0, 1, 1]])) # blue # node_lw assert np.all(node_coll4.get_linewidth() == np.array([1])) @@ -514,7 +506,6 @@ def test_draw_multilayer(edgelist8): # node_size assert np.all(node_coll4.get_sizes() == np.array([10**2])) - plt.close() diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 1d82032d4..e8138c512 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -19,6 +19,7 @@ from .. import convert from ..algorithms import max_edge_order, unique_edge_sizes +from ..convert import to_bipartite_edgelist from ..core import DiHypergraph, Hypergraph, SimplicialComplex from ..exception import XGIError from ..utils import subfaces @@ -47,6 +48,7 @@ "draw_hypergraph_hull", "draw_multilayer", "draw_dihypergraph", + "draw_bipartite", ] @@ -373,7 +375,7 @@ def draw_nodes( * "max_node_size" (default: 30) * "min_node_lw" (default: 0) * "max_node_lw" (default: 5) - + kwargs : optional keywords See `draw_node_labels` for a description of optional keywords. @@ -569,7 +571,7 @@ def draw_hyperedges( Default parameters. Keys that may be useful to override default settings: * "min_dyad_lw" (default: 1) * "max_dyad_lw" (default: 10) - + kwargs : optional keywords See `draw_hyperedge_labels` for a description of optional keywords. @@ -800,7 +802,7 @@ def draw_simplices( Default parameters. Keys that may be useful to override default settings: * "min_dyad_lw" (default: 1) * "max_dyad_lw" (default: 10) - + kwargs : optional keywords See `draw_hyperedge_labels` for a description of optional keywords. @@ -1592,8 +1594,8 @@ def draw_multilayer( ydiff = np.max(ys) - np.min(ys) ymin = np.min(ys) - ydiff * 0.1 ymax = np.max(ys) + ydiff * 0.1 - xmin = np.min(xs) - xdiff * 0.1 #* (width / height) - xmax = np.max(xs) + xdiff * 0.1 #* (width / height) + xmin = np.min(xs) - xdiff * 0.1 # * (width / height) + xmax = np.max(xs) + xdiff * 0.1 # * (width / height) xx, yy = np.meshgrid([xmin, xmax], [ymin, ymax]) # plot layers @@ -2140,20 +2142,20 @@ def draw_bipartite( node_pos=None, edge_pos=None, ax=None, - dyad_color="black", - dyad_lw=1, node_fc="white", node_ec="black", - node_marker="o", + node_shape="o", node_lw=1, node_size=10, - edge_marker_fc="lightblue", + edge_marker_fc="blue", edge_marker_ec="black", - edge_marker="s", + edge_marker_shape="s", edge_marker_lw=1, edge_marker_size=10, + dyad_color="black", + dyad_lw=1, max_order=None, - settings=None, + rescale_sizes=True, **kwargs, ): """Draw a hypergraph as a bipartite network. @@ -2226,14 +2228,18 @@ def draw_bipartite( * max_node_size * min_node_lw * max_node_lw + * min_edge_marker_size + * max_edge_marker_size + * min_edge_marker_lw + * max_edge_marker_lw * node_fc_cmap * node_ec_cmap - * min_lines_lw - * max_lines_lw - * lines_fc_cmap - * edge_fc_cmap * edge_marker_fc_cmap * edge_marker_ec_cmap + * min_dyad_lw + * max_dyad_lw + * dyad_color_cmap + Returns ------- @@ -2254,153 +2260,106 @@ def draw_bipartite( if not isinstance(H, Hypergraph): raise XGIError("The input must be a Hypergraph") - if settings is None: - settings = { - "min_node_size": 10.0, - "max_node_size": 30.0, - "min_edge_marker_size": 10.0, - "max_edge_marker_size": 30.0, - "min_node_lw": 10.0, - "max_node_lw": 30.0, - "min_edge_marker_lw": 10.0, - "max_edge_marker_lw": 30.0, - "min_dyad_lw": 1.0, - "max_dyad_lw": 5.0, - "node_fc_cmap": cm.Reds, - "node_ec_cmap": cm.RdBu, - "dyad_fc_cmap": cm.Blues, - "edge_marker_fc_cmap": cm.Blues, - "edge_marker_ec_cmap": cm.Greys, - } + settings = { + "min_node_lw": 10.0, + "max_node_lw": 30.0, + "min_node_size": 10.0, + "max_node_size": 30.0, + "min_edge_marker_lw": 10.0, + "max_edge_marker_lw": 30.0, + "min_edge_marker_size": 10.0, + "max_edge_marker_size": 30.0, + "min_dyad_lw": 1.0, + "max_dyad_lw": 5.0, + "node_fc_cmap": cm.Reds, + "node_ec_cmap": cm.RdBu, + "dyad_color_cmap": cm.Blues, + "edge_marker_fc_cmap": cm.Greys, + "edge_marker_ec_cmap": cm.Blues, + } settings.update(kwargs) - if ax is None: - ax = plt.gca() - - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") + node_settings = { + "min_node_lw": settings["min_node_lw"], + "max_node_lw": settings["max_node_lw"], + "min_node_size": settings["min_node_size"], + "max_node_size": settings["max_node_size"], + "node_fc_cmap": settings["node_fc_cmap"], + "node_ec_cmap": settings["node_ec_cmap"], + } - if not max_order: - max_order = max_edge_order(H) + edge_marker_settings = { + "min_node_lw": settings["min_edge_marker_lw"], + "max_node_lw": settings["max_edge_marker_lw"], + "min_node_size": settings["min_edge_marker_size"], + "max_node_size": settings["max_edge_marker_size"], + "node_fc_cmap": settings["edge_marker_fc_cmap"], + "node_ec_cmap": settings["edge_marker_ec_cmap"], + } if not node_pos or not edge_pos: node_pos, edge_pos = bipartite_spring_layout(H) - dyad_lw = _scalar_arg_to_dict( - dyad_lw, H.edges, settings["min_dyad_lw"], settings["max_dyad_lw"] - ) - - if dyad_color is None: - dyad_color = H.edges.size - - dyad_color = _color_arg_to_dict(dyad_color, H.edges, settings["dyad_fc_cmap"]) - - if edge_marker_fc is None: - edge_marker_fc = H.edges.size - - edge_marker_fc = _color_arg_to_dict( - edge_marker_fc, H.edges, settings["edge_marker_fc_cmap"] - ) - - if edge_marker_ec is None: - edge_marker_ec = H.edges.size + if ax is None: + ax = plt.gca() - edge_marker_ec = _color_arg_to_dict( - edge_marker_ec, H.edges, settings["edge_marker_ec_cmap"] + ax, node_collection = draw_nodes( + H=H, + pos=node_pos, + ax=ax, + node_fc=node_fc, + node_ec=node_ec, + node_lw=node_lw, + node_size=node_size, + node_shape=node_shape, + zorder=2, + params=node_settings, + node_labels=None, + rescale_sizes=rescale_sizes, + **kwargs, ) - edge_marker_lw = _scalar_arg_to_dict( - edge_marker_lw, - H.edges, - settings["min_edge_marker_lw"], - settings["max_edge_marker_lw"], + ax, edge_marker_collection = draw_nodes( + H=H.dual(), + pos=edge_pos, + ax=ax, + node_fc=edge_marker_fc, + node_ec=edge_marker_ec, + node_lw=edge_marker_lw, + node_size=edge_marker_size, + node_shape=edge_marker_shape, + zorder=1, + params=edge_marker_settings, + node_labels=None, + rescale_sizes=rescale_sizes, + **kwargs, ) - edge_marker_size = _scalar_arg_to_dict( - edge_marker_size, - H.edges, - settings["min_edge_marker_size"], - settings["max_edge_marker_size"], - ) + dyads = to_bipartite_edgelist(H) + dyad_lw = _draw_arg_to_arr(dyad_lw) + # dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads) + # check validity of input values + if np.any(dyad_lw < 0): + raise ValueError("dyad_lw cannot contain negative values.") - node_size = _scalar_arg_to_dict( - node_size, H.nodes, settings["min_node_size"], settings["max_node_size"] - ) - node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"]) - node_ec = _color_arg_to_dict(node_ec, H.nodes, settings["node_ec_cmap"]) - node_lw = _scalar_arg_to_dict( - node_lw, - H.nodes, - settings["min_node_lw"], - settings["max_node_lw"], - ) + # interpolate if needed + if rescale_sizes and isinstance(dyad_lw, np.ndarray): + dyad_lw = _interp_draw_arg( + dyad_lw, settings["min_dyad_lw"], settings["max_dyad_lw"] + ) - for id, e in H.edges.members(dtype=dict).items(): - d = len(e) - 1 - if d > 0: - x_edge, y_edge = edge_pos[id] - for n in e: - x_node, y_node = node_pos[n] - line = plt.Line2D( - [x_node, x_edge], - [y_node, y_edge], - color=dyad_color[id], - lw=dyad_lw[id], - zorder=max_order - d, - ) - ax.add_line(line) - - (xs, ys, s, c, ec, lw,) = zip( - *[ - ( - node_pos[i][0], - node_pos[i][1], - node_size[i], - node_fc[i], - node_ec[i], - node_lw[i], - ) - for i in H.nodes - ] - ) - ax.scatter( - x=xs, - y=ys, - marker=node_marker, - s=s, - c=c, - edgecolors=ec, - linewidths=lw, - zorder=max_order, - ) + dyad_pos = np.asarray([(node_pos[list(e)[0]], edge_pos[list(e)[1]]) for e in dyads]) - (xs, ys, s, c, ec, lw,) = zip( - *[ - ( - edge_pos[i][0], - edge_pos[i][1], - edge_marker_size[i], - edge_marker_fc[i], - edge_marker_ec[i], - edge_marker_lw[i], - ) - for i in H.edges - ] - ) - ax.scatter( - x=xs, - y=ys, - marker=edge_marker, - s=s, - c=c, - edgecolors=ec, - linewidths=lw, - zorder=max_order, + dyad_collection = LineCollection( + dyad_pos, + colors=dyad_color, + linewidths=dyad_lw, + antialiaseds=(1,), + cmap=settings["dyad_color_cmap"], + zorder=0, ) - # compute axis limits - _update_lims(node_pos, ax) + ax.add_collection(dyad_collection) - return ax + return ax, (node_collection, edge_marker_collection, dyad_collection) diff --git a/xgi/drawing/layout.py b/xgi/drawing/layout.py index 9804001f2..af34d00d9 100644 --- a/xgi/drawing/layout.py +++ b/xgi/drawing/layout.py @@ -7,8 +7,8 @@ from numpy.linalg import inv, svd from .. import convert -from ..core import SimplicialComplex from ..convert import to_bipartite_graph +from ..core import SimplicialComplex __all__ = [ "random_layout", @@ -19,6 +19,7 @@ "circular_layout", "spiral_layout", "barycenter_kamada_kawai_layout", + "bipartite_spring_layout", ]