Skip to content

Commit

Permalink
fix plotting with transposed nondim coords.
Browse files Browse the repository at this point in the history
Fixes #3138
  • Loading branch information
dcherian committed Oct 24, 2019
1 parent 652dd3c commit ed9948e
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 13 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Bug fixes
- Sync with cftime by removing `dayofwk=-1` for cftime>=1.0.4.
By `Anderson Banihirwe <https://github.com/andersy005>`_.

- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`)
By `Deepak Cherian <https://github.com/dcherian>`_.

Documentation
~~~~~~~~~~~~~
Expand Down
25 changes: 12 additions & 13 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd

from ..core.alignment import broadcast
from .facetgrid import _easy_facetgrid
from .utils import (
_add_colorbar,
Expand Down Expand Up @@ -666,17 +667,6 @@ def newplotfunc(
darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb
)

# better to pass the ndarrays directly to plotting functions
xval = darray[xlab].values
yval = darray[ylab].values

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
xval = np.broadcast_to(xval, yval.shape)

if yval.ndim < xval.ndim:
yval = np.broadcast_to(yval, xval.shape)

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
if imshow_rgb:
Expand All @@ -690,8 +680,17 @@ def newplotfunc(
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose(transpose_coords=True)

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)
# better to pass the ndarrays directly to plotting functions
# Pass the data as a masked ndarray
if darray[xlab].ndim == 1 and darray[ylab].ndim == 1:
xval = darray[xlab].values
yval = darray[ylab].values
zval = darray.to_masked_array(copy=False)
else:
xval, yval, zval = map(
lambda x: x.values, broadcast(darray[xlab], darray[ylab], darray)
)
zval = np.ma.masked_array(zval, mask=pd.isnull(zval), copy=False)

# Replace pd.Intervals if contained in xval or yval.
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__)
Expand Down
16 changes: 16 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2145,3 +2145,19 @@ def test_yticks_kwarg(self, da):
da.plot(yticks=np.arange(5))
expected = np.arange(5)
assert np.all(plt.gca().get_yticks() == expected)


@requires_matplotlib
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"])
def test_plot_transposed_nondim_coord(plotfunc):
x = np.linspace(0, 10, 101)
h = np.linspace(3, 7, 101)
s = np.linspace(0, 1, 51)
z = s[:, np.newaxis] * h[np.newaxis, :]
da = xr.DataArray(
np.sin(x) * np.cos(z),
dims=["s", "x"],
coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
)
getattr(da.plot, plotfunc)(x="x", y="zt")
getattr(da.plot, plotfunc)(x="zt", y="x")

0 comments on commit ed9948e

Please sign in to comment.