diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index c1aedd570bc..bcb291dacd4 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -5,6 +5,7 @@ from ..core.alignment import broadcast from .facetgrid import _easy_facetgrid +from .plot import _PlotMethods from .utils import ( _add_colorbar, _get_nice_quiver_magnitude, @@ -622,3 +623,78 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # Return .lines so colorbar creation works properly return hdl.lines + + +def _attach_to_plot_class(plotfunc): + """ + Set the function to the plot class and add a common docstring. + + Use this decorator when relying on DataArray.plot methods for + creating the Dataset plot. + + TODO: Reduce code duplication. + + * The goal is to reduce code duplication by moving all Dataset + specific plots to the DataArray side and use this thin wrapper to + handle the conversion between Dataset and DataArray. + * Improve docstring handling, maybe reword the DataArray versions to + explain Datasets better. + * Consider automatically adding all _PlotMethods to + _Dataset_PlotMethods. + + Parameters + ---------- + plotfunc : function + Function that returns a finished plot primitive. + """ + # Build on the original docstring: + original_doc = getattr(_PlotMethods, plotfunc.__name__, None) + commondoc = original_doc.__doc__ + if commondoc is not None: + doc_warning = ( + f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}." + " Some inconsistencies may exist." + ) + # Add indentation so it matches the original doc: + commondoc = f"\n\n {doc_warning}\n\n {commondoc}" + else: + commondoc = "" + plotfunc.__doc__ = ( + f" {plotfunc.__doc__}\n\n" + " The y DataArray will be used as base," + " any other variables are added as coords.\n\n" + f"{commondoc}" + ) + + @functools.wraps(plotfunc) + def plotmethod(self, *args, **kwargs): + return plotfunc(self._ds, *args, **kwargs) + + # Add to class _PlotMethods + setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) + + +def _temp_dataarray(ds, y, args, kwargs): + """Create a temporary datarray with extra coords.""" + from ..core.dataarray import DataArray + + # Base coords: + coords = dict(ds.coords) + + # Add extra coords to the DataArray: + all_args = args + tuple(kwargs.values()) + coords.update({v: ds[v] for v in all_args if ds.data_vars.get(v) is not None}) + + # The dataarray has to include all the dims. Broadcast to that shape + # and add the additional coords: + _y = ds[y].broadcast_like(ds) + + return DataArray(_y, coords=coords) + + +@_attach_to_plot_class +def line(ds, x, y, *args, **kwargs): + """Line plot Dataset data variables against each other.""" + da = _temp_dataarray(ds, y, args, kwargs) + + return da.plot.line(x, *args, **kwargs)