Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing things in drawing functions #476

Merged
merged 17 commits into from
Oct 25, 2023
Merged
32 changes: 30 additions & 2 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,12 @@ def test_correct_number_of_collections_draw_multilayer(edgelist8):
def test_draw_dihypergraph(diedgelist2, edgelist8):
DH = xgi.DiHypergraph(diedgelist2)

fig, ax1 = plt.subplots()
fig1, ax1 = plt.subplots()
ax1 = xgi.draw_dihypergraph(DH, ax=ax1)
fig2, ax2 = plt.subplots()
ax2 = xgi.draw_dihypergraph(
DH, ax=ax2, node_fc="red", node_ec="blue", node_lw=2, node_size=20
)

# number of elements
assert len(ax1.lines) == 7 # number of source nodes
Expand All @@ -471,6 +475,30 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
DH.edges.filterby("size", 1)
) # hyperedges markers + nodes

# node face colors
assert np.all(
ax1.collections[-1].get_facecolor() == np.array([[1, 1, 1, 1]])
) # white
assert np.all(
ax2.collections[-1].get_facecolor() == np.array([[1, 0, 0, 1]])
) # red

# node edge colors
assert np.all(
ax1.collections[-1].get_edgecolor() == np.array([[0, 0, 0, 1]])
) # black
assert np.all(
ax2.collections[-1].get_edgecolor() == np.array([[0, 0, 1, 1]])
) # blue

# node_lw
assert np.all(ax1.collections[-1].get_linewidth() == np.array([1]))
assert np.all(ax2.collections[-1].get_linewidth() == np.array([2]))

# node_size
assert np.all(ax1.collections[-1].get_sizes() == np.array([15**2]))
assert np.all(ax2.collections[-1].get_sizes() == np.array([20**2]))

# zorder
for line, z in zip(ax1.lines, [1, 1, 1, 1, 0, 0, 0]): # lines for source nodes
assert line.get_zorder() == z
Expand All @@ -479,7 +507,7 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
for collection in ax1.collections:
assert collection.get_zorder() == 3 # nodes and hyperedges markers

plt.close()
plt.close("all")

# test toggle for edges
fig, ax2 = plt.subplots()
Expand Down
129 changes: 68 additions & 61 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,8 +202,6 @@ def draw(
"max_dyad_lw": 10,
"min_node_lw": 0,
"max_node_lw": 5,
"edge_fc_cmap": "crest_r", # for compatibility with simplices until update
"dyad_color_cmap": cm.Greys, # for compatibility with simplices until update
}

settings.update(kwargs)
Expand Down Expand Up @@ -432,7 +430,7 @@ def draw_nodes(
node_lw, settings["min_node_lw"], settings["max_node_lw"]
)

node_size = node_size**2
node_size = np.array(node_size) ** 2

# plot
node_collection = ax.scatter(
Expand Down Expand Up @@ -593,10 +591,8 @@ def draw_hyperedges(
dyad_lw = _draw_arg_to_arr(dyad_lw)

# parse colors
dyad_color, dyad_c_mapped = _parse_color_arg(
dyad_color, dyad_color_cmap, list(dyads)
)
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, edge_fc_cmap, list(edges))
dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, list(edges))

# check validity of input values
if np.any(dyad_lw < 0):
Expand Down Expand Up @@ -1076,14 +1072,17 @@ def _draw_hull(node_pos, ax, edges_ec, facecolor, alpha, zorder, radius):
Parameters
----------
node_pos : np.array
nx2 dimensional array containing positions of the nodes
Array of dimension (n, 2) containing node positions
ax : matplotlib.pyplot.axes
Axis to plot on
edges_ec : str
Color of the border of the convex hull
facecolor : str
Filling color of the convex hull
alpha : float
Transparency of the convex hull
zorder : float
Vertical order on which to plot
radius : float
Radius of the convex hull in the vicinity of the nodes.

