diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 640bc2d5f9a..2918794b9eb 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -3,6 +3,7 @@ import functools import itertools import warnings +from typing import Iterable import numpy as np @@ -470,39 +471,39 @@ def add_quiverkey(self, u, v, **kwargs): # self._adjust_fig_for_guide(self.quiverkey.text) return self - def set_axis_labels(self, x_var=None, y_var=None): + def set_axis_labels(self, *axlabels): """Set axis labels on the left column and bottom row of the grid.""" - if x_var is not None: - if x_var in self.data.coords: - self._x_var = x_var - self.set_xlabels(label_from_attrs(self.data[x_var])) - else: - # x_var is a string - self.set_xlabels(x_var) - - if y_var is not None: - if y_var in self.data.coords: - self._y_var = y_var - self.set_ylabels(label_from_attrs(self.data[y_var])) - else: - self.set_ylabels(y_var) + from ..core.dataarray import DataArray + + for var, axis in zip(axlabels, ["x", "y", "z"]): + if var is not None: + if isinstance(var, DataArray): + getattr(self, f"set_{axis}labels")(label_from_attrs(var)) + else: + getattr(self, f"set_{axis}labels")(var) + return self - def set_xlabels(self, label=None, **kwargs): - """Label the x axis on the bottom row of the grid.""" + def _set_labels( + self, axis: str, axes: Iterable, label: None | str = None, **kwargs + ): if label is None: - label = label_from_attrs(self.data[self._x_var]) - for ax in self._bottom_axes: - ax.set_xlabel(label, **kwargs) + label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")]) + for ax in axes: + getattr(ax, f"set_{axis}label")(label, **kwargs) return self - def set_ylabels(self, label=None, **kwargs): + def set_xlabels(self, label: None | str = None, **kwargs) -> None: + """Label the x axis on the bottom row of the grid.""" + self._set_labels("x", self._bottom_axes, label, **kwargs) + + def set_ylabels(self, label: None | str = None, **kwargs) -> None: """Label the y axis on the left column of the grid.""" - if label is None: - label = label_from_attrs(self.data[self._y_var]) - for ax in self._left_axes: - ax.set_ylabel(label, **kwargs) - return self + self._set_labels("y", self._left_axes, label, **kwargs) + + def set_zlabels(self, label: None | str = None, **kwargs) -> None: + """Label the z axis.""" + self._set_labels("z", self._left_axes, label, **kwargs) def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs): """