Skip to content

Commit

Permalink
Add typing to plot methods (#7052)
Browse files Browse the repository at this point in the history
* add plot methods statically and add typing to plot tests

* whats-new update

* fix copy-paste typo

* correct plot signatures

* add *some* typing to plot methods

* annotate darray in plot tests

* correct typing of plot returns

* fix plotting overloads

* add correct overloads to dataset_plot

* update whats-new

* rename xr.plot.plot module since it shadows the xr.plot.plot method

* move accessor to its own module

* move DSPlotAccessor to accessor module

* fix DSPlotAccessor import

* add explanation to import statement

* add breaking change to whats-new

* remove unused `rtol` argument from plot

* make most arguments of plotmethods kwargs only

* fix wrong return types

* add breaking kwarg change to whats-new

* support for aspect='auto' or 'equal

* typing support for Dataset FacetGrid

* deprecate positional arguments for all plot methods

* add deprecation to whats-new

* add FacetGrid generic type

* fix mypy 0.981 complaints

* fix index errors in plots

* add overloads to scatter

* deprecate scatter args

* add scatter to accessors and fix docstrings

* undo some breaking changes

* fix the docstrings and some typing

* fix typing of scatter accessor funcs

* align docstrings with signature and complete typing

* add remaining typing

* align more docstrings

* re add ValueError for scatter plots with u, v

* fix whats-new conflict

* fix some typing errors

* more typing fixes

* fix last mypy complaints

* try fixing facetgrid examples

* fix py3.8 problems

* update plotting.rst

* update api

* update plot docstring

* add a tip about yincrease in imshow

* set default for x/yincrease in docstring

* simplify typing

* add deprecation date as comment

* update whats-new to new release

* fix whats-new
  • Loading branch information
headtr1ck authored Oct 16, 2022
1 parent 50301ac commit da9c1d1
Show file tree
Hide file tree
Showing 21 changed files with 4,992 additions and 2,181 deletions.
4 changes: 2 additions & 2 deletions ci/requirements/min-all-deps.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ channels:
- conda-forge
- nodefaults
dependencies:
# MINIMUM VERSIONS POLICY: see doc/installing.rst
# MINIMUM VERSIONS POLICY: see doc/user-guide/installing.rst
# Run ci/min_deps_check.py to verify that this file respects the policy.
# When upgrading python, numpy, or pandas, must also change
# doc/installing.rst and setup.py.
# doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py.
- python=3.8
- boto3=1.18
- bottleneck=1.3
Expand Down
5 changes: 0 additions & 5 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,6 @@
plot.scatter
plot.surface

plot.FacetGrid.map_dataarray
plot.FacetGrid.set_titles
plot.FacetGrid.set_ticks
plot.FacetGrid.map

CFTimeIndex.all
CFTimeIndex.any
CFTimeIndex.append
Expand Down
2 changes: 2 additions & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ DataArray
DataArray.plot.line
DataArray.plot.pcolormesh
DataArray.plot.step
DataArray.plot.scatter
DataArray.plot.surface


Expand All @@ -719,6 +720,7 @@ Faceting
plot.FacetGrid.map_dataarray
plot.FacetGrid.map_dataarray_line
plot.FacetGrid.map_dataset
plot.FacetGrid.map_plot1d
plot.FacetGrid.set_axis_labels
plot.FacetGrid.set_ticks
plot.FacetGrid.set_titles
Expand Down
50 changes: 34 additions & 16 deletions doc/user-guide/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Matplotlib must be installed before xarray can plot.

To use xarray's plotting capabilities with time coordinates containing
``cftime.datetime`` objects
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.2.0 or later
`nc-time-axis <https://github.com/SciTools/nc-time-axis>`_ v1.3.0 or later
needs to be installed.

For more extensive plotting applications consider the following projects:
Expand Down Expand Up @@ -106,7 +106,13 @@ The simplest way to make a plot is to call the :py:func:`DataArray.plot()` metho
@savefig plotting_1d_simple.png width=4in
air1d.plot()
Xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.
Xarray uses the coordinate name along with metadata ``attrs.long_name``,
``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available)
to label the axes.
The names ``long_name``, ``standard_name`` and ``units`` are copied from the
`CF-conventions spec <https://cfconventions.org/Data/cf-conventions/cf-conventions-1.7/build/ch03s03.html>`_.
When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``.
The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``.

.. ipython:: python
Expand Down Expand Up @@ -340,7 +346,10 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d
y="lat", hue="lon", xincrease=False, yincrease=False
)
In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.
In addition, one can use ``xscale, yscale`` to set axes scaling;
``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits.
These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``,
``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively.


Two Dimensions
Expand All @@ -350,7 +359,8 @@ Two Dimensions
Simple Example
================

