Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added edge_ec argument in draw to specify edge colors #575

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 24 additions & 3 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,27 @@ def test_draw_hyperedges_fc_cmap(edgelist8):
plt.close()


def test_draw_hyperedges_ec(edgelist8):
# implemented in PR #575

H = xgi.Hypergraph(edgelist8)

colors = np.array([[0.6468274 , 0.80289262, 0.56592265, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4],
[0.17363177, 0.19076859, 0.44549087, 0.4]])

# edge stat color
fig, ax = plt.subplots()
ax, collections = xgi.draw_hyperedges(H,ax=ax, edge_ec=H.edges.size, edge_fc="w")
(_, edge_collection) = collections

assert np.all(edge_collection.get_edgecolor() == colors)
plt.close("all")


def test_draw_simplices(edgelist8):
with pytest.raises(XGIError):
H = xgi.Hypergraph(edgelist8)
Expand Down Expand Up @@ -679,16 +700,16 @@ def test_draw_undirected_dyads(edgelist8):
H = xgi.Hypergraph(edgelist8)

fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(H)
ax, dyad_collection = xgi.draw_undirected_dyads(H, ax=ax)
assert len(dyad_collection._paths) == 26 # number of lines

with pytest.raises(ValueError):
fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(H, dyad_lw=-1)
ax, dyad_collection = xgi.draw_undirected_dyads(H, dyad_lw=-1, ax=ax)

fig, ax = plt.subplots()
ax, dyad_collection = xgi.draw_undirected_dyads(
H, dyad_color=np.random.random(H.num_edges)
H, dyad_color=np.random.random(H.num_edges), ax=ax
)
assert len(np.unique(dyad_collection.get_color())) == 28
plt.close("all")
Expand Down
226 changes: 128 additions & 98 deletions tutorials/focus/Tutorial 5 - Plotting.ipynb

Large diffs are not rendered by default.

45 changes: 36 additions & 9 deletions tutorials/getting_started/XGI in 1 minute.ipynb

Large diffs are not rendered by default.

24 changes: 12 additions & 12 deletions tutorials/getting_started/XGI in 15 minutes.ipynb

Large diffs are not rendered by default.

71 changes: 49 additions & 22 deletions tutorials/getting_started/XGI in 5 minutes.ipynb

Large diffs are not rendered by default.

