Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added draw_bipartite #465

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 277 additions & 1 deletion xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
_scalar_arg_to_dict,
_update_lims,
)
from .layout import _augmented_projection, barycenter_spring_layout
from .layout import (
_augmented_projection,
barycenter_spring_layout,
bipartite_spring_layout,
)

__all__ = [
"draw",
Expand All @@ -36,6 +40,7 @@
"draw_hypergraph_hull",
"draw_multilayer",
"draw_dihypergraph",
"draw_bipartite",
]


Expand Down Expand Up @@ -1767,3 +1772,274 @@ def draw_dihypergraph(
_update_lims(pos, ax)

return ax


def draw_bipartite(
H,
node_pos=None,
edge_pos=None,
ax=None,
dyad_color="black",
dyad_lw=1,
node_fc="white",
node_ec="black",
node_marker="o",
node_lw=1,
node_size=10,
edge_marker_fc="lightblue",
edge_marker_ec="black",
edge_marker="s",
edge_marker_lw=1,
edge_marker_size=10,
max_order=None,
settings=None,
**kwargs,
):
"""Draw a hypergraph as a bipartite network.

Parameters
----------
H : Hypergraph
The hypergraph to draw.
ax : matplotlib.pyplot.axes, optional
Axis to draw on. If None (default), get the current axes.
lines_fc : str, dict, iterable, optional
Color of the hyperedges (lines). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
lines_lw : int, float, dict, iterable, optional
Line width of the hyperedges (lines). If int or float, use the same width for
all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other
iterable, assume the widths are specified in the same order as the hyperedges
are found in DH.edges. By default, 1.5.
line_head_width : float, optional
Length of arrows' heads. By default, 0.05
node_fc : str, dict, iterable, or NodeStat, optional
Color of the nodes. If str, use the same color for all nodes. If a dict, must
contain (node_id: color_str) pairs. If other iterable, assume the colors are
specified in the same order as the nodes are found in H.nodes. If NodeStat, use
the colormap specified with node_fc_cmap. By default, "white".
node_ec : str, dict, iterable, or NodeStat, optional
Color of node borders. If str, use the same color for all nodes. If a dict,
must contain (node_id: color_str) pairs. If other iterable, assume the colors
are specified in the same order as the nodes are found in H.nodes. If NodeStat,
use the colormap specified with node_ec_cmap. By default, "black".
node_lw : int, float, dict, iterable, or NodeStat, optional
Line width of the node borders in pixels. If int or float, use the same width
for all node borders. If a dict, must contain (node_id: width) pairs. If
iterable, assume the widths are specified in the same order as the nodes are
found in H.nodes. If NodeStat, use a monotonic linear interpolation defined
between min_node_lw and max_node_lw. By default, 1.
node_size : int, float, dict, iterable, or NodeStat, optional
Radius of the nodes in pixels. If int or float, use the same radius for all
nodes. If a dict, must contain (node_id: radius) pairs. If iterable, assume
the radiuses are specified in the same order as the nodes are found in
H.nodes. If NodeStat, use a monotonic linear interpolation defined between
min_node_size and max_node_size. By default, 15.
edge_marker_toggle: bool, optional
If True then marker representing the hyperedges are drawn. By default True.
edge_marker_fc: str, dict, iterable, optional
Filling color of the hyperedges (markers). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
edge_marker_ec: str, dict, iterable, optional
Edge color of the hyperedges (markers). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
edge_marker: str, optional
Marker used for the hyperedges. By default 's' (square marker).
max_order : int, optional
Maximum of hyperedges to plot. If None (default), plots all orders.
node_labels : bool or dict, optional
If True, draw ids on the nodes. If a dict, must contain (node_id: label) pairs.
By default, False.
hyperedge_labels : bool or dict, optional
If True, draw ids on the hyperedges. If a dict, must contain (edge_id: label)
pairs. By default, False.
**kwargs : optional args
Alternate default values. Values that can be overwritten are the following:
* min_node_size
* max_node_size
* min_node_lw
* max_node_lw
* node_fc_cmap
* node_ec_cmap
* min_lines_lw
* max_lines_lw
* lines_fc_cmap
* edge_fc_cmap
* edge_marker_fc_cmap
* edge_marker_ec_cmap

Returns
-------
ax : matplotlib.pyplot.axes

Raises
------
XGIError
If something different than a DiHypergraph is passed.

See Also
--------
draw
draw_nodes
draw_node_labels

"""
if not isinstance(H, Hypergraph):
raise XGIError("The input must be a Hypergraph")

if settings is None:
settings = {
"min_node_size": 10.0,
"max_node_size": 30.0,
"min_edge_marker_size": 10.0,
"max_edge_marker_size": 30.0,
"min_node_lw": 10.0,
"max_node_lw": 30.0,
"min_edge_marker_lw": 10.0,
"max_edge_marker_lw": 30.0,
"min_dyad_lw": 1.0,
"max_dyad_lw": 5.0,
"node_fc_cmap": cm.Reds,
"node_ec_cmap": cm.RdBu,
"dyad_fc_cmap": cm.Blues,
"edge_marker_fc_cmap": cm.Blues,
"edge_marker_ec_cmap": cm.Greys,
}

settings.update(kwargs)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")

if not max_order:
max_order = max_edge_order(H)

if not node_pos or not edge_pos:
node_pos, edge_pos = bipartite_spring_layout(H)

dyad_lw = _scalar_arg_to_dict(
dyad_lw, H.edges, settings["min_dyad_lw"], settings["max_dyad_lw"]
)

if dyad_color is None:
dyad_color = H.edges.size

dyad_color = _color_arg_to_dict(dyad_color, H.edges, settings["dyad_fc_cmap"])

if edge_marker_fc is None:
edge_marker_fc = H.edges.size

edge_marker_fc = _color_arg_to_dict(
edge_marker_fc, H.edges, settings["edge_marker_fc_cmap"]
)

if edge_marker_ec is None:
edge_marker_ec = H.edges.size

edge_marker_ec = _color_arg_to_dict(
edge_marker_ec, H.edges, settings["edge_marker_ec_cmap"]
)

edge_marker_lw = _scalar_arg_to_dict(
edge_marker_lw,
H.edges,
settings["min_edge_marker_lw"],
settings["max_edge_marker_lw"],
)

edge_marker_size = _scalar_arg_to_dict(
edge_marker_size,
H.edges,
settings["min_edge_marker_size"],
settings["max_edge_marker_size"],
)

node_size = _scalar_arg_to_dict(
node_size, H.nodes, settings["min_node_size"], settings["max_node_size"]
)
node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"])
node_ec = _color_arg_to_dict(node_ec, H.nodes, settings["node_ec_cmap"])
node_lw = _scalar_arg_to_dict(
node_lw,
H.nodes,
settings["min_node_lw"],
settings["max_node_lw"],
)

for id, e in H.edges.members(dtype=dict).items():
d = len(e) - 1
if d > 0:
x_edge, y_edge = edge_pos[id]
for n in e:
x_node, y_node = node_pos[n]
line = plt.Line2D(
[x_node, x_edge],
[y_node, y_edge],
color=dyad_color[id],
lw=dyad_lw[id],
zorder=max_order - d,
)
ax.add_line(line)

(xs, ys, s, c, ec, lw,) = zip(
*[
(
node_pos[i][0],
node_pos[i][1],
node_size[i],
node_fc[i],
node_ec[i],
node_lw[i],
)
for i in H.nodes
]
)
ax.scatter(
x=xs,
y=ys,
marker=node_marker,
s=s,
c=c,
edgecolors=ec,
linewidths=lw,
zorder=max_order,
)

(xs, ys, s, c, ec, lw,) = zip(
*[
(
edge_pos[i][0],
edge_pos[i][1],
edge_marker_size[i],
edge_marker_fc[i],
edge_marker_ec[i],
edge_marker_lw[i],
)
for i in H.edges
]
)
ax.scatter(
x=xs,
y=ys,
marker=edge_marker,
s=s,
c=c,
edgecolors=ec,
linewidths=lw,
zorder=max_order,
)

# compute axis limits
_update_lims(node_pos, ax)

return ax
66 changes: 66 additions & 0 deletions xgi/drawing/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@

from .. import convert
from ..algorithms import max_edge_order
from ..convert import to_bipartite_graph
from ..core import SimplicialComplex

__all__ = [
"random_layout",
"pairwise_spring_layout",
"barycenter_spring_layout",
"bipartite_spring_layout",
"weighted_barycenter_spring_layout",
"pca_transform",
"circular_layout",
Expand Down Expand Up @@ -267,6 +269,70 @@ def barycenter_spring_layout(
return pos


def bipartite_spring_layout(H, seed=None, k=None, **kwargs):
"""
Position the nodes using Fruchterman-Reingold force-directed
algorithm using an augmented version of the the graph projection
of the hypergraph (or simplicial complex), where phantom nodes
(barycenters) are created for each edge composed by more than two nodes.
If a simplicial complex is provided the results will be based on the
hypergraph constructed from its maximal simplices.

Parameters
----------
H : xgi Hypergraph or SimplicialComplex
A position will be assigned to every node in H.
seed : int, RandomState instance or None optional (default=None)
Set the random state for deterministic node layouts.
If int, `seed` is the seed used by the random number generator,
If None (default), random numbers are sampled from the
numpy random number generator without initialization.
k : float
The spring constant of the links. When k=None (default),
k = 1/sqrt(N). For more information, see the documentation
for the NetworkX spring_layout() function.
kwargs :
Optional arguments for the NetworkX spring_layout() function.
See https://networkx.org/documentation/stable/reference/generated/networkx.drawing.layout.spring_layout.html


Returns
-------
pos : dict
A dictionary of positions keyed by node

See Also
--------
random_layout
pairwise_spring_layout
weighted_barycenter_spring_layout

Examples
--------
>>> import xgi
>>> N = 50
>>> ps = [0.1, 0.01]
>>> H = xgi.random_hypergraph(N, ps)
>>> pos = xgi.barycenter_spring_layout(H)
"""
if seed is not None:
random.seed(seed)

if isinstance(H, SimplicialComplex):
H = convert.from_max_simplices(H)

G, nodedict, edgedict = to_bipartite_graph(H, index=True)

# Creating a dictionary for the position of the nodes with the standard spring
# layout
pos = nx.spring_layout(G, seed=seed, k=k, **kwargs)

node_pos = {nodedict[i]: pos[i] for i in nodedict}
edge_pos = {edgedict[i]: pos[i] for i in edgedict}

return node_pos, edge_pos


def weighted_barycenter_spring_layout(
H, return_phantom_graph=False, seed=None, k=None, **kwargs
):
Expand Down