Expand Down Expand Up @@ -1435,6 +1434,13 @@ def draw_multilayer(
-------
ax : matplotlib Axes3DSubplot
The subplot with the multilayer network visualization.


Notes
-----
The effect of the `sep` parameter is limited by the `height` of the figure.
If `sep` is larger than a certain value depending on `height`, no additional
effect will be seen.
"""
settings = {
"min_node_size": 10.0,
Expand Down Expand Up @@ -1604,13 +1610,14 @@ def draw_multilayer(
def draw_dihypergraph(
DH,
ax=None,
lines_fc=None,
lines_lw=1.5,
line_head_width=0.05,
node_fc="white",
node_ec="black",
node_lw=1,
node_size=15,
node_fc_cmap="Reds",
lines_fc=None,
lines_lw=1.5,
line_head_width=0.05,
edge_marker_toggle=True,
edge_marker_fc=None,
edge_marker_ec=None,
Expand All @@ -1621,6 +1628,7 @@ def draw_dihypergraph(
node_labels=False,
hyperedge_labels=False,
settings=None,
rescale_sizes=True,
**kwargs,
):
"""Draw a directed hypergraph
Expand All @@ -1631,18 +1639,6 @@ def draw_dihypergraph(
The directed hypergraph to draw.
ax : matplotlib.pyplot.axes, optional
Axis to draw on. If None (default), get the current axes.
lines_fc : str, dict, iterable, optional
Color of the hyperedges (lines). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
lines_lw : int, float, dict, iterable, optional
Line width of the hyperedges (lines). If int or float, use the same width for
all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other
iterable, assume the widths are specified in the same order as the hyperedges
are found in DH.edges. By default, 1.5.
line_head_width : float, optional
Length of arrows' heads. By default, 0.05
node_fc : str, dict, iterable, or NodeStat, optional
Color of the nodes. If str, use the same color for all nodes. If a dict, must
contain (node_id: color_str) pairs. If other iterable, assume the colors are
Expand All @@ -1665,6 +1661,21 @@ def draw_dihypergraph(
the radiuses are specified in the same order as the nodes are found in
H.nodes. If NodeStat, use a monotonic linear interpolation defined between
min_node_size and max_node_size. By default, 15.
node_fc_cmap : colormap
Colormap for mapping node colors. By default, "Reds". Ignored, if `node_fc` is
a str (single color).
lines_fc : str, dict, iterable, optional
Color of the hyperedges (lines). If str, use the same color for all hyperedges.
If a dict, must contain (hyperedge_id: color_str) pairs. If other iterable,
assume the colors are specified in the same order as the hyperedges are found
in DH.edges. If None (default), use the size of the hyperedges.
lines_lw : int, float, dict, iterable, optional
Line width of the hyperedges (lines). If int or float, use the same width for
all hyperedges. If a dict, must contain (hyperedge_id: width) pairs. If other
iterable, assume the widths are specified in the same order as the hyperedges
are found in DH.edges. By default, 1.5.
line_head_width : float, optional
Length of arrows' heads. By default, 0.05
edge_marker_toggle: bool, optional
If True then marker representing the hyperedges are drawn. By default True.
edge_marker_fc: str, dict, iterable, optional
Expand Down Expand Up @@ -1693,14 +1704,11 @@ def draw_dihypergraph(
* max_node_size
* min_node_lw
* max_node_lw
* node_fc_cmap
* node_ec_cmap
* min_lines_lw
* max_lines_lw
* lines_fc_cmap
* edge_fc_cmap
* edge_marker_fc_cmap
* edge_marker_ec_cmap

Returns
-------
Expand All @@ -1721,36 +1729,39 @@ def draw_dihypergraph(
if not isinstance(DH, DiHypergraph):
raise XGIError("The input must be a DiHypergraph")

if settings is None:
settings = {
"min_node_size": 10.0,
"max_node_size": 30.0,
"min_node_lw": 1.0,
"max_node_lw": 5.0,
"node_fc_cmap": cm.Reds,
"node_ec_cmap": cm.Greys,
"min_lines_lw": 2.0,
"max_lines_lw": 10.0,
"lines_fc_cmap": cm.Blues,
"edge_marker_fc_cmap": cm.Blues,
"edge_marker_ec_cmap": cm.Greys,
}
settings = {
"min_node_size": 5,
"max_node_size": 30,
"min_node_lw": 0,
"max_node_lw": 5,
"min_lines_lw": 2.0,
"max_lines_lw": 10.0,
"lines_fc_cmap": plt.cm.Blues,
"edge_marker_fc_cmap": plt.cm.Blues,
}

settings.update(kwargs)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")

# convert to hypergraph in order to use the augmented projection function
H_conv = convert.convert_to_hypergraph(DH)

(
ax,
_,
) = _draw_init(H_conv, ax, True)

if not max_order:
max_order = max_edge_order(H_conv)

# convert all formats to ndarray
node_size = _draw_arg_to_arr(node_size)

# interpolate if needed
if rescale_sizes and isinstance(node_size, np.ndarray):
node_size = _interp_draw_arg(
node_size, settings["min_node_size"], settings["max_node_size"]
)

lines_lw = _scalar_arg_to_dict(
lines_lw, H_conv.edges, settings["min_lines_lw"], settings["max_lines_lw"]
)
Expand All @@ -1767,17 +1778,6 @@ def draw_dihypergraph(
edge_marker_fc, H_conv.edges, settings["edge_marker_fc_cmap"]
)

if edge_marker_ec is None:
edge_marker_ec = H_conv.edges.size

edge_marker_ec = _color_arg_to_dict(
edge_marker_ec, H_conv.edges, settings["edge_marker_ec_cmap"]
)

node_size = _scalar_arg_to_dict(
node_size, H_conv.nodes, settings["min_node_size"], settings["max_node_size"]
)

G_aug = _augmented_projection(H_conv)
for dyad in H_conv.edges.filterby("size", 2).members():
try:
Expand Down Expand Up @@ -1812,8 +1812,13 @@ def draw_dihypergraph(
# the following to avoid the point of the arrow overlapping the node
distance = np.hypot(dx, dy)
direction_vector = np.array([dx, dy]) / distance
size = (
node_size
if not isinstance(node_size, np.ndarray)
else node_size[node]
)
shortened_distance = (
distance - node_size[node] * 0.003
distance - size * 0.003
) # Calculate the shortened length
dx = direction_vector[0] * shortened_distance
dy = direction_vector[1] * shortened_distance
Expand All @@ -1836,7 +1841,7 @@ def draw_dihypergraph(
marker=edge_marker,
s=edge_marker_size**2,
c=edge_marker_fc[id],
edgecolors=edge_marker_ec[id],
edgecolors=edge_marker_ec,
linewidths=edge_marker_lw,
zorder=max_order,
)
Expand All @@ -1854,17 +1859,19 @@ def draw_dihypergraph(
label_kwds["font_size_edges"] = 6
draw_hyperedge_labels(H_conv, pos, hyperedge_labels, ax_edges=ax, **label_kwds)

draw_nodes(
ax, node_collection = draw_nodes(
H=H_conv,
pos=pos,
ax=ax,
node_fc=node_fc,
node_ec=node_ec,
node_lw=node_lw,
node_size=node_size,
# node_shape=node_shape,
zorder=max_order,
params=settings,
node_labels=node_labels,
# rescale_sizes=rescale_sizes,
**kwargs,
)

Expand Down
4 changes: 1 addition & 3 deletions xgi/drawing/draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def _update_lims(pos, ax):
ax.autoscale_view()


def _parse_color_arg(colors, cmap, ids, id_kind="edges"):
def _parse_color_arg(colors, ids, id_kind="edges"):
"""
Parse and process color arguments for plotting.

Expand All @@ -106,8 +106,6 @@ def _parse_color_arg(colors, cmap, ids, id_kind="edges"):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys
cmap : matplotlib colormap or None
The colormap to use for mapping numerical values to colors.
ids : array-like or None
The IDs of the elements being plotted.
id_kind : str, optional
Expand Down
3 changes: 3 additions & 0 deletions xgi/stats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def __repr__(self):
out += ")"
return out

def __len__(self):
return len(self.view)

@property
def name(self):
"""Name of this stat.
Expand Down