The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional.
The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh`
by default when the data is two-dimensional.

.. ipython:: python
:okwarning:
Expand Down Expand Up @@ -585,7 +595,10 @@ Faceting here refers to splitting an array along one or two dimensions and
plotting each group.
Xarray's basic plotting is useful for plotting two dimensional arrays. What
about three or four dimensional arrays? That's where facets become helpful.
The general approach to plotting here is called “small multiples”, where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship conditioned on one or more other variables is often called a “trellis plot”.
The general approach to plotting here is called “small multiples”, where the
same kind of plot is repeated multiple times, and the specific use of small
multiples to display the same relationship conditioned on one or more other
variables is often called a “trellis plot”.

Consider the temperature data set. There are 4 observations per day for two
years which makes for 2920 values along the time dimension.
Expand Down Expand Up @@ -670,8 +683,8 @@ Faceted plotting supports other arguments common to xarray 2d plots.
@savefig plot_facet_robust.png
g = hasoutliers.plot.pcolormesh(
"lon",
"lat",
x="lon",
y="lat",
col="time",
col_wrap=3,
robust=True,
Expand Down Expand Up @@ -711,7 +724,7 @@ they have been plotted.
.. ipython:: python
:okwarning:
g = t.plot.imshow("lon", "lat", col="time", col_wrap=3, robust=True)
g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True)
for i, ax in enumerate(g.axes.flat):
ax.set_title("Air Temperature %d" % i)
Expand All @@ -727,7 +740,8 @@ they have been plotted.
axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`,
:py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and
:py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information.
Plotting functions can be applied to each subset of the data by calling :py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.
Plotting functions can be applied to each subset of the data by calling
:py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`.

TODO: add an example of using the ``map`` method to plot dataset variables
(e.g., with ``plt.quiver``).
Expand Down Expand Up @@ -777,7 +791,8 @@ Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display
@savefig ds_discrete_legend_hue_scatter.png
ds.plot.scatter(x="A", y="B", hue="w", hue_style="discrete")
The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.
The ``markersize`` kwarg lets you vary the point's size by variable value.
You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes.

.. ipython:: python
:okwarning:
Expand All @@ -794,7 +809,8 @@ Faceting is also possible
ds.plot.scatter(x="A", y="B", col="x", row="z", hue="w", hue_style="discrete")
For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.
For more advanced scatter plots, we recommend converting the relevant data variables
to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``.

Quiver
~~~~~~
Expand All @@ -816,7 +832,8 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto
@savefig ds_facet_quiver.png
ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4)
``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.
``scale`` is required for faceted quiver plots.
The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer.

Streamplot
~~~~~~~~~~
Expand All @@ -830,7 +847,8 @@ Visualizing vector fields is also supported with streamline plots:
ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B")
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible:
where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines.
Again, faceting is also possible:

.. ipython:: python
:okwarning:
Expand Down Expand Up @@ -983,7 +1001,7 @@ instead of the default ones:
)
@savefig plotting_example_2d_irreg.png width=4in
da.plot.pcolormesh("lon", "lat")
da.plot.pcolormesh(x="lon", y="lat")
Note that in this case, xarray still follows the pixel centered convention.
This might be undesirable in some cases, for example when your data is defined
Expand All @@ -996,7 +1014,7 @@ this convention when plotting on a map:
import cartopy.crs as ccrs
ax = plt.subplot(projection=ccrs.PlateCarree())
da.plot.pcolormesh("lon", "lat", ax=ax)
da.plot.pcolormesh(x="lon", y="lat", ax=ax)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
@savefig plotting_example_2d_irreg_map.png width=4in
Expand All @@ -1009,7 +1027,7 @@ You can however decide to infer the cell boundaries and use the
:okwarning:
ax = plt.subplot(projection=ccrs.PlateCarree())
da.plot.pcolormesh("lon", "lat", ax=ax, infer_intervals=True)
da.plot.pcolormesh(x="lon", y="lat", ax=ax, infer_intervals=True)
ax.scatter(lon, lat, transform=ccrs.PlateCarree())
ax.coastlines()
@savefig plotting_example_2d_irreg_map_infer.png width=4in
Expand Down
12 changes: 10 additions & 2 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,22 @@ v2022.10.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add static typing to plot accessors (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Breaking changes
~~~~~~~~~~~~~~~~

- Many arguments of plotmethods have been made keyword-only.
- ``xarray.plot.plot`` module renamed to ``xarray.plot.dataarray_plot`` to prevent
shadowing of the ``plot`` method. (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Deprecations
~~~~~~~~~~~~

- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`).
By `Michael Niklas <https://github.com/headtr1ck>`_.

Bug fixes
~~~~~~~~~
Expand Down Expand Up @@ -64,8 +72,8 @@ New Features
the z argument. (:pull:`6778`)
By `Jimmy Westling <https://github.com/illviljan>`_.
- Include the variable name in the error message when CF decoding fails to allow
for easier identification of problematic variables (:issue:`7145`,
:pull:`7147`). By `Spencer Clark <https://github.com/spencerkclark>`_.
for easier identification of problematic variables (:issue:`7145`, :pull:`7147`).
By `Spencer Clark <https://github.com/spencerkclark>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ module = [
"importlib_metadata.*",
"iris.*",
"matplotlib.*",
"mpl_toolkits.*",
"Nio.*",
"nc_time_axis.*",
"numbagg.*",
Expand Down
1 change: 0 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ float_to_top = true
default_section = THIRDPARTY
known_first_party = xarray


[aliases]
test = pytest

Expand Down
14 changes: 8 additions & 6 deletions xarray/core/alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
if TYPE_CHECKING:
from .dataarray import DataArray
from .dataset import Dataset
from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset
from .types import JoinOptions, T_DataArray, T_Dataset, T_DataWithCoords

DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords)

Expand Down Expand Up @@ -944,8 +944,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude):


def _broadcast_helper(
arg: T_DataArrayOrSet, exclude, dims_map, common_coords
) -> T_DataArrayOrSet:
arg: T_DataWithCoords, exclude, dims_map, common_coords
) -> T_DataWithCoords:

from .dataarray import DataArray
from .dataset import Dataset
Expand Down Expand Up @@ -976,14 +976,16 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset:

# remove casts once https://github.com/python/mypy/issues/12800 is resolved
if isinstance(arg, DataArray):
return cast("T_DataArrayOrSet", _broadcast_array(arg))
return cast("T_DataWithCoords", _broadcast_array(arg))
elif isinstance(arg, Dataset):
return cast("T_DataArrayOrSet", _broadcast_dataset(arg))
return cast("T_DataWithCoords", _broadcast_dataset(arg))
else:
raise ValueError("all input must be Dataset or DataArray objects")


def broadcast(*args, exclude=None):
# TODO: this typing is too restrictive since it cannot deal with mixed
# DataArray and Dataset types...? Is this a problem?
def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]:
"""Explicitly broadcast any number of DataArray or Dataset objects against
one another.
Expand Down
4 changes: 2 additions & 2 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex
from ..plot.plot import _PlotMethods
from ..plot.accessor import DataArrayPlotAccessor
from ..plot.utils import _get_units_from_attrs
from . import alignment, computation, dtypes, indexing, ops, utils
from ._reductions import DataArrayReductions
Expand Down Expand Up @@ -4189,7 +4189,7 @@ def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArra
def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None:
self.attrs = other.attrs

plot = utils.UncachedAccessor(_PlotMethods)
plot = utils.UncachedAccessor(DataArrayPlotAccessor)

def _title_for_slice(self, truncate: int = 50) -> str:
"""
Expand Down
8 changes: 5 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@

from ..coding.calendar_ops import convert_calendar, interp_calendar
from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings
from ..plot.dataset_plot import _Dataset_PlotMethods
from ..plot.accessor import DatasetPlotAccessor
from . import alignment
from . import dtypes as xrdtypes
from . import duck_array_ops, formatting, formatting_html, ops, utils
Expand Down Expand Up @@ -7483,7 +7483,7 @@ def imag(self: T_Dataset) -> T_Dataset:
"""
return self.map(lambda x: x.imag, keep_attrs=True)

plot = utils.UncachedAccessor(_Dataset_PlotMethods)
plot = utils.UncachedAccessor(DatasetPlotAccessor)

def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset:
"""Returns a ``Dataset`` with variables that match specific conditions.
Expand Down Expand Up @@ -8575,7 +8575,9 @@ def curvefit(
or not isinstance(coords, Iterable)
):
coords = [coords]
coords_ = [self[coord] if isinstance(coord, str) else coord for coord in coords]
coords_: Sequence[DataArray] = [
self[coord] if isinstance(coord, str) else coord for coord in coords
]

# Determine whether any coords are dims on self
for coord in coords_:
Expand Down
10 changes: 9 additions & 1 deletion xarray/core/pycompat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from importlib import import_module
from typing import Any, Literal

import numpy as np
from packaging.version import Version
Expand All @@ -9,6 +10,8 @@

integer_types = (int, np.integer)

ModType = Literal["dask", "pint", "cupy", "sparse"]


class DuckArrayModule:
"""
Expand All @@ -18,7 +21,12 @@ class DuckArrayModule:
https://github.com/pydata/xarray/pull/5561#discussion_r664815718
"""

def __init__(self, mod):
module: ModType | None
version: Version
type: tuple[type[Any]] # TODO: improve this? maybe Generic
available: bool

def __init__(self, mod: ModType) -> None:
try:
duck_array_module = import_module(mod)
duck_array_version = Version(duck_array_module.__version__)
Expand Down
3 changes: 3 additions & 0 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,10 @@ def dtype(self) -> np.dtype:
CoarsenBoundaryOptions = Literal["exact", "trim", "pad"]
SideOptions = Literal["left", "right"]

ScaleOptions = Literal["linear", "symlog", "log", "logit", None]
HueStyleOptions = Literal["continuous", "discrete", None]
AspectOptions = Union[Literal["auto", "equal"], float, None]
ExtendOptions = Literal["neither", "both", "min", "max", None]

# TODO: Wait until mypy supports recursive objects in combination with typevars
_T = TypeVar("_T")
Expand Down
Loading

0 comments on commit da9c1d1

Please sign in to comment.