Skip to content

Commit

Permalink
Fix drawing warning (#512)
Browse files Browse the repository at this point in the history
* Fix drawing warning

* added unit tests

* fix test error

Using pytest best practices (https://docs.pytest.org/en/7.0.x/how-to/capture-warnings.html#additional-use-cases-of-warnings-in-tests)
  • Loading branch information
nwlandry authored Feb 20, 2024
1 parent 1e746bf commit 7ed5532
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
25 changes: 21 additions & 4 deletions tests/drawing/test_draw.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import matplotlib.pyplot as plt
import numpy as np
import pytest
Expand Down Expand Up @@ -514,7 +516,7 @@ def test_draw_multilayer(edgelist8):
# node_size
assert np.all(node_coll4.get_sizes() == np.array([10**2]))

plt.close()
plt.close("all")


def test_draw_dihypergraph(diedgelist2, edgelist8):
Expand Down Expand Up @@ -587,16 +589,14 @@ def test_draw_dihypergraph(diedgelist2, edgelist8):
for patch in ax1.patches: # lines
assert patch.get_zorder() == 0

plt.close("all")

# test toggle for edges
fig, ax2 = plt.subplots()
ax2, collections = xgi.draw_dihypergraph(DH, edge_marker_toggle=False, ax=ax2)
node_coll, phantom_node_coll = collections
assert len(ax2.collections) == 1
assert phantom_node_coll is None

plt.close()
plt.close("all")

# test XGI ERROR raise
with pytest.raises(XGIError):
Expand All @@ -622,3 +622,20 @@ def test_draw_dihypergraph_with_str_labels_and_isolated_nodes():
assert len(node_coll4.get_offsets()) == 6 # number of original nodes
assert len(phantom_node_coll4.get_offsets()) == 2 # number of original edges
assert len(ax4.patches) == 7 # number of lines
plt.close()


def test_issue_499(edgelist8):
H = xgi.Hypergraph(edgelist8)

fig, ax = plt.subplots()

with warnings.catch_warnings():
warnings.simplefilter("error")
ax, collections = xgi.draw(H, ax=ax, node_fc="black")

with warnings.catch_warnings():
warnings.simplefilter("error")
ax, collections = xgi.draw(H, ax=ax, node_fc=["black"] * H.num_nodes)

plt.close("all")
11 changes: 7 additions & 4 deletions xgi/drawing/draw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import seaborn as sb # for cmap "crest"
from matplotlib import cm
from matplotlib.colors import is_color_like
from matplotlib.patches import FancyArrowPatch
from mpl_toolkits.mplot3d.art3d import (
Line3DCollection,
Expand Down Expand Up @@ -433,10 +434,6 @@ def draw_nodes(
settings.update(params)
settings.update(kwargs)

# avoid matplotlib scatter UserWarning "Parameters 'cmap' will be ignored"
if isinstance(node_fc, str):
node_fc_cmap = None

ax, pos = _draw_init(H, ax, pos)

# convert pos to format convenient for scatter
Expand All @@ -450,6 +447,12 @@ def draw_nodes(
node_fc = _draw_arg_to_arr(node_fc)
node_lw = _draw_arg_to_arr(node_lw)

# avoid matplotlib scatter UserWarning "Parameters 'cmap' will be ignored"
if isinstance(node_fc, str) or (
isinstance(node_fc, np.ndarray) and is_color_like(node_fc[0])
):
node_fc_cmap = None

# check validity of input values
if np.any(node_size < 0):
raise ValueError("node_size cannot contain negative values.")
Expand Down

0 comments on commit 7ed5532

Please sign in to comment.