102 changes: 80 additions & 22 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def draw(
edge_fc_cmap="crest_r",
edge_vmin=None,
edge_vmax=None,
edge_ec=None,
alpha=0.4,
hull=False,
radius=0.05,
Expand Down Expand Up @@ -166,6 +167,23 @@ def draw(
Colormap used to map the edge colors. By default, "cres_r".
edge_vmin, edge_vmax : float, optional
Minimum and maximum for edge colormap scaling. By default, None.
edge_ec : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys

Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys

If None (default), color by edge size.
Numerical formats will be mapped to colors using edge_vmin, edge_vmax,
and edge_fc_cmap.
alpha : float, optional
The edge transparency. By default, 0.4.
hull : bool, optional
Expand Down Expand Up @@ -262,6 +280,7 @@ def draw(
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
edge_ec=edge_ec,
max_order=max_order,
hyperedge_labels=hyperedge_labels,
rescale_sizes=rescale_sizes,
Expand All @@ -285,6 +304,7 @@ def draw(
edge_fc_cmap=edge_fc_cmap,
edge_vmin=edge_vmin,
edge_vmax=edge_vmax,
edge_ec=edge_ec,
max_order=max_order,
hyperedge_labels=hyperedge_labels,
hull=hull,
Expand Down Expand Up @@ -523,6 +543,7 @@ def draw_hyperedges(
edge_fc_cmap="crest_r",
edge_vmin=None,
edge_vmax=None,
edge_ec=None,
alpha=0.4,
max_order=None,
params=dict(),
Expand Down Expand Up @@ -566,13 +587,13 @@ def draw_hyperedges(
edge_fc : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Those with colors:
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys

Those with numerical values (will be mapped to colors):
Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys
Expand All @@ -582,6 +603,23 @@ def draw_hyperedges(
Colormap used to map the edge colors. By default, "crest_r".
edge_vmin, edge_vmax : float, optional
Minimum and maximum for edge colormap scaling. By default, None.
edge_ec : color or list of colors or array-like or dict or EdgeStat, optional
Color of the hyperedges. The accepted formats are the same as
matplotlib's scatter, with the addition of dict and IDStat.
Formats with colors:
* single color as a string
* single color as 3- or 4-tuple
* list of colors of length len(ids)
* dict of colors containing the `ids` as keys

Formats with numerical values (will be mapped to colors):
* array of floats
* dict of floats containing the `ids` as keys
* IDStat containing the `ids` as keys

If None (default), color by edge size.
Numerical formats will be mapped to colors using edge_vmin, edge_vmax,
and edge_fc_cmap.
alpha : float, optional
The edge transparency. By default, 0.4.
max_order : int, optional
Expand Down Expand Up @@ -650,13 +688,18 @@ def draw_hyperedges(

if edge_fc is None: # color is proportional to size
edge_fc = edges.size
if edge_ec is None: # color is proportional to size
edge_ec = edges.size

# convert all formats to ndarray
dyad_lw = _draw_arg_to_arr(dyad_lw)

# parse colors
dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, list(edges))
dyad_color, dyad_c_to_map = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_to_map = _parse_color_arg(edge_fc, list(edges))
edge_ec, edge_ec_to_map = _parse_color_arg(edge_ec, list(edges))
# edge_c_to_map and dyad_c_to_map are True if the colors
# are input as numeric values that need to be mapped to colors

# check validity of input values
if np.any(dyad_lw < 0):
Expand All @@ -672,7 +715,7 @@ def draw_hyperedges(
dyad_pos = np.asarray([(pos[list(e)[0]], pos[list(e)[1]]) for e in dyads.members()])

# plot dyads
if dyad_c_mapped:
if dyad_c_to_map:
dyad_c_arr = dyad_color
dyad_colors = None
else:
Expand All @@ -682,7 +725,7 @@ def draw_hyperedges(
dyad_collection = LineCollection(
dyad_pos,
colors=dyad_colors,
array=dyad_c_arr, # colors if mapped, ie arr of floats
array=dyad_c_arr, # colors if to be mapped, ie arr of floats
linewidths=dyad_lw,
antialiaseds=(1,),
linestyle=dyad_style,
Expand All @@ -691,7 +734,7 @@ def draw_hyperedges(
)

# dyad_collection.set_cmap(dyad_color_cmap)
if dyad_c_mapped:
if dyad_c_to_map:
dyad_collection.set_clim(dyad_vmin, dyad_vmax)
# dyad_collection.set_zorder(max_order - 1) # edges go behind nodes
ax.add_collection(dyad_collection)
Expand All @@ -700,13 +743,27 @@ def draw_hyperedges(
ids_sorted = np.argsort(edges.size.aslist())[::-1]

# plot other hyperedges
if edge_c_mapped:

# prepare colors for PatchCollection format
if edge_c_to_map:
edge_fc_arr = edge_fc[ids_sorted]
edge_fc_colors = None
else:
edge_fc_arr = None
edge_fc_colors = edge_fc[ids_sorted] if len(edge_fc) > 1 else edge_fc


edge_ec = edge_ec[ids_sorted] if len(edge_ec) > 1 else edge_ec # reorder

if edge_ec_to_map: # edgecolors need to be manually mapped

# create scalarmappable to map floats to colors
# we use the same vmin, vmax, and cmap as for edge_fc
norm = mpl.colors.Normalize(vmin=edge_vmin, vmax=edge_vmax)
sm_edgecolors = cm.ScalarMappable(norm=norm, cmap=edge_fc_cmap)

edge_ec = sm_edgecolors.to_rgba(edge_ec) # map to colors

patches = []
for he in np.array(edges.members())[ids_sorted]:
d = len(he) - 1
Expand All @@ -733,13 +790,14 @@ def draw_hyperedges(
edge_collection = PatchCollection(
patches,
facecolors=edge_fc_colors,
array=edge_fc_arr,
array=edge_fc_arr, # will be mapped by PatchCollection
cmap=edge_fc_cmap,
edgecolors=edge_ec,
alpha=alpha,
zorder=max_order - 2, # below dyads
)
# edge_collection.set_cmap(edge_fc_cmap)
if edge_c_mapped:
if edge_c_to_map:
edge_collection.set_clim(edge_vmin, edge_vmax)
ax.add_collection(edge_collection)

Expand Down Expand Up @@ -1379,9 +1437,9 @@ def draw_multilayer(
raise ValueError("dyad_lw cannot contain negative values.")

# parse colors
dyad_color, dyad_c_mapped = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_mapped = _parse_color_arg(edge_fc, list(edges))
layer_color, layer_c_mapped = _parse_color_arg(layer_color, orders)
dyad_color, dyad_c_to_map = _parse_color_arg(dyad_color, list(dyads))
edge_fc, edge_c_to_map = _parse_color_arg(edge_fc, list(edges))
layer_color, layer_c_to_map = _parse_color_arg(layer_color, orders)

node_size = np.array(node_size) ** 2

Expand All @@ -1402,7 +1460,7 @@ def draw_multilayer(
# draw surfaces corresponding to the different orders
zz = np.zeros(xx.shape) + d * sep

if layer_c_mapped:
if layer_c_to_map:
layer_c = None
else:
layer_c = layer_color[jj] if len(layer_color) > 1 else layer_color
Expand All @@ -1427,7 +1485,7 @@ def draw_multilayer(
]

# plot dyads
if dyad_c_mapped:
if dyad_c_to_map:
raise ValueError(
"dyad_color needs to be a color or list of colors, not numerical values."
)
Expand All @@ -1447,7 +1505,7 @@ def draw_multilayer(
ids_sorted = np.argsort(edges.size.aslist())[::-1]

# plot other hyperedges
if edge_c_mapped:
if edge_c_to_map:
edge_fc_arr = edge_fc[ids_sorted]
edge_fc_colors = None
else:
Expand All @@ -1474,7 +1532,7 @@ def draw_multilayer(
zorder=max_order - 2, # below dyads
)
edge_collection.set_cmap(edge_fc_cmap)
if edge_c_mapped:
if edge_c_to_map:
edge_collection.set_clim(edge_vmin, edge_vmax)
ax.add_collection3d(edge_collection)

Expand Down Expand Up @@ -1963,7 +2021,7 @@ def draw_undirected_dyads(
)

# parse colors
dyad_color, dyads_c_mapped = _parse_color_arg(dyad_color, H.edges)
dyad_color, dyads_c_to_map = _parse_color_arg(dyad_color, H.edges)

# The following two list comprehensions map colors assigned to a hyperedge to
# all of the bipartite edges, so that users need not specify colors for every
Expand All @@ -1986,7 +2044,7 @@ def draw_undirected_dyads(
)

# convert numbers to colors for FancyArrowPatch
if dyads_c_mapped:
if dyads_c_to_map:
norm = mpl.colors.Normalize()
m = cm.ScalarMappable(norm=norm, cmap=dyad_color_cmap)
dyad_color = m.to_rgba(dyad_color)
Expand Down Expand Up @@ -2160,10 +2218,10 @@ def draw_directed_dyads(
)

# parse colors
dyad_color, dyads_c_mapped = _parse_color_arg(dyad_color, H.edges)
dyad_color, dyads_c_to_map = _parse_color_arg(dyad_color, H.edges)

# convert numbers to colors for FancyArrowPatch
if dyads_c_mapped:
if dyads_c_to_map:
norm = mpl.colors.Normalize()
m = cm.ScalarMappable(norm=norm, cmap=dyad_color_cmap)
dyad_color = m.to_rgba(dyad_color)
Expand Down Expand Up @@ -2233,7 +2291,7 @@ def to_marker_edge(marker_size, marker):
else:
dlw = dyad_lw

if dyads_c_mapped:
if dyads_c_to_map:
d_color = dyad_color[edge_to_idx[e]]
else:
d_color = dyad_color
Expand Down
14 changes: 8 additions & 6 deletions xgi/drawing/draw_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):

This function is needed to handle the input formats not naturally
handled by matploltib's Collections: IDStat, dict, and arrays of
floats. All those are converted to arrays of floats and.
floats. All those numerical formats are converted to arrays of floats.

Parameters:
-----------
Expand All @@ -103,8 +103,8 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
--------
colors : single color or ndarray
Processed color values for plotting.
colors_are_mapped : bool
True if the colors are mapped and need special handling. This
colors_to_map : bool
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this just be called map_colors?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea that would work too I guess, but the colors are only mapped later in the draw function

True if the colors need to be mapped and need special handling. This
is used in draw_hyperedges to deal with Collections.

Raises:
Expand All @@ -127,6 +127,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):

xsize = len(ids)

# convert all dict-like input formats to an array
if isinstance(colors, IDStat):
colors = colors.asdict()
if isinstance(colors, dict):
Expand All @@ -135,13 +136,14 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
values = list(colors.values())
colors = np.array(values)

# see if input format needs to be mapped to colors (if numeric)
try: # see if the input format is compatible with PatchCollection's facecolor
colors = to_rgba_array(colors)
colors_are_mapped = False
colors_to_map = False
except:
try: # in case of array of floats (can be fed to PatchCollection with some care)
colors = np.asanyarray(colors, dtype=float)
colors_are_mapped = True
colors_to_map = True
except:
raise ValueError("Invalid input format for colors.")

Expand All @@ -150,7 +152,7 @@ def _parse_color_arg(colors, ids, id_kind="edges"):
f"The input color argument must be a single color or its length must match the number of plotted elements ({xsize})."
)

return colors, colors_are_mapped
return colors, colors_to_map


def _draw_arg_to_arr(arg):
Expand Down
Loading