Skip to content

Commit

Permalink
Refactor draw module (#435)
Browse files Browse the repository at this point in the history
* refact: draw module

* style: black and isort
  • Loading branch information
maximelucas authored Jul 27, 2023
1 parent cbaec09 commit 50d18d7
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 276 deletions.
292 changes: 16 additions & 276 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
from ..core import DiHypergraph, Hypergraph, SimplicialComplex
from ..exception import XGIError
from ..stats import IDStat
from .draw_utils import (
_CCW_sort,
_color_arg_to_dict,
_draw_init,
_scalar_arg_to_dict,
_update_lims,
)
from .layout import _augmented_projection, barycenter_spring_layout

__all__ = [
Expand Down Expand Up @@ -112,8 +119,8 @@ def draw(
If True, draw ids on the hyperedges. If a dict, must contain (edge_id: label)
pairs. By default, False.
aspect : {"auto", "equal"} or float, optional
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
description at
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html
**kwargs : optional args
Expand Down Expand Up @@ -163,15 +170,7 @@ def draw(
if edge_fc is None:
edge_fc = H.edges.size

if pos is None:
pos = barycenter_spring_layout(H)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")
ax, pos = _draw_init(H, ax, pos)

if not max_order:
max_order = max_edge_order(H)
Expand Down Expand Up @@ -311,15 +310,7 @@ def draw_nodes(

settings.update(kwargs)

if pos is None:
pos = barycenter_spring_layout(H)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")
ax, pos = _draw_init(H, ax, pos)

# Note Iterable covers lists, tuples, ranges, generators, np.ndarrays, etc
node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"])
Expand Down Expand Up @@ -436,22 +427,14 @@ def draw_hyperedges(
"""

if pos is None:
pos = barycenter_spring_layout(H)

if ax is None:
ax = plt.gca()
ax, pos = _draw_init(H, ax, pos)

if max_order is None:
max_order = max_edge_order(H)

if edge_fc is None:
edge_fc = H.edges.size

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")

if settings is None:
settings = {
"min_dyad_lw": 2.0,
Expand Down Expand Up @@ -604,19 +587,11 @@ def draw_simplices(
if not max_order:
max_order = max_edge_order(H_)

if pos is None:
pos = barycenter_spring_layout(H_)

if ax is None:
ax = plt.gca()
ax, pos = _draw_init(H_, ax, pos)

if edge_fc is None:
edge_fc = H_.edges.size

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")

if settings is None:
settings = {
"min_dyad_lw": 2.0,
Expand Down Expand Up @@ -699,215 +674,6 @@ def draw_simplices(
return ax


def _scalar_arg_to_dict(scalar_arg, ids, min_val, max_val):
"""Map different types of arguments for drawing style to a dict with scalar values.
Parameters
----------
scalar_arg : int, float, dict, iterable, or NodeStat/EdgeStat
Attributes for drawing parameter.
ids : NodeView or EdgeView
This is the node or edge IDs that attributes get mapped to.
min_val : int or float
The minimum value of the drawing parameter.
max_val : int or float
The maximum value of the drawing parameter.
Returns
-------
dict
An ID: scalar dictionary.
Raises
------
TypeError
If a int, float, list, dict, or NodeStat/EdgeStat is not passed.
"""
if isinstance(scalar_arg, str):
raise TypeError(
"Argument must be int, float, dict, iterable, "
f"or NodeStat/EdgeStat. Received {type(scalar_arg)}"
)

# Single argument
if isinstance(scalar_arg, (int, float)):
return {id: scalar_arg for id in ids}

# IDStat
if isinstance(scalar_arg, IDStat):
vals = np.interp(
scalar_arg.asnumpy(),
[scalar_arg.min(), scalar_arg.max()],
[min_val, max_val],
)
return dict(zip(ids, vals))

# Iterables of floats or ints
if isinstance(scalar_arg, Iterable):
if isinstance(scalar_arg, dict):
try:
return {id: float(scalar_arg[id]) for id in scalar_arg if id in ids}
except ValueError as e:
raise TypeError(
"The input dict must have values that can be cast to floats."
)

elif isinstance(scalar_arg, (list, ndarray)):
try:
return {id: float(scalar_arg[idx]) for idx, id in enumerate(ids)}
except ValueError as e:
raise TypeError(
"The input list or array must have values that can be cast to floats."
)
else:
raise TypeError(
"Argument must be an dict, list, or numpy array of floats or ints."
)

raise TypeError(
"Argument must be int, float, dict, iterable, "
f"or NodeStat/EdgeStat. Received {type(scalar_arg)}"
)


def _color_arg_to_dict(color_arg, ids, cmap):
"""Map different types of arguments for drawing style to a dict with color values.
Parameters
----------
color_arg : Several formats are accepted:
Single color values
* str
* 3- or 4-tuple
Iterable of colors (each color specified as above)
* numpy array
* list
* dict {id: color} pairs
Iterable of numerical values (floats or ints)
* list
* dict
* numpy array
Stats
* NodeStat
* EdgeStat
Attributes for drawing parameter.
ids : NodeView or EdgeView
This is the node or edge IDs that attributes get mapped to.
cmap : ListedColormap or LinearSegmentedColormap
colormap to use for NodeStat/EdgeStat.
Returns
-------
dict
An ID: color dictionary.
Raises
------
TypeError
If a string, dict, iterable, or NodeStat/EdgeStat is not passed.
Notes
-----
For the iterable of values, we do not accept tuples,
because there is the potential for ambiguity.
"""

# single argument. Must be a string or a tuple of floats
if isinstance(color_arg, str) or (
isinstance(color_arg, tuple) and isinstance(color_arg[0], float)
):
return {id: color_arg for id in ids}

# Iterables of colors. The values of these iterables must strings or tuples. As of now,
# there is not a check to verify that the tuples contain floats.
if isinstance(color_arg, Iterable):
if isinstance(color_arg, dict) and isinstance(
next(iter(color_arg.values())), (str, tuple, ndarray)
):
return {id: color_arg[id] for id in color_arg if id in ids}
if isinstance(color_arg, (list, ndarray)) and isinstance(
color_arg[0], (str, tuple, ndarray)
):
return {id: color_arg[idx] for idx, id in enumerate(ids)}

# Stats or iterable of values
if isinstance(color_arg, (Iterable, IDStat)):
# set max and min of interpolation based on color map
if isinstance(cmap, ListedColormap):
minval = 0
maxval = cmap.N
elif isinstance(cmap, LinearSegmentedColormap):
minval = 0.1
maxval = 0.9
else:
raise XGIError("Invalid colormap!")

# handle the case of IDStat vs iterables
if isinstance(color_arg, IDStat):
vals = np.interp(
color_arg.asnumpy(),
[color_arg.min(), color_arg.max()],
[minval, maxval],
)
return {
id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids)
}

elif isinstance(color_arg, Iterable):
if isinstance(color_arg, dict) and isinstance(
next(iter(color_arg.values())), (int, float)
):
v = list(color_arg.values())
vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval])
# because we have ids, we can't just assume that the keys of arg correspond to
# the ids.
return {
id: np.array(cmap(v)).reshape(1, -1)
for v, id in zip(vals, color_arg.keys())
if id in ids
}

if isinstance(color_arg, (list, ndarray)) and isinstance(
color_arg[0], (int, float)
):
vals = np.interp(
color_arg, [np.min(color_arg), np.max(color_arg)], [minval, maxval]
)
return {
id: np.array(cmap(vals[i])).reshape(1, -1)
for i, id in enumerate(ids)
}
else:
raise TypeError(
"Argument must be an dict, list, or numpy array of floats."
)

raise TypeError(
"Argument must be str, dict, iterable, or "
f"NodeStat/EdgeStat. Received {type(color_arg)}"
)


def _CCW_sort(p):
"""
Sort the input 2D points counterclockwise.
"""
p = np.array(p)
mean = np.mean(p, axis=0)
d = p - mean
s = np.arctan2(d[:, 0], d[:, 1])
return p[np.argsort(s), :]


def draw_node_labels(
H,
pos,
Expand Down Expand Up @@ -1141,24 +907,6 @@ def draw_hyperedge_labels(
return text_items


def _update_lims(pos, ax):
"""Update Axis limits based on node positions"""

# compute axis limits
pos_arr = np.asarray([[x, y] for n, (x, y) in pos.items()])

maxx, maxy = np.max(pos_arr, axis=0)
minx, miny = np.min(pos_arr, axis=0)
w = maxx - minx
h = maxy - miny

# update view after drawing
padx, pady = 0.05 * w, 0.05 * h
corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady)
ax.update_datalim(corners)
ax.autoscale_view()


def _draw_hull(node_pos, ax, edges_ec, facecolor, alpha, zorder, radius):
"""Draw a convex hull encompassing the nodes in node_pos
Expand Down Expand Up @@ -1283,8 +1031,8 @@ def draw_hypergraph_hull(
radius : float, optional
Radius of the convex hull in the vicinity of the nodes, by default 0.05.
aspect : {"auto", "equal"} or float, optional
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed
directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full
description at
https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html
**kwargs : optional args
Expand Down Expand Up @@ -1337,15 +1085,7 @@ def draw_hypergraph_hull(

settings.update(kwargs)

if pos is None:
pos = barycenter_spring_layout(H)

if ax is None:
ax = plt.gca()

ax.get_xaxis().set_ticks([])
ax.get_yaxis().set_ticks([])
ax.axis("off")
ax, pos = _draw_init(H, ax, pos)

if not max_order:
max_order = max_edge_order(H)
Expand Down
Loading

0 comments on commit 50d18d7

Please sign in to comment.