Skip to content

Commit

Permalink
Generalize set_(x, y, z)labels in facetgrids (#6918)
Browse files Browse the repository at this point in the history
* Generalize set_xlabels

* Update facetgrid.py

* Add some typing and docstring fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Illviljan and pre-commit-ci[bot] authored Aug 17, 2022
1 parent fbaf815 commit 63d7eb9
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import itertools
import warnings
from typing import Iterable

import numpy as np

Expand Down Expand Up @@ -470,39 +471,39 @@ def add_quiverkey(self, u, v, **kwargs):
# self._adjust_fig_for_guide(self.quiverkey.text)
return self

def set_axis_labels(self, x_var=None, y_var=None):
def set_axis_labels(self, *axlabels):
"""Set axis labels on the left column and bottom row of the grid."""
if x_var is not None:
if x_var in self.data.coords:
self._x_var = x_var
self.set_xlabels(label_from_attrs(self.data[x_var]))
else:
# x_var is a string
self.set_xlabels(x_var)

if y_var is not None:
if y_var in self.data.coords:
self._y_var = y_var
self.set_ylabels(label_from_attrs(self.data[y_var]))
else:
self.set_ylabels(y_var)
from ..core.dataarray import DataArray

for var, axis in zip(axlabels, ["x", "y", "z"]):
if var is not None:
if isinstance(var, DataArray):
getattr(self, f"set_{axis}labels")(label_from_attrs(var))
else:
getattr(self, f"set_{axis}labels")(var)

return self

def set_xlabels(self, label=None, **kwargs):
"""Label the x axis on the bottom row of the grid."""
def _set_labels(
self, axis: str, axes: Iterable, label: None | str = None, **kwargs
):
if label is None:
label = label_from_attrs(self.data[self._x_var])
for ax in self._bottom_axes:
ax.set_xlabel(label, **kwargs)
label = label_from_attrs(self.data[getattr(self, f"_{axis}_var")])
for ax in axes:
getattr(ax, f"set_{axis}label")(label, **kwargs)
return self

def set_ylabels(self, label=None, **kwargs):
def set_xlabels(self, label: None | str = None, **kwargs) -> None:
"""Label the x axis on the bottom row of the grid."""
self._set_labels("x", self._bottom_axes, label, **kwargs)

def set_ylabels(self, label: None | str = None, **kwargs) -> None:
"""Label the y axis on the left column of the grid."""
if label is None:
label = label_from_attrs(self.data[self._y_var])
for ax in self._left_axes:
ax.set_ylabel(label, **kwargs)
return self
self._set_labels("y", self._left_axes, label, **kwargs)

def set_zlabels(self, label: None | str = None, **kwargs) -> None:
"""Label the z axis."""
self._set_labels("z", self._left_axes, label, **kwargs)

def set_titles(self, template="{coord} = {value}", maxchar=30, size=None, **kwargs):
"""
Expand Down

0 comments on commit 63d7eb9

Please sign in to comment.