From 1ddd3b2acc9c876171099f37794fb9ecde25dac0 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 22 Oct 2021 15:05:06 -0400 Subject: [PATCH 1/8] to_numpy in facetgrid --- xarray/plot/facetgrid.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 28dd82e76f5..77ab5155153 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -175,11 +175,11 @@ def __init__( ) # Set up the lists of names for the row and column facet variables - col_names = list(data[col].values) if col else [] - row_names = list(data[row].values) if row else [] + col_names = list(data[col].to_numpy()) if col else [] + row_names = list(data[row].to_numpy()) if row else [] if single_group: - full = [{single_group: x} for x in data[single_group].values] + full = [{single_group: x} for x in data[single_group].to_numpy()] empty = [None for x in range(nrow * ncol - len(full))] name_dicts = full + empty else: @@ -253,7 +253,7 @@ def map_dataarray(self, func, x, y, **kwargs): raise ValueError("cbar_ax not supported by FacetGrid.") cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, self.data.values, **kwargs + func, self.data.to_numpy(), **kwargs ) self._cmap_extend = cmap_params.get("extend") @@ -349,7 +349,7 @@ def map_dataset( if hue and meta_data["hue_style"] == "continuous": cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, self.data[hue].values, **kwargs + func, self.data[hue].to_numpy(), **kwargs ) kwargs["meta_data"]["cmap_params"] = cmap_params kwargs["meta_data"]["cbar_kwargs"] = cbar_kwargs @@ -425,7 +425,7 @@ def _adjust_fig_for_guide(self, guide): def add_legend(self, **kwargs): self.figlegend = self.fig.legend( handles=self._mappables[-1], - labels=list(self._hue_var.values), + labels=list(self._hue_var.to_numpy()), title=self._hue_label, loc="center right", **kwargs, @@ -625,7 +625,7 @@ def map(self, func, *args, **kwargs): if namedict is not None: data = self.data.loc[namedict] plt.sca(ax) - innerargs = [data[a].values for a in args] + innerargs = [data[a].to_numpy() for a in args] maybe_mappable = func(*innerargs, **kwargs) # TODO: better way to verify that an artist is mappable? # https://stackoverflow.com/questions/33023036/is-it-possible-to-detect-if-a-matplotlib-artist-is-a-mappable-suitable-for-use-w#33023522 From c467046d076b5ea2111f1566a603d1014bf7e7fd Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Fri, 22 Oct 2021 15:05:22 -0400 Subject: [PATCH 2/8] to_numpy in 2D plots --- xarray/plot/plot.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index dffdde25db4..31605d45e19 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -1152,10 +1152,6 @@ def newplotfunc( else: dims = (yval.dims[0], xval.dims[0]) - # better to pass the ndarrays directly to plotting functions - xval = xval.to_numpy() - yval = yval.to_numpy() - # May need to transpose for correct x, y labels # xlab may be the name of a coord, we have to check for dim names if imshow_rgb: @@ -1168,8 +1164,13 @@ def newplotfunc( if dims != darray.dims: darray = darray.transpose(*dims, transpose_coords=True) + # better to pass the ndarrays directly to plotting functions + xval = xval.to_numpy() + yval = yval.to_numpy() + zarray = darray.as_numpy() + # Pass the data as a masked ndarray too - zval = darray.to_masked_array(copy=False) + zval = zarray.to_masked_array(copy=False) # Replace pd.Intervals if contained in xval or yval. xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) From d7d1cc2d1a53541988e35a43d772f637adec01a5 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 17:18:46 -0400 Subject: [PATCH 3/8] test for faceted line plots --- xarray/tests/test_units.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index 7bde6ce8b9f..aa5ed49a07b 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5614,7 +5614,7 @@ def test_units_in_line_plot_labels(self): assert ax.get_ylabel() == "pressure [pascal]" assert ax.get_xlabel() == "x [meters]" - def test_units_in_2d_plot_labels(self): + def test_units_in_2d_plot_colorbar_label(self): arr = np.ones((2, 3)) * unit_registry.Pa da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") @@ -5622,3 +5622,21 @@ def test_units_in_2d_plot_labels(self): ax = da.plot.contourf(ax=ax, cbar_ax=cax, add_colorbar=True) assert cax.get_ylabel() == "pressure [pascal]" + + def test_units_facetgrid_plot_labels(self): + arr = np.ones((2, 3)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y"], name="pressure") + + fig, (ax, cax) = plt.subplots(1, 2) + fgrid = da.plot.line(x="x", col="y") + + assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]" + + @pytest.mark.xfail + def test_units_facetgrid_2d_plot_colorbar_labels(self): + arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure") + + ax = da.plot.imshow(x="x", y="y", col="w") + + # assert cax.get_ylabel() == "pressure [pascal]" From 3067a26f4cac0a85db8bc9e666896856ee10d7f8 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 17:33:38 -0400 Subject: [PATCH 4/8] test and fix for faceted imshow --- xarray/plot/plot.py | 2 +- xarray/tests/test_units.py | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 31605d45e19..62dc78bd3d9 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -1079,7 +1079,7 @@ def newplotfunc( # Matplotlib does not support normalising RGB data, so do it here. # See eg. https://github.com/matplotlib/matplotlib/pull/10220 if robust or vmax is not None or vmin is not None: - darray = _rescale_imshow_rgb(darray, vmin, vmax, robust) + darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust) vmin, vmax, robust = None, None, False if subplot_kws is None: diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index aa5ed49a07b..ae5f13c621e 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5632,11 +5632,12 @@ def test_units_facetgrid_plot_labels(self): assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]" - @pytest.mark.xfail - def test_units_facetgrid_2d_plot_colorbar_labels(self): + def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self): arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure") - ax = da.plot.imshow(x="x", y="y", col="w") + da.plot.imshow(x="x", y="y", col="w") - # assert cax.get_ylabel() == "pressure [pascal]" + print(fgrid.axes) + + assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]" From fca0bf27540ebbe42a982dc68ce51497f91b0bf5 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 17:48:09 -0400 Subject: [PATCH 5/8] check label of colorbar in faceted 2d plots --- xarray/tests/test_units.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/xarray/tests/test_units.py b/xarray/tests/test_units.py index ae5f13c621e..8be20c5f81c 100644 --- a/xarray/tests/test_units.py +++ b/xarray/tests/test_units.py @@ -5636,8 +5636,13 @@ def test_units_facetgrid_2d_imshow_plot_colorbar_labels(self): arr = np.ones((2, 3, 4, 5)) * unit_registry.Pa da = xr.DataArray(data=arr, dims=["x", "y", "z", "w"], name="pressure") - da.plot.imshow(x="x", y="y", col="w") + da.plot.imshow(x="x", y="y", col="w") # no colorbar to check labels of - print(fgrid.axes) + def test_units_facetgrid_2d_contourf_plot_colorbar_labels(self): + arr = np.ones((2, 3, 4)) * unit_registry.Pa + da = xr.DataArray(data=arr, dims=["x", "y", "z"], name="pressure") - assert fgrid.axes[0, 0].get_ylabel() == "pressure [pascal]" + fig, (ax1, ax2, ax3, cax) = plt.subplots(1, 4) + fgrid = da.plot.contourf(x="x", y="y", col="z") + + assert fgrid.cbar.ax.get_ylabel() == "pressure [pascal]" From d2f060b0ff399a291545ba286624862a7a82e728 Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 17:53:26 -0400 Subject: [PATCH 6/8] whatsnew --- doc/whats-new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 1ade92f8588..33dae5af7cc 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -80,6 +80,9 @@ Bug fixes By `Jimmy Westling `_. - Numbers are properly formatted in a plot's title (:issue:`5788`, :pull:`5789`). By `Maxime Liquet `_. +- Faceted plots will no longer raise a `pint.UnitStrippedWarning` when a `pint.Quantity` array is plotted, + and will correctly display the units of the data in the colorbar (if there is one) (:pull:`5886`). + By `Tom Nicholas `_. Documentation ~~~~~~~~~~~~~ From 820c0fb1283e8ebd94ee122d3e80959c577582fc Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 18:16:39 -0400 Subject: [PATCH 7/8] don't take values twice for masked arrays --- xarray/plot/plot.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 6cd94323c3f..f41e934c12d 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -1161,10 +1161,9 @@ def newplotfunc( # better to pass the ndarrays directly to plotting functions xval = xval.to_numpy() yval = yval.to_numpy() - zarray = darray.as_numpy() - # Pass the data as a masked ndarray too - zval = zarray.to_masked_array(copy=False) + # Pass the data as a masked ndarray too + zval = darray.to_masked_array(copy=False) # Replace pd.Intervals if contained in xval or yval. xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) From baad9de14dc2f95a777f88a33f1815ffd9ab2dfc Mon Sep 17 00:00:00 2001 From: Thomas Nicholas Date: Thu, 28 Oct 2021 18:17:32 -0400 Subject: [PATCH 8/8] remove trailing whitespace --- xarray/plot/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index f41e934c12d..60f132d07e1 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -1162,7 +1162,7 @@ def newplotfunc( xval = xval.to_numpy() yval = yval.to_numpy() - # Pass the data as a masked ndarray too + # Pass the data as a masked ndarray too zval = darray.to_masked_array(copy=False) # Replace pd.Intervals if contained in xval or yval.