diff --git a/cf_xarray/accessor.py b/cf_xarray/accessor.py index ac6dbe78..3be768ac 100644 --- a/cf_xarray/accessor.py +++ b/cf_xarray/accessor.py @@ -655,6 +655,42 @@ def check_results(names, k): ) +def _possible_x_y_plot(obj, key): + """Guesses a name for an x/y variable if possible.""" + # in priority order + x_criteria = [ + ("coordinates", "longitude"), + ("axes", "X"), + ("coordinates", "time"), + ("axes", "T"), + ] + y_criteria = [ + ("coordinates", "vertical"), + ("axes", "Z"), + ("coordinates", "latitude"), + ("axes", "Y"), + ] + + def _get_possible(accessor, criteria): + # is_scalar depends on NON_NUMPY_SUPPORTED_TYPES + # importing a private function seems better than + # maintaining that variable! + from xarray.core.utils import is_scalar + + for attr, key in criteria: + value = getattr(accessor, attr).get(key) + if not value or len(value) > 1: + continue + if not is_scalar(accessor._obj[value[0]]): + return value[0] + return None + + if key == "x": + return _get_possible(obj.cf, x_criteria) + elif key == "y": + return _get_possible(obj.cf, y_criteria) + + class _CFWrappedClass: """ This class is used to wrap any class in _WRAPPED_CLASSES. @@ -705,27 +741,34 @@ def _plot_decorator(self, func): @functools.wraps(func) def _plot_wrapper(*args, **kwargs): - if "x" in kwargs: - if kwargs["x"] in valid_keys: - xvar = self.accessor[kwargs["x"]] - else: - xvar = self._obj[kwargs["x"]] - if "positive" in xvar.attrs: - if xvar.attrs["positive"] == "down": - kwargs.setdefault("xincrease", False) - else: - kwargs.setdefault("xincrease", True) + def _process_x_or_y(kwargs, key): + if key not in kwargs: + kwargs[key] = _possible_x_y_plot(self._obj, key) - if "y" in kwargs: - if kwargs["y"] in valid_keys: - yvar = self.accessor[kwargs["y"]] - else: - yvar = self._obj[kwargs["y"]] - if "positive" in yvar.attrs: - if yvar.attrs["positive"] == "down": - kwargs.setdefault("yincrease", False) + value = kwargs.get(key) + if value: + if value in valid_keys: + var = self.accessor[value] else: - kwargs.setdefault("yincrease", True) + var = self._obj[value] + if "positive" in var.attrs: + if var.attrs["positive"] == "down": + kwargs.setdefault(f"{key}increase", False) + else: + kwargs.setdefault(f"{key}increase", True) + return kwargs + + is_line_plot = (func.__name__ == "line") or ( + func.__name__ == "wrapper" and kwargs.get("hue") + ) + if is_line_plot: + if not kwargs.get("hue"): + kwargs = _process_x_or_y(kwargs, "x") + if not kwargs.get("x"): + kwargs = _process_x_or_y(kwargs, "y") + else: + kwargs = _process_x_or_y(kwargs, "x") + kwargs = _process_x_or_y(kwargs, "y") return func(*args, **kwargs) diff --git a/cf_xarray/tests/test_accessor.py b/cf_xarray/tests/test_accessor.py index 216e6058..4fee67e5 100644 --- a/cf_xarray/tests/test_accessor.py +++ b/cf_xarray/tests/test_accessor.py @@ -293,15 +293,20 @@ def test_dataarray_getitem(): assert_identical(air.cf["area_grid_cell"], air.cell_area.reset_coords(drop=True)) -@pytest.mark.parametrize("obj", dataarrays) -def test_dataarray_plot(obj): +def test_dataarray_plot(): + + obj = airds.air - rv = obj.isel(time=1).cf.plot(x="X", y="Y") + rv = obj.isel(time=1).transpose("lon", "lat").cf.plot() assert isinstance(rv, mpl.collections.QuadMesh) + assert all(v > 180 for v in rv.axes.get_xlim()) + assert all(v < 200 for v in rv.axes.get_ylim()) plt.close() - rv = obj.isel(time=1).cf.plot.contourf(x="X", y="Y") + rv = obj.isel(time=1).transpose("lon", "lat").cf.plot.contourf() assert isinstance(rv, mpl.contour.QuadContourSet) + assert all(v > 180 for v in rv.axes.get_xlim()) + assert all(v < 200 for v in rv.axes.get_ylim()) plt.close() rv = obj.cf.plot(x="X", y="Y", col="T") @@ -316,6 +321,29 @@ def test_dataarray_plot(obj): assert all([isinstance(line, mpl.lines.Line2D) for line in rv]) plt.close() + # set y automatically + rv = obj.isel(time=0, lon=1).cf.plot.line() + np.testing.assert_equal(rv[0].get_ydata(), obj.lat.data) + plt.close() + + # don't set y automatically + rv = obj.isel(time=0, lon=1).cf.plot.line(x="lat") + np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data) + plt.close() + + # various line plots and automatic guessing + rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line() + np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data) + plt.close() + + # rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y") + # np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data) + # plt.close() + + rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line() + np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data) + plt.close() + obj = obj.copy(deep=True) obj.time.attrs.clear() rv = obj.cf.plot(x="X", y="Y", col="time") @@ -714,3 +742,34 @@ def test_drop_dims(ds): # Axis and coordinate for cf_name in ["X", "longitude"]: assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name)) + + +def test_possible_x_y_plot(): + from ..accessor import _possible_x_y_plot + + # choose axes + assert _possible_x_y_plot(airds.air.isel(time=1), "x") == "lon" + assert _possible_x_y_plot(airds.air.isel(time=1), "y") == "lat" + assert _possible_x_y_plot(airds.air.isel(lon=1), "y") == "lat" + assert _possible_x_y_plot(airds.air.isel(lon=1), "x") == "time" + + # choose coordinates over axes + assert _possible_x_y_plot(popds.UVEL, "x") == "ULONG" + assert _possible_x_y_plot(popds.UVEL, "y") == "ULAT" + assert _possible_x_y_plot(popds.TEMP, "x") == "TLONG" + assert _possible_x_y_plot(popds.TEMP, "y") == "TLAT" + + assert _possible_x_y_plot(popds.UVEL.drop_vars("ULONG"), "x") == "nlon" + + # choose X over T, Y over Z + def makeds(*dims): + coords = {dim: (dim, np.arange(3), {"axis": dim}) for dim in dims} + return xr.DataArray(np.zeros((3, 3)), dims=dims, coords=coords) + + yzds = makeds("Y", "Z") + assert _possible_x_y_plot(yzds, "y") == "Z" + assert _possible_x_y_plot(yzds, "x") is None + + xtds = makeds("X", "T") + assert _possible_x_y_plot(xtds, "y") is None + assert _possible_x_y_plot(xtds, "x") == "X" diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 597455ce..04e92161 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -6,6 +6,7 @@ What's New v0.4.1 (unreleased) =================== +- Automatically set ``x`` or ``y`` for :py:attr:`DataArray.cf.plot`. By `Deepak Cherian`_. - Added scripts to document :ref:`criteria` with tables. By `Mattia Almansi`_. - Support for ``.drop()``, ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_. - Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_