Skip to content

Automatically choose x,y for plots #148

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 62 additions & 19 deletions cf_xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,42 @@ def check_results(names, k):
)


def _possible_x_y_plot(obj, key):
"""Guesses a name for an x/y variable if possible."""
# in priority order
x_criteria = [
("coordinates", "longitude"),
("axes", "X"),
("coordinates", "time"),
("axes", "T"),
]
y_criteria = [
("coordinates", "vertical"),
("axes", "Z"),
("coordinates", "latitude"),
("axes", "Y"),
]

def _get_possible(accessor, criteria):
# is_scalar depends on NON_NUMPY_SUPPORTED_TYPES
# importing a private function seems better than
# maintaining that variable!
from xarray.core.utils import is_scalar

for attr, key in criteria:
value = getattr(accessor, attr).get(key)
if not value or len(value) > 1:
continue
if not is_scalar(accessor._obj[value[0]]):
return value[0]
return None

if key == "x":
return _get_possible(obj.cf, x_criteria)
elif key == "y":
return _get_possible(obj.cf, y_criteria)


class _CFWrappedClass:
"""
This class is used to wrap any class in _WRAPPED_CLASSES.
Expand Down Expand Up @@ -705,27 +741,34 @@ def _plot_decorator(self, func):

@functools.wraps(func)
def _plot_wrapper(*args, **kwargs):
if "x" in kwargs:
if kwargs["x"] in valid_keys:
xvar = self.accessor[kwargs["x"]]
else:
xvar = self._obj[kwargs["x"]]
if "positive" in xvar.attrs:
if xvar.attrs["positive"] == "down":
kwargs.setdefault("xincrease", False)
else:
kwargs.setdefault("xincrease", True)
def _process_x_or_y(kwargs, key):
if key not in kwargs:
kwargs[key] = _possible_x_y_plot(self._obj, key)

if "y" in kwargs:
if kwargs["y"] in valid_keys:
yvar = self.accessor[kwargs["y"]]
else:
yvar = self._obj[kwargs["y"]]
if "positive" in yvar.attrs:
if yvar.attrs["positive"] == "down":
kwargs.setdefault("yincrease", False)
value = kwargs.get(key)
if value:
if value in valid_keys:
var = self.accessor[value]
else:
kwargs.setdefault("yincrease", True)
var = self._obj[value]
if "positive" in var.attrs:
if var.attrs["positive"] == "down":
kwargs.setdefault(f"{key}increase", False)
else:
kwargs.setdefault(f"{key}increase", True)
return kwargs

is_line_plot = (func.__name__ == "line") or (
func.__name__ == "wrapper" and kwargs.get("hue")
)
if is_line_plot:
if not kwargs.get("hue"):
kwargs = _process_x_or_y(kwargs, "x")
if not kwargs.get("x"):
kwargs = _process_x_or_y(kwargs, "y")
else:
kwargs = _process_x_or_y(kwargs, "x")
kwargs = _process_x_or_y(kwargs, "y")

return func(*args, **kwargs)

Expand Down
67 changes: 63 additions & 4 deletions cf_xarray/tests/test_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,20 @@ def test_dataarray_getitem():
assert_identical(air.cf["area_grid_cell"], air.cell_area.reset_coords(drop=True))


@pytest.mark.parametrize("obj", dataarrays)
def test_dataarray_plot(obj):
def test_dataarray_plot():

obj = airds.air

rv = obj.isel(time=1).cf.plot(x="X", y="Y")
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot()
assert isinstance(rv, mpl.collections.QuadMesh)
assert all(v > 180 for v in rv.axes.get_xlim())
assert all(v < 200 for v in rv.axes.get_ylim())
plt.close()

rv = obj.isel(time=1).cf.plot.contourf(x="X", y="Y")
rv = obj.isel(time=1).transpose("lon", "lat").cf.plot.contourf()
assert isinstance(rv, mpl.contour.QuadContourSet)
assert all(v > 180 for v in rv.axes.get_xlim())
assert all(v < 200 for v in rv.axes.get_ylim())
plt.close()

rv = obj.cf.plot(x="X", y="Y", col="T")
Expand All @@ -316,6 +321,29 @@ def test_dataarray_plot(obj):
assert all([isinstance(line, mpl.lines.Line2D) for line in rv])
plt.close()

# set y automatically
rv = obj.isel(time=0, lon=1).cf.plot.line()
np.testing.assert_equal(rv[0].get_ydata(), obj.lat.data)
plt.close()

# don't set y automatically
rv = obj.isel(time=0, lon=1).cf.plot.line(x="lat")
np.testing.assert_equal(rv[0].get_xdata(), obj.lat.data)
plt.close()

# various line plots and automatic guessing
rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
plt.close()

# rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot(hue="Y")
# np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
# plt.close()

rv = obj.cf.isel(T=1, Y=[0, 1, 2]).cf.plot.line()
np.testing.assert_equal(rv[0].get_xdata(), obj.lon.data)
plt.close()

obj = obj.copy(deep=True)
obj.time.attrs.clear()
rv = obj.cf.plot(x="X", y="Y", col="time")
Expand Down Expand Up @@ -714,3 +742,34 @@ def test_drop_dims(ds):
# Axis and coordinate
for cf_name in ["X", "longitude"]:
assert_identical(ds.drop_dims("lon"), ds.cf.drop_dims(cf_name))


def test_possible_x_y_plot():
from ..accessor import _possible_x_y_plot

# choose axes
assert _possible_x_y_plot(airds.air.isel(time=1), "x") == "lon"
assert _possible_x_y_plot(airds.air.isel(time=1), "y") == "lat"
assert _possible_x_y_plot(airds.air.isel(lon=1), "y") == "lat"
assert _possible_x_y_plot(airds.air.isel(lon=1), "x") == "time"

# choose coordinates over axes
assert _possible_x_y_plot(popds.UVEL, "x") == "ULONG"
assert _possible_x_y_plot(popds.UVEL, "y") == "ULAT"
assert _possible_x_y_plot(popds.TEMP, "x") == "TLONG"
assert _possible_x_y_plot(popds.TEMP, "y") == "TLAT"

assert _possible_x_y_plot(popds.UVEL.drop_vars("ULONG"), "x") == "nlon"

# choose X over T, Y over Z
def makeds(*dims):
coords = {dim: (dim, np.arange(3), {"axis": dim}) for dim in dims}
return xr.DataArray(np.zeros((3, 3)), dims=dims, coords=coords)

yzds = makeds("Y", "Z")
assert _possible_x_y_plot(yzds, "y") == "Z"
assert _possible_x_y_plot(yzds, "x") is None

xtds = makeds("X", "T")
assert _possible_x_y_plot(xtds, "y") is None
assert _possible_x_y_plot(xtds, "x") == "X"
1 change: 1 addition & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ What's New
v0.4.1 (unreleased)
===================

- Automatically set ``x`` or ``y`` for :py:attr:`DataArray.cf.plot`. By `Deepak Cherian`_.
- Added scripts to document :ref:`criteria` with tables. By `Mattia Almansi`_.
- Support for ``.drop()``, ``.drop_vars()``, ``.drop_sel()``, ``.drop_dims()``, ``.set_coords()``, ``.reset_coords()``. By `Mattia Almansi`_.
- Support for using ``standard_name`` in more functions. (:pr:`128`) By `Deepak Cherian`_
Expand Down