Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix plotting with transposed nondim coords. #3441

Merged
merged 12 commits into from
Dec 4, 2019
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ New Features

Bug fixes
~~~~~~~~~
- Fix plotting with transposed 2D non-dimensional coordinates. (:issue:`3138`, :pull:`3441`)
By `Deepak Cherian <https://github.com/dcherian>`_.


Documentation
Expand Down
25 changes: 18 additions & 7 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,10 +672,22 @@ def newplotfunc(

# check if we need to broadcast one dimension
if xval.ndim < yval.ndim:
xval = np.broadcast_to(xval, yval.shape)
dims = darray[ylab].dims
if xval.shape[0] == yval.shape[0]:
xval = np.broadcast_to(xval[:, np.newaxis], yval.shape)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a cleaner way of doing this?

else:
xval = np.broadcast_to(xval[np.newaxis, :], yval.shape)

if yval.ndim < xval.ndim:
yval = np.broadcast_to(yval, xval.shape)
elif yval.ndim < xval.ndim:
dims = darray[xlab].dims
if yval.shape[0] == xval.shape[0]:
yval = np.broadcast_to(yval[:, np.newaxis], xval.shape)
else:
yval = np.broadcast_to(yval[np.newaxis, :], xval.shape)
elif xval.ndim == 2:
dims = darray[xlab].dims
else:
dims = (darray[ylab].dims[0], darray[xlab].dims[0])

# May need to transpose for correct x, y labels
# xlab may be the name of a coord, we have to check for dim names
Expand All @@ -685,10 +697,9 @@ def newplotfunc(
# we transpose to (y, x, color) to make this work.
yx_dims = (ylab, xlab)
dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims)
if dims != darray.dims:
darray = darray.transpose(*dims, transpose_coords=True)
elif darray[xlab].dims[-1] == darray.dims[0]:
darray = darray.transpose(transpose_coords=True)

if dims != darray.dims:
darray = darray.transpose(*dims, transpose_coords=True)

# Pass the data as a masked ndarray too
zval = darray.to_masked_array(copy=False)
Expand Down
29 changes: 29 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ def test2d_1d_2d_coordinates_contourf(self):
)

a.plot.contourf(x="time", y="depth")
a.plot.contourf(x="depth", y="time")

def test3d(self):
self.darray.plot()
Expand Down Expand Up @@ -2149,3 +2150,31 @@ def test_yticks_kwarg(self, da):
da.plot(yticks=np.arange(5))
expected = np.arange(5)
assert np.all(plt.gca().get_yticks() == expected)


@requires_matplotlib
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"])
def test_plot_transposed_nondim_coord(plotfunc):
x = np.linspace(0, 10, 101)
h = np.linspace(3, 7, 101)
s = np.linspace(0, 1, 51)
z = s[:, np.newaxis] * h[np.newaxis, :]
da = xr.DataArray(
np.sin(x) * np.cos(z),
dims=["s", "x"],
coords={"x": x, "s": s, "z": (("s", "x"), z), "zt": (("x", "s"), z.T)},
)
getattr(da.plot, plotfunc)(x="x", y="zt")
getattr(da.plot, plotfunc)(x="zt", y="x")


@requires_matplotlib
@pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"])
def test_plot_transposes_properly(plotfunc):
# test that we aren't mistakenly transposing when the 2 dimensions have equal sizes.
da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x"))
hdl = getattr(da.plot, plotfunc)(x="x", y="y")
# get_array doesn't work for contour, contourf. It returns the colormap intervals.
# pcolormesh returns 1D array but imshow returns a 2D array so it is necessary
# to ravel() on the LHS
assert np.all(hdl.get_array().ravel() == da.to_masked_array().ravel())