From 50d18d7ba4118ff654784b4d5a60fc61a53b922c Mon Sep 17 00:00:00 2001 From: Maxime Lucas Date: Thu, 27 Jul 2023 16:20:17 -0400 Subject: [PATCH] Refactor draw module (#435) * refact: draw module * style: black and isort --- xgi/drawing/draw.py | 292 +++----------------------------------- xgi/drawing/draw_utils.py | 254 +++++++++++++++++++++++++++++++++ 2 files changed, 270 insertions(+), 276 deletions(-) create mode 100644 xgi/drawing/draw_utils.py diff --git a/xgi/drawing/draw.py b/xgi/drawing/draw.py index 238de0514..348528b48 100644 --- a/xgi/drawing/draw.py +++ b/xgi/drawing/draw.py @@ -19,6 +19,13 @@ from ..core import DiHypergraph, Hypergraph, SimplicialComplex from ..exception import XGIError from ..stats import IDStat +from .draw_utils import ( + _CCW_sort, + _color_arg_to_dict, + _draw_init, + _scalar_arg_to_dict, + _update_lims, +) from .layout import _augmented_projection, barycenter_spring_layout __all__ = [ @@ -112,8 +119,8 @@ def draw( If True, draw ids on the hyperedges. If a dict, must contain (edge_id: label) pairs. By default, False. aspect : {"auto", "equal"} or float, optional - Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed - directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full + Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed + directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full description at https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html **kwargs : optional args @@ -163,15 +170,7 @@ def draw( if edge_fc is None: edge_fc = H.edges.size - if pos is None: - pos = barycenter_spring_layout(H) - - if ax is None: - ax = plt.gca() - - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") + ax, pos = _draw_init(H, ax, pos) if not max_order: max_order = max_edge_order(H) @@ -311,15 +310,7 @@ def draw_nodes( settings.update(kwargs) - if pos is None: - pos = barycenter_spring_layout(H) - - if ax is None: - ax = plt.gca() - - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") + ax, pos = _draw_init(H, ax, pos) # Note Iterable covers lists, tuples, ranges, generators, np.ndarrays, etc node_fc = _color_arg_to_dict(node_fc, H.nodes, settings["node_fc_cmap"]) @@ -436,11 +427,7 @@ def draw_hyperedges( """ - if pos is None: - pos = barycenter_spring_layout(H) - - if ax is None: - ax = plt.gca() + ax, pos = _draw_init(H, ax, pos) if max_order is None: max_order = max_edge_order(H) @@ -448,10 +435,6 @@ def draw_hyperedges( if edge_fc is None: edge_fc = H.edges.size - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") - if settings is None: settings = { "min_dyad_lw": 2.0, @@ -604,19 +587,11 @@ def draw_simplices( if not max_order: max_order = max_edge_order(H_) - if pos is None: - pos = barycenter_spring_layout(H_) - - if ax is None: - ax = plt.gca() + ax, pos = _draw_init(H_, ax, pos) if edge_fc is None: edge_fc = H_.edges.size - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") - if settings is None: settings = { "min_dyad_lw": 2.0, @@ -699,215 +674,6 @@ def draw_simplices( return ax -def _scalar_arg_to_dict(scalar_arg, ids, min_val, max_val): - """Map different types of arguments for drawing style to a dict with scalar values. - - Parameters - ---------- - scalar_arg : int, float, dict, iterable, or NodeStat/EdgeStat - Attributes for drawing parameter. - ids : NodeView or EdgeView - This is the node or edge IDs that attributes get mapped to. - min_val : int or float - The minimum value of the drawing parameter. - max_val : int or float - The maximum value of the drawing parameter. - - Returns - ------- - dict - An ID: scalar dictionary. - - Raises - ------ - TypeError - If a int, float, list, dict, or NodeStat/EdgeStat is not passed. - """ - if isinstance(scalar_arg, str): - raise TypeError( - "Argument must be int, float, dict, iterable, " - f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" - ) - - # Single argument - if isinstance(scalar_arg, (int, float)): - return {id: scalar_arg for id in ids} - - # IDStat - if isinstance(scalar_arg, IDStat): - vals = np.interp( - scalar_arg.asnumpy(), - [scalar_arg.min(), scalar_arg.max()], - [min_val, max_val], - ) - return dict(zip(ids, vals)) - - # Iterables of floats or ints - if isinstance(scalar_arg, Iterable): - if isinstance(scalar_arg, dict): - try: - return {id: float(scalar_arg[id]) for id in scalar_arg if id in ids} - except ValueError as e: - raise TypeError( - "The input dict must have values that can be cast to floats." - ) - - elif isinstance(scalar_arg, (list, ndarray)): - try: - return {id: float(scalar_arg[idx]) for idx, id in enumerate(ids)} - except ValueError as e: - raise TypeError( - "The input list or array must have values that can be cast to floats." - ) - else: - raise TypeError( - "Argument must be an dict, list, or numpy array of floats or ints." - ) - - raise TypeError( - "Argument must be int, float, dict, iterable, " - f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" - ) - - -def _color_arg_to_dict(color_arg, ids, cmap): - """Map different types of arguments for drawing style to a dict with color values. - - Parameters - ---------- - color_arg : Several formats are accepted: - - Single color values - - * str - * 3- or 4-tuple - - Iterable of colors (each color specified as above) - - * numpy array - * list - * dict {id: color} pairs - - Iterable of numerical values (floats or ints) - - * list - * dict - * numpy array - - Stats - - * NodeStat - * EdgeStat - - Attributes for drawing parameter. - ids : NodeView or EdgeView - This is the node or edge IDs that attributes get mapped to. - cmap : ListedColormap or LinearSegmentedColormap - colormap to use for NodeStat/EdgeStat. - - Returns - ------- - dict - An ID: color dictionary. - - Raises - ------ - TypeError - If a string, dict, iterable, or NodeStat/EdgeStat is not passed. - - Notes - ----- - For the iterable of values, we do not accept tuples, - because there is the potential for ambiguity. - """ - - # single argument. Must be a string or a tuple of floats - if isinstance(color_arg, str) or ( - isinstance(color_arg, tuple) and isinstance(color_arg[0], float) - ): - return {id: color_arg for id in ids} - - # Iterables of colors. The values of these iterables must strings or tuples. As of now, - # there is not a check to verify that the tuples contain floats. - if isinstance(color_arg, Iterable): - if isinstance(color_arg, dict) and isinstance( - next(iter(color_arg.values())), (str, tuple, ndarray) - ): - return {id: color_arg[id] for id in color_arg if id in ids} - if isinstance(color_arg, (list, ndarray)) and isinstance( - color_arg[0], (str, tuple, ndarray) - ): - return {id: color_arg[idx] for idx, id in enumerate(ids)} - - # Stats or iterable of values - if isinstance(color_arg, (Iterable, IDStat)): - # set max and min of interpolation based on color map - if isinstance(cmap, ListedColormap): - minval = 0 - maxval = cmap.N - elif isinstance(cmap, LinearSegmentedColormap): - minval = 0.1 - maxval = 0.9 - else: - raise XGIError("Invalid colormap!") - - # handle the case of IDStat vs iterables - if isinstance(color_arg, IDStat): - vals = np.interp( - color_arg.asnumpy(), - [color_arg.min(), color_arg.max()], - [minval, maxval], - ) - return { - id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids) - } - - elif isinstance(color_arg, Iterable): - if isinstance(color_arg, dict) and isinstance( - next(iter(color_arg.values())), (int, float) - ): - v = list(color_arg.values()) - vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval]) - # because we have ids, we can't just assume that the keys of arg correspond to - # the ids. - return { - id: np.array(cmap(v)).reshape(1, -1) - for v, id in zip(vals, color_arg.keys()) - if id in ids - } - - if isinstance(color_arg, (list, ndarray)) and isinstance( - color_arg[0], (int, float) - ): - vals = np.interp( - color_arg, [np.min(color_arg), np.max(color_arg)], [minval, maxval] - ) - return { - id: np.array(cmap(vals[i])).reshape(1, -1) - for i, id in enumerate(ids) - } - else: - raise TypeError( - "Argument must be an dict, list, or numpy array of floats." - ) - - raise TypeError( - "Argument must be str, dict, iterable, or " - f"NodeStat/EdgeStat. Received {type(color_arg)}" - ) - - -def _CCW_sort(p): - """ - Sort the input 2D points counterclockwise. - """ - p = np.array(p) - mean = np.mean(p, axis=0) - d = p - mean - s = np.arctan2(d[:, 0], d[:, 1]) - return p[np.argsort(s), :] - - def draw_node_labels( H, pos, @@ -1141,24 +907,6 @@ def draw_hyperedge_labels( return text_items -def _update_lims(pos, ax): - """Update Axis limits based on node positions""" - - # compute axis limits - pos_arr = np.asarray([[x, y] for n, (x, y) in pos.items()]) - - maxx, maxy = np.max(pos_arr, axis=0) - minx, miny = np.min(pos_arr, axis=0) - w = maxx - minx - h = maxy - miny - - # update view after drawing - padx, pady = 0.05 * w, 0.05 * h - corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) - ax.update_datalim(corners) - ax.autoscale_view() - - def _draw_hull(node_pos, ax, edges_ec, facecolor, alpha, zorder, radius): """Draw a convex hull encompassing the nodes in node_pos @@ -1283,8 +1031,8 @@ def draw_hypergraph_hull( radius : float, optional Radius of the convex hull in the vicinity of the nodes, by default 0.05. aspect : {"auto", "equal"} or float, optional - Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed - directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full + Set the aspect ratio of the axes scaling, i.e. y/x-scale. `aspect` is passed + directly to matplotlib's `ax.set_aspect()`. Default is `equal`. See full description at https://matplotlib.org/stable/api/_as_gen/matplotlib.axes.Axes.set_aspect.html **kwargs : optional args @@ -1337,15 +1085,7 @@ def draw_hypergraph_hull( settings.update(kwargs) - if pos is None: - pos = barycenter_spring_layout(H) - - if ax is None: - ax = plt.gca() - - ax.get_xaxis().set_ticks([]) - ax.get_yaxis().set_ticks([]) - ax.axis("off") + ax, pos = _draw_init(H, ax, pos) if not max_order: max_order = max_edge_order(H) diff --git a/xgi/drawing/draw_utils.py b/xgi/drawing/draw_utils.py new file mode 100644 index 000000000..18d3c82b0 --- /dev/null +++ b/xgi/drawing/draw_utils.py @@ -0,0 +1,254 @@ +"""Draw hypergraphs and simplicial complexes with matplotlib.""" + +from collections.abc import Iterable + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.colors import LinearSegmentedColormap, ListedColormap +from numpy import ndarray + +from ..exception import XGIError +from ..stats import IDStat +from .layout import barycenter_spring_layout + + +def _draw_init(H, ax, pos): + + if pos is None: + pos = barycenter_spring_layout(H) + + if ax is None: + ax = plt.gca() + + ax.get_xaxis().set_ticks([]) + ax.get_yaxis().set_ticks([]) + ax.axis("off") + + return ax, pos + + +def _update_lims(pos, ax): + """Update Axis limits based on node positions""" + + # compute axis limits + pos_arr = np.asarray([[x, y] for n, (x, y) in pos.items()]) + + maxx, maxy = np.max(pos_arr, axis=0) + minx, miny = np.min(pos_arr, axis=0) + w = maxx - minx + h = maxy - miny + + # update view after drawing + padx, pady = 0.05 * w, 0.05 * h + corners = (minx - padx, miny - pady), (maxx + padx, maxy + pady) + ax.update_datalim(corners) + ax.autoscale_view() + + +def _scalar_arg_to_dict(scalar_arg, ids, min_val, max_val): + """Map different types of arguments for drawing style to a dict with scalar values. + + Parameters + ---------- + scalar_arg : int, float, dict, iterable, or NodeStat/EdgeStat + Attributes for drawing parameter. + ids : NodeView or EdgeView + This is the node or edge IDs that attributes get mapped to. + min_val : int or float + The minimum value of the drawing parameter. + max_val : int or float + The maximum value of the drawing parameter. + + Returns + ------- + dict + An ID: scalar dictionary. + + Raises + ------ + TypeError + If a int, float, list, dict, or NodeStat/EdgeStat is not passed. + """ + if isinstance(scalar_arg, str): + raise TypeError( + "Argument must be int, float, dict, iterable, " + f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" + ) + + # Single argument + if isinstance(scalar_arg, (int, float)): + return {id: scalar_arg for id in ids} + + # IDStat + if isinstance(scalar_arg, IDStat): + vals = np.interp( + scalar_arg.asnumpy(), + [scalar_arg.min(), scalar_arg.max()], + [min_val, max_val], + ) + return dict(zip(ids, vals)) + + # Iterables of floats or ints + if isinstance(scalar_arg, Iterable): + if isinstance(scalar_arg, dict): + try: + return {id: float(scalar_arg[id]) for id in scalar_arg if id in ids} + except ValueError as e: + raise TypeError( + "The input dict must have values that can be cast to floats." + ) + + elif isinstance(scalar_arg, (list, ndarray)): + try: + return {id: float(scalar_arg[idx]) for idx, id in enumerate(ids)} + except ValueError as e: + raise TypeError( + "The input list or array must have values that can be cast to floats." + ) + else: + raise TypeError( + "Argument must be an dict, list, or numpy array of floats or ints." + ) + + raise TypeError( + "Argument must be int, float, dict, iterable, " + f"or NodeStat/EdgeStat. Received {type(scalar_arg)}" + ) + + +def _color_arg_to_dict(color_arg, ids, cmap): + """Map different types of arguments for drawing style to a dict with color values. + + Parameters + ---------- + color_arg : Several formats are accepted: + + Single color values + + * str + * 3- or 4-tuple + + Iterable of colors (each color specified as above) + + * numpy array + * list + * dict {id: color} pairs + + Iterable of numerical values (floats or ints) + + * list + * dict + * numpy array + + Stats + + * NodeStat + * EdgeStat + + Attributes for drawing parameter. + ids : NodeView or EdgeView + This is the node or edge IDs that attributes get mapped to. + cmap : ListedColormap or LinearSegmentedColormap + colormap to use for NodeStat/EdgeStat. + + Returns + ------- + dict + An ID: color dictionary. + + Raises + ------ + TypeError + If a string, dict, iterable, or NodeStat/EdgeStat is not passed. + + Notes + ----- + For the iterable of values, we do not accept tuples, + because there is the potential for ambiguity. + """ + + # single argument. Must be a string or a tuple of floats + if isinstance(color_arg, str) or ( + isinstance(color_arg, tuple) and isinstance(color_arg[0], float) + ): + return {id: color_arg for id in ids} + + # Iterables of colors. The values of these iterables must strings or tuples. As of now, + # there is not a check to verify that the tuples contain floats. + if isinstance(color_arg, Iterable): + if isinstance(color_arg, dict) and isinstance( + next(iter(color_arg.values())), (str, tuple, ndarray) + ): + return {id: color_arg[id] for id in color_arg if id in ids} + if isinstance(color_arg, (list, ndarray)) and isinstance( + color_arg[0], (str, tuple, ndarray) + ): + return {id: color_arg[idx] for idx, id in enumerate(ids)} + + # Stats or iterable of values + if isinstance(color_arg, (Iterable, IDStat)): + # set max and min of interpolation based on color map + if isinstance(cmap, ListedColormap): + minval = 0 + maxval = cmap.N + elif isinstance(cmap, LinearSegmentedColormap): + minval = 0.1 + maxval = 0.9 + else: + raise XGIError("Invalid colormap!") + + # handle the case of IDStat vs iterables + if isinstance(color_arg, IDStat): + vals = np.interp( + color_arg.asnumpy(), + [color_arg.min(), color_arg.max()], + [minval, maxval], + ) + return { + id: np.array(cmap(vals[i])).reshape(1, -1) for i, id in enumerate(ids) + } + + elif isinstance(color_arg, Iterable): + if isinstance(color_arg, dict) and isinstance( + next(iter(color_arg.values())), (int, float) + ): + v = list(color_arg.values()) + vals = np.interp(v, [np.min(v), np.max(v)], [minval, maxval]) + # because we have ids, we can't just assume that the keys of arg correspond to + # the ids. + return { + id: np.array(cmap(v)).reshape(1, -1) + for v, id in zip(vals, color_arg.keys()) + if id in ids + } + + if isinstance(color_arg, (list, ndarray)) and isinstance( + color_arg[0], (int, float) + ): + vals = np.interp( + color_arg, [np.min(color_arg), np.max(color_arg)], [minval, maxval] + ) + return { + id: np.array(cmap(vals[i])).reshape(1, -1) + for i, id in enumerate(ids) + } + else: + raise TypeError( + "Argument must be an dict, list, or numpy array of floats." + ) + + raise TypeError( + "Argument must be str, dict, iterable, or " + f"NodeStat/EdgeStat. Received {type(color_arg)}" + ) + + +def _CCW_sort(p): + """ + Sort the input 2D points counterclockwise. + """ + p = np.array(p) + mean = np.mean(p, axis=0) + d = p - mean + s = np.arctan2(d[:, 0], d[:, 1]) + return p[np.argsort(s), :]