Skip to content

Commit

Permalink
Support array-like mask in heatmaps (#3803)
Browse files Browse the repository at this point in the history
* Support array-like mask in heatmaps

* Nit
  • Loading branch information
mariosasko authored Jan 26, 2025
1 parent e0c2431 commit eb0b5cc
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 8 deletions.
16 changes: 8 additions & 8 deletions seaborn/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,14 @@ def _matrix_mask(data, mask):
if mask is None:
mask = np.zeros(data.shape, bool)

if isinstance(mask, np.ndarray):
if isinstance(mask, pd.DataFrame):
# For DataFrame masks, ensure that semantic labels match data
if not mask.index.equals(data.index) \
and mask.columns.equals(data.columns):
err = "Mask must have the same index and columns as data."
raise ValueError(err)
elif hasattr(mask, "__array__"):
mask = np.asarray(mask)
# For array masks, ensure that shape matches data then convert
if mask.shape != data.shape:
raise ValueError("Mask must have the same shape as data.")
Expand All @@ -79,13 +86,6 @@ def _matrix_mask(data, mask):
columns=data.columns,
dtype=bool)

elif isinstance(mask, pd.DataFrame):
# For DataFrame masks, ensure that semantic labels match data
if not mask.index.equals(data.index) \
and mask.columns.equals(data.columns):
err = "Mask must have the same index and columns as data."
raise ValueError(err)

# Add any cells with missing data to the mask
# This works around an issue where `plt.pcolormesh` doesn't represent
# missing data properly
Expand Down
18 changes: 18 additions & 0 deletions tests/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,24 @@ def test_ndarray_input(self):
assert p.xlabel == ""
assert p.ylabel == ""

def test_array_like_input(self):
class ArrayLike:
def __init__(self, data):
self.data = data

def __array__(self, dtype=None, copy=None):
return np.asarray(self.data, dtype=dtype, copy=copy)

p = mat._HeatMapper(ArrayLike(self.x_norm), **self.default_kws)
npt.assert_array_equal(p.plot_data, self.x_norm)
pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))

npt.assert_array_equal(p.xticklabels, np.arange(8))
npt.assert_array_equal(p.yticklabels, np.arange(4))

assert p.xlabel == ""
assert p.ylabel == ""

def test_df_input(self):

p = mat._HeatMapper(self.df_norm, **self.default_kws)
Expand Down

0 comments on commit eb0b5cc

Please sign in to comment.