diff --git a/.github/workflows/ci-additional.yaml b/.github/workflows/ci-additional.yaml index 7087b43abc3..8f84b021952 100644 --- a/.github/workflows/ci-additional.yaml +++ b/.github/workflows/ci-additional.yaml @@ -72,7 +72,10 @@ jobs: python xarray/util/print_versions.py - name: Run doctests run: | - python -m pytest --doctest-modules xarray --ignore xarray/tests + # Raise an error if there are warnings in the doctests, with `-Werror`. + # This is a trial; if it presents an problem, feel free to remove. + # See https://github.com/pydata/xarray/issues/7164 for more info. + python -m pytest --doctest-modules xarray --ignore xarray/tests -Werror mypy: name: Mypy diff --git a/asv_bench/benchmarks/import.py b/asv_bench/benchmarks/import.py new file mode 100644 index 00000000000..4d326d41d75 --- /dev/null +++ b/asv_bench/benchmarks/import.py @@ -0,0 +1,18 @@ +class Import: + """Benchmark importing xarray""" + + def timeraw_import_xarray(self): + return """ + import xarray + """ + + def timeraw_import_xarray_plot(self): + return """ + import xarray.plot + """ + + def timeraw_import_xarray_backends(self): + return """ + from xarray.backends import list_engines + list_engines() + """ diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index 8ff322ee6a4..1a2307aee5e 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -3,10 +3,10 @@ channels: - conda-forge - nodefaults dependencies: - # MINIMUM VERSIONS POLICY: see doc/installing.rst + # MINIMUM VERSIONS POLICY: see doc/user-guide/installing.rst # Run ci/min_deps_check.py to verify that this file respects the policy. # When upgrading python, numpy, or pandas, must also change - # doc/installing.rst and setup.py. + # doc/user-guide/installing.rst, doc/user-guide/plotting.rst and setup.py. - python=3.8 - boto3=1.18 - bottleneck=1.3 diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 30bc9f858f2..18c2539c04e 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -330,11 +330,6 @@ plot.scatter plot.surface - plot.FacetGrid.map_dataarray - plot.FacetGrid.set_titles - plot.FacetGrid.set_ticks - plot.FacetGrid.map - CFTimeIndex.all CFTimeIndex.any CFTimeIndex.append diff --git a/doc/api.rst b/doc/api.rst index c780b18a8c2..b47e11c71b9 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -703,6 +703,7 @@ DataArray DataArray.plot.line DataArray.plot.pcolormesh DataArray.plot.step + DataArray.plot.scatter DataArray.plot.surface @@ -719,6 +720,7 @@ Faceting plot.FacetGrid.map_dataarray plot.FacetGrid.map_dataarray_line plot.FacetGrid.map_dataset + plot.FacetGrid.map_plot1d plot.FacetGrid.set_axis_labels plot.FacetGrid.set_ticks plot.FacetGrid.set_titles diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 9fb34712f32..6952a018c7f 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -27,7 +27,7 @@ Matplotlib must be installed before xarray can plot. To use xarray's plotting capabilities with time coordinates containing ``cftime.datetime`` objects -`nc-time-axis `_ v1.2.0 or later +`nc-time-axis `_ v1.3.0 or later needs to be installed. For more extensive plotting applications consider the following projects: @@ -106,7 +106,13 @@ The simplest way to make a plot is to call the :py:func:`DataArray.plot()` metho @savefig plotting_1d_simple.png width=4in air1d.plot() -Xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) to label the axes. The names ``long_name``, ``standard_name`` and ``units`` are copied from the `CF-conventions spec `_. When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. +Xarray uses the coordinate name along with metadata ``attrs.long_name``, +``attrs.standard_name``, ``DataArray.name`` and ``attrs.units`` (if available) +to label the axes. +The names ``long_name``, ``standard_name`` and ``units`` are copied from the +`CF-conventions spec `_. +When choosing names, the order of precedence is ``long_name``, ``standard_name`` and finally ``DataArray.name``. +The y-axis label in the above plot was constructed from the ``long_name`` and ``units`` attributes of ``air1d``. .. ipython:: python @@ -340,7 +346,10 @@ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes d y="lat", hue="lon", xincrease=False, yincrease=False ) -In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, ``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively. +In addition, one can use ``xscale, yscale`` to set axes scaling; +``xticks, yticks`` to set axes ticks and ``xlim, ylim`` to set axes limits. +These accept the same values as the matplotlib methods ``Axes.set_(x,y)scale()``, +``Axes.set_(x,y)ticks()``, ``Axes.set_(x,y)lim()`` respectively. Two Dimensions @@ -350,7 +359,8 @@ Two Dimensions Simple Example ================ -The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional. +The default method :py:meth:`DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` +by default when the data is two-dimensional. .. ipython:: python :okwarning: @@ -585,7 +595,10 @@ Faceting here refers to splitting an array along one or two dimensions and plotting each group. Xarray's basic plotting is useful for plotting two dimensional arrays. What about three or four dimensional arrays? That's where facets become helpful. -The general approach to plotting here is called “small multiples”, where the same kind of plot is repeated multiple times, and the specific use of small multiples to display the same relationship conditioned on one or more other variables is often called a “trellis plot”. +The general approach to plotting here is called “small multiples”, where the +same kind of plot is repeated multiple times, and the specific use of small +multiples to display the same relationship conditioned on one or more other +variables is often called a “trellis plot”. Consider the temperature data set. There are 4 observations per day for two years which makes for 2920 values along the time dimension. @@ -670,8 +683,8 @@ Faceted plotting supports other arguments common to xarray 2d plots. @savefig plot_facet_robust.png g = hasoutliers.plot.pcolormesh( - "lon", - "lat", + x="lon", + y="lat", col="time", col_wrap=3, robust=True, @@ -711,7 +724,7 @@ they have been plotted. .. ipython:: python :okwarning: - g = t.plot.imshow("lon", "lat", col="time", col_wrap=3, robust=True) + g = t.plot.imshow(x="lon", y="lat", col="time", col_wrap=3, robust=True) for i, ax in enumerate(g.axes.flat): ax.set_title("Air Temperature %d" % i) @@ -727,7 +740,8 @@ they have been plotted. axis labels, axis ticks and plot titles. See :py:meth:`~xarray.plot.FacetGrid.set_titles`, :py:meth:`~xarray.plot.FacetGrid.set_xlabels`, :py:meth:`~xarray.plot.FacetGrid.set_ylabels` and :py:meth:`~xarray.plot.FacetGrid.set_ticks` for more information. -Plotting functions can be applied to each subset of the data by calling :py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`. +Plotting functions can be applied to each subset of the data by calling +:py:meth:`~xarray.plot.FacetGrid.map_dataarray` or to each subplot by calling :py:meth:`~xarray.plot.FacetGrid.map`. TODO: add an example of using the ``map`` method to plot dataset variables (e.g., with ``plt.quiver``). @@ -742,14 +756,32 @@ Consider this dataset .. ipython:: python - ds = xr.tutorial.scatter_example_dataset() + ds = xr.tutorial.scatter_example_dataset(seed=42) ds Scatter ~~~~~~~ -Suppose we want to scatter ``A`` against ``B`` +Let's plot the ``A`` DataArray as a function of the ``y`` coord + +.. ipython:: python + :okwarning: + + ds.A + + @savefig da_A_y.png + ds.A.plot.scatter(x="y") + +Same plot can be displayed using the dataset: + +.. ipython:: python + :okwarning: + + @savefig ds_A_y.png + ds.plot.scatter(x="y", y="A") + +Now suppose we want to scatter the ``A`` DataArray against the ``B`` DataArray .. ipython:: python :okwarning: @@ -765,25 +797,36 @@ The ``hue`` kwarg lets you vary the color by variable value @savefig ds_hue_scatter.png ds.plot.scatter(x="A", y="B", hue="w") -When ``hue`` is specified, a colorbar is added for numeric ``hue`` DataArrays by -default and a legend is added for non-numeric ``hue`` DataArrays (as above). -You can force a legend instead of a colorbar by setting ``hue_style='discrete'``. -Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display of a legend or colorbar (as appropriate). +You can force a legend instead of a colorbar by setting ``add_legend=True, add_colorbar=False``. .. ipython:: python :okwarning: - ds = ds.assign(w=[1, 2, 3, 5]) @savefig ds_discrete_legend_hue_scatter.png - ds.plot.scatter(x="A", y="B", hue="w", hue_style="discrete") + ds.plot.scatter(x="A", y="B", hue="w", add_legend=True, add_colorbar=False) + +.. ipython:: python + :okwarning: + + @savefig ds_discrete_colorbar_hue_scatter.png + ds.plot.scatter(x="A", y="B", hue="w", add_legend=False, add_colorbar=True) -The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes. +The ``markersize`` kwarg lets you vary the point's size by variable value. +You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes. .. ipython:: python :okwarning: @savefig ds_hue_size_scatter.png - ds.plot.scatter(x="A", y="B", hue="z", hue_style="discrete", markersize="z") + ds.plot.scatter(x="A", y="B", hue="y", markersize="z") + +The ``z`` kwarg lets you plot the data along the z-axis as well. + +.. ipython:: python + :okwarning: + + @savefig ds_hue_size_scatter_z.png + ds.plot.scatter(x="A", y="B", z="z", hue="y", markersize="x") Faceting is also possible @@ -791,10 +834,18 @@ Faceting is also possible :okwarning: @savefig ds_facet_scatter.png - ds.plot.scatter(x="A", y="B", col="x", row="z", hue="w", hue_style="discrete") + ds.plot.scatter(x="A", y="B", hue="y", markersize="x", row="x", col="w") + +And adding the z-axis + +.. ipython:: python + :okwarning: + @savefig ds_facet_scatter_z.png + ds.plot.scatter(x="A", y="B", z="z", hue="y", markersize="x", row="x", col="w") -For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. +For more advanced scatter plots, we recommend converting the relevant data variables +to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. Quiver ~~~~~~ @@ -816,7 +867,8 @@ where ``u`` and ``v`` denote the x and y direction components of the arrow vecto @savefig ds_facet_quiver.png ds.plot.quiver(x="x", y="y", u="A", v="B", col="w", row="z", scale=4) -``scale`` is required for faceted quiver plots. The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. +``scale`` is required for faceted quiver plots. +The scale determines the number of data units per arrow length unit, i.e. a smaller scale parameter makes the arrow longer. Streamplot ~~~~~~~~~~ @@ -830,7 +882,8 @@ Visualizing vector fields is also supported with streamline plots: ds.isel(w=1, z=1).plot.streamplot(x="x", y="y", u="A", v="B") -where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. Again, faceting is also possible: +where ``u`` and ``v`` denote the x and y direction components of the vectors tangent to the streamlines. +Again, faceting is also possible: .. ipython:: python :okwarning: @@ -983,7 +1036,7 @@ instead of the default ones: ) @savefig plotting_example_2d_irreg.png width=4in - da.plot.pcolormesh("lon", "lat") + da.plot.pcolormesh(x="lon", y="lat") Note that in this case, xarray still follows the pixel centered convention. This might be undesirable in some cases, for example when your data is defined @@ -996,7 +1049,7 @@ this convention when plotting on a map: import cartopy.crs as ccrs ax = plt.subplot(projection=ccrs.PlateCarree()) - da.plot.pcolormesh("lon", "lat", ax=ax) + da.plot.pcolormesh(x="lon", y="lat", ax=ax) ax.scatter(lon, lat, transform=ccrs.PlateCarree()) ax.coastlines() @savefig plotting_example_2d_irreg_map.png width=4in @@ -1009,7 +1062,7 @@ You can however decide to infer the cell boundaries and use the :okwarning: ax = plt.subplot(projection=ccrs.PlateCarree()) - da.plot.pcolormesh("lon", "lat", ax=ax, infer_intervals=True) + da.plot.pcolormesh(x="lon", y="lat", ax=ax, infer_intervals=True) ax.scatter(lon, lat, transform=ccrs.PlateCarree()) ax.coastlines() @savefig plotting_example_2d_irreg_map_infer.png width=4in diff --git a/doc/whats-new.rst b/doc/whats-new.rst index ed630098631..7ccb9ea4525 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -23,14 +23,22 @@ v2022.10.1 (unreleased) New Features ~~~~~~~~~~~~ +- Add static typing to plot accessors (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. Breaking changes ~~~~~~~~~~~~~~~~ +- Many arguments of plotmethods have been made keyword-only. +- ``xarray.plot.plot`` module renamed to ``xarray.plot.dataarray_plot`` to prevent + shadowing of the ``plot`` method. (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. Deprecations ~~~~~~~~~~~~ +- Positional arguments for all plot methods have been deprecated (:issue:`6949`, :pull:`7052`). + By `Michael Niklas `_. Bug fixes ~~~~~~~~~ @@ -42,6 +50,9 @@ Documentation Internal Changes ~~~~~~~~~~~~~~~~ +- Doctests fail on any warnings (:pull:`7166`) + By `Maximilian Roos `_. + .. _whats-new.2022.10.0: @@ -64,8 +75,8 @@ New Features the z argument. (:pull:`6778`) By `Jimmy Westling `_. - Include the variable name in the error message when CF decoding fails to allow - for easier identification of problematic variables (:issue:`7145`, - :pull:`7147`). By `Spencer Clark `_. + for easier identification of problematic variables (:issue:`7145`, :pull:`7147`). + By `Spencer Clark `_. Breaking changes ~~~~~~~~~~~~~~~~ diff --git a/pyproject.toml b/pyproject.toml index 53331b6b66f..271abc0aab1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,7 @@ module = [ "importlib_metadata.*", "iris.*", "matplotlib.*", + "mpl_toolkits.*", "Nio.*", "nc_time_axis.*", "numbagg.*", diff --git a/setup.cfg b/setup.cfg index 89c3de7d5e2..72dca2dec63 100644 --- a/setup.cfg +++ b/setup.cfg @@ -165,7 +165,6 @@ float_to_top = true default_section = THIRDPARTY known_first_party = xarray - [aliases] test = pytest diff --git a/xarray/backends/rasterio_.py b/xarray/backends/rasterio_.py index d3153eb3e18..acd9070320b 100644 --- a/xarray/backends/rasterio_.py +++ b/xarray/backends/rasterio_.py @@ -187,49 +187,6 @@ def open_rasterio( `_ for more information). - You can generate 2D coordinates from the file's attributes with:: - - >>> from affine import Affine - >>> da = xr.open_rasterio( - ... "https://github.com/rasterio/rasterio/raw/1.2.1/tests/data/RGB.byte.tif" - ... ) - >>> da - - [1703814 values with dtype=uint8] - Coordinates: - * band (band) int64 1 2 3 - * y (y) float64 2.827e+06 2.826e+06 2.826e+06 ... 2.612e+06 2.612e+06 - * x (x) float64 1.021e+05 1.024e+05 1.027e+05 ... 3.389e+05 3.392e+05 - Attributes: - transform: (300.0379266750948, 0.0, 101985.0, 0.0, -300.041782729805... - crs: +init=epsg:32618 - res: (300.0379266750948, 300.041782729805) - is_tiled: 0 - nodatavals: (0.0, 0.0, 0.0) - scales: (1.0, 1.0, 1.0) - offsets: (0.0, 0.0, 0.0) - AREA_OR_POINT: Area - >>> transform = Affine(*da.attrs["transform"]) - >>> transform - Affine(300.0379266750948, 0.0, 101985.0, - 0.0, -300.041782729805, 2826915.0) - >>> nx, ny = da.sizes["x"], da.sizes["y"] - >>> x, y = transform * np.meshgrid(np.arange(nx) + 0.5, np.arange(ny) + 0.5) - >>> x - array([[102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666], - [102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666], - [102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666], - ..., - [102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666], - [102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666], - [102135.01896334, 102435.05689001, 102735.09481669, ..., - 338564.90518331, 338864.94310999, 339164.98103666]]) - Parameters ---------- filename : str, rasterio.DatasetReader, or rasterio.WarpedVRT diff --git a/xarray/core/alignment.py b/xarray/core/alignment.py index c68b57633c4..8c95acadb2e 100644 --- a/xarray/core/alignment.py +++ b/xarray/core/alignment.py @@ -38,7 +38,7 @@ if TYPE_CHECKING: from .dataarray import DataArray from .dataset import Dataset - from .types import JoinOptions, T_DataArray, T_DataArrayOrSet, T_Dataset + from .types import JoinOptions, T_DataArray, T_Dataset, T_DataWithCoords DataAlignable = TypeVar("DataAlignable", bound=DataWithCoords) @@ -944,8 +944,8 @@ def _get_broadcast_dims_map_common_coords(args, exclude): def _broadcast_helper( - arg: T_DataArrayOrSet, exclude, dims_map, common_coords -) -> T_DataArrayOrSet: + arg: T_DataWithCoords, exclude, dims_map, common_coords +) -> T_DataWithCoords: from .dataarray import DataArray from .dataset import Dataset @@ -976,14 +976,16 @@ def _broadcast_dataset(ds: T_Dataset) -> T_Dataset: # remove casts once https://github.com/python/mypy/issues/12800 is resolved if isinstance(arg, DataArray): - return cast("T_DataArrayOrSet", _broadcast_array(arg)) + return cast("T_DataWithCoords", _broadcast_array(arg)) elif isinstance(arg, Dataset): - return cast("T_DataArrayOrSet", _broadcast_dataset(arg)) + return cast("T_DataWithCoords", _broadcast_dataset(arg)) else: raise ValueError("all input must be Dataset or DataArray objects") -def broadcast(*args, exclude=None): +# TODO: this typing is too restrictive since it cannot deal with mixed +# DataArray and Dataset types...? Is this a problem? +def broadcast(*args: T_DataWithCoords, exclude=None) -> tuple[T_DataWithCoords, ...]: """Explicitly broadcast any number of DataArray or Dataset objects against one another. diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index a7e193c79a0..8d971c53917 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -22,7 +22,7 @@ from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex -from ..plot.plot import _PlotMethods +from ..plot.accessor import DataArrayPlotAccessor from ..plot.utils import _get_units_from_attrs from . import alignment, computation, dtypes, indexing, ops, utils from ._reductions import DataArrayReductions @@ -4189,7 +4189,7 @@ def _inplace_binary_op(self: T_DataArray, other: Any, f: Callable) -> T_DataArra def _copy_attrs_from(self, other: DataArray | Dataset | Variable) -> None: self.attrs = other.attrs - plot = utils.UncachedAccessor(_PlotMethods) + plot = utils.UncachedAccessor(DataArrayPlotAccessor) def _title_for_slice(self, truncate: int = 50) -> str: """ @@ -5261,9 +5261,9 @@ def idxmin( >>> array.min() array(-2) - >>> array.argmin() - - array(4) + >>> array.argmin(...) + {'x': + array(4)} >>> array.idxmin() array('e', dtype='>> array.max() array(2) - >>> array.argmax() - - array(1) + >>> array.argmax(...) + {'x': + array(1)} >>> array.idxmax() array('b', dtype='>> array.min() array(-1) - >>> array.argmin() - - array(2) >>> array.argmin(...) {'x': array(2)} @@ -5555,9 +5552,6 @@ def argmax( >>> array.max() array(3) - >>> array.argmax() - - array(3) >>> array.argmax(...) {'x': array(3)} @@ -6031,7 +6025,7 @@ def groupby( >>> da = xr.DataArray( ... np.linspace(0, 1826, num=1827), - ... coords=[pd.date_range("1/1/2000", "31/12/2004", freq="D")], + ... coords=[pd.date_range("2000-01-01", "2004-12-31", freq="D")], ... dims="time", ... ) >>> da diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index c6c97129dc9..ab1d36a9e54 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -35,7 +35,7 @@ from ..coding.calendar_ops import convert_calendar, interp_calendar from ..coding.cftimeindex import CFTimeIndex, _parse_array_of_cftime_strings -from ..plot.dataset_plot import _Dataset_PlotMethods +from ..plot.accessor import DatasetPlotAccessor from . import alignment from . import dtypes as xrdtypes from . import duck_array_ops, formatting, formatting_html, ops, utils @@ -7483,7 +7483,7 @@ def imag(self: T_Dataset) -> T_Dataset: """ return self.map(lambda x: x.imag, keep_attrs=True) - plot = utils.UncachedAccessor(_Dataset_PlotMethods) + plot = utils.UncachedAccessor(DatasetPlotAccessor) def filter_by_attrs(self: T_Dataset, **kwargs) -> T_Dataset: """Returns a ``Dataset`` with variables that match specific conditions. @@ -8575,7 +8575,9 @@ def curvefit( or not isinstance(coords, Iterable) ): coords = [coords] - coords_ = [self[coord] if isinstance(coord, str) else coord for coord in coords] + coords_: Sequence[DataArray] = [ + self[coord] if isinstance(coord, str) else coord for coord in coords + ] # Determine whether any coords are dims on self for coord in coords_: diff --git a/xarray/core/pycompat.py b/xarray/core/pycompat.py index 09ee13e4941..93b61ecc3e8 100644 --- a/xarray/core/pycompat.py +++ b/xarray/core/pycompat.py @@ -1,6 +1,7 @@ from __future__ import annotations from importlib import import_module +from typing import Any, Literal import numpy as np from packaging.version import Version @@ -9,6 +10,8 @@ integer_types = (int, np.integer) +ModType = Literal["dask", "pint", "cupy", "sparse"] + class DuckArrayModule: """ @@ -18,7 +21,12 @@ class DuckArrayModule: https://github.com/pydata/xarray/pull/5561#discussion_r664815718 """ - def __init__(self, mod): + module: ModType | None + version: Version + type: tuple[type[Any]] # TODO: improve this? maybe Generic + available: bool + + def __init__(self, mod: ModType) -> None: try: duck_array_module = import_module(mod) duck_array_version = Version(duck_array_module.__version__) diff --git a/xarray/core/types.py b/xarray/core/types.py index d0764c4a791..2b65f4d23e6 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -164,7 +164,10 @@ def dtype(self) -> np.dtype: CoarsenBoundaryOptions = Literal["exact", "trim", "pad"] SideOptions = Literal["left", "right"] +ScaleOptions = Literal["linear", "symlog", "log", "logit", None] HueStyleOptions = Literal["continuous", "discrete", None] +AspectOptions = Union[Literal["auto", "equal"], float, None] +ExtendOptions = Literal["neither", "both", "min", "max", None] # TODO: Wait until mypy supports recursive objects in combination with typevars _T = TypeVar("_T") diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 28ae0cf32e7..bac62673ee1 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -1,6 +1,24 @@ +""" +Use this module directly: + import xarray.plot as xplt + +Or use the methods on a DataArray or Dataset: + DataArray.plot._____ + Dataset.plot._____ +""" +from .dataarray_plot import ( + contour, + contourf, + hist, + imshow, + line, + pcolormesh, + plot, + step, + surface, +) from .dataset_plot import scatter from .facetgrid import FacetGrid -from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface __all__ = [ "plot", diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py new file mode 100644 index 00000000000..273d0f4f921 --- /dev/null +++ b/xarray/plot/accessor.py @@ -0,0 +1,1301 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING, Any, Hashable, Iterable, Literal, NoReturn, overload + +import numpy as np + +# Accessor methods have the same name as plotting methods, so we need a different namespace +from . import dataarray_plot, dataset_plot + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import LineCollection, PathCollection, QuadMesh + from matplotlib.colors import Normalize + from matplotlib.container import BarContainer + from matplotlib.contour import QuadContourSet + from matplotlib.image import AxesImage + from matplotlib.quiver import Quiver + from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection + from numpy.typing import ArrayLike + + from ..core.dataarray import DataArray + from ..core.dataset import Dataset + from ..core.types import AspectOptions, HueStyleOptions, ScaleOptions + from .facetgrid import FacetGrid + + +class DataArrayPlotAccessor: + """ + Enables use of xarray.plot functions as attributes on a DataArray. + For example, DataArray.plot.imshow + """ + + _da: DataArray + + __slots__ = ("_da",) + __doc__ = dataarray_plot.plot.__doc__ + + def __init__(self, darray: DataArray) -> None: + self._da = darray + + # Should return Any such that the user does not run into problems + # with the many possible return values + @functools.wraps(dataarray_plot.plot, assigned=("__doc__", "__annotations__")) + def __call__(self, **kwargs) -> Any: + return dataarray_plot.plot(self._da, **kwargs) + + @functools.wraps(dataarray_plot.hist) + def hist(self, *args, **kwargs) -> tuple[np.ndarray, np.ndarray, BarContainer]: + return dataarray_plot.hist(self._da, *args, **kwargs) + + @overload + def line( # type: ignore[misc] # None is hashable :( + self, + *args: Any, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> list[Line3D]: + ... + + @overload + def line( + self, + *args: Any, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def line( + self, + *args: Any, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataarray_plot.line) + def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: + return dataarray_plot.line(self._da, *args, **kwargs) + + @overload + def step( # type: ignore[misc] # None is hashable :( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + **kwargs: Any, + ) -> list[Line3D]: + ... + + @overload + def step( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def step( + self, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataarray_plot.step) + def step(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: + return dataarray_plot.step(self._da, *args, **kwargs) + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> PathCollection: + ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> FacetGrid[DataArray]: + ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataarray_plot.scatter) + def scatter(self, *args, **kwargs): + return dataarray_plot.scatter(self._da, *args, **kwargs) + + @overload + def imshow( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> AxesImage: + ... + + @overload + def imshow( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def imshow( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataarray_plot.imshow) + def imshow(self, *args, **kwargs) -> AxesImage: + return dataarray_plot.imshow(self._da, *args, **kwargs) + + @overload + def contour( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadContourSet: + ... + + @overload + def contour( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def contour( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataarray_plot.contour) + def contour(self, *args, **kwargs) -> QuadContourSet: + return dataarray_plot.contour(self._da, *args, **kwargs) + + @overload + def contourf( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadContourSet: + ... + + @overload + def contourf( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def contourf( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @functools.wraps(dataarray_plot.contourf) + def contourf(self, *args, **kwargs) -> QuadContourSet: + return dataarray_plot.contourf(self._da, *args, **kwargs) + + @overload + def pcolormesh( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> QuadMesh: + ... + + @overload + def pcolormesh( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @overload + def pcolormesh( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @functools.wraps(dataarray_plot.pcolormesh) + def pcolormesh(self, *args, **kwargs) -> QuadMesh: + return dataarray_plot.pcolormesh(self._da, *args, **kwargs) + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> Poly3DCollection: + ... + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @overload + def surface( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap=None, + center=None, + robust: bool = False, + extend=None, + levels=None, + infer_intervals=None, + colors=None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @functools.wraps(dataarray_plot.surface) + def surface(self, *args, **kwargs) -> Poly3DCollection: + return dataarray_plot.surface(self._da, *args, **kwargs) + + +class DatasetPlotAccessor: + """ + Enables use of xarray.plot functions as attributes on a Dataset. + For example, Dataset.plot.scatter + """ + + _ds: Dataset + __slots__ = ("_ds",) + + def __init__(self, dataset: Dataset) -> None: + self._ds = dataset + + def __call__(self, *args, **kwargs) -> NoReturn: + raise ValueError( + "Dataset.plot cannot be called directly. Use " + "an explicit plot method, e.g. ds.plot.scatter(...)" + ) + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> PathCollection: + ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @overload + def scatter( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: + ... + + @functools.wraps(dataset_plot.scatter) + def scatter(self, *args, **kwargs) -> PathCollection | FacetGrid[DataArray]: + return dataset_plot.scatter(self._ds, *args, **kwargs) + + @overload + def quiver( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> Quiver: + ... + + @overload + def quiver( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @overload + def quiver( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @functools.wraps(dataset_plot.quiver) + def quiver(self, *args, **kwargs) -> Quiver | FacetGrid: + return dataset_plot.quiver(self._ds, *args, **kwargs) + + @overload + def streamplot( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> LineCollection: + ... + + @overload + def streamplot( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @overload + def streamplot( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals=None, + center=None, + levels=None, + robust: bool | None = None, + colors=None, + extend=None, + cmap=None, + **kwargs: Any, + ) -> FacetGrid: + ... + + @functools.wraps(dataset_plot.streamplot) + def streamplot(self, *args, **kwargs) -> LineCollection | FacetGrid: + return dataset_plot.streamplot(self._ds, *args, **kwargs) diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py new file mode 100644 index 00000000000..ae44297058b --- /dev/null +++ b/xarray/plot/dataarray_plot.py @@ -0,0 +1,2466 @@ +from __future__ import annotations + +import functools +import warnings +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Hashable, + Iterable, + Literal, + MutableMapping, + cast, + overload, +) + +import numpy as np +import pandas as pd +from packaging.version import Version + +from ..core.alignment import broadcast +from ..core.concat import concat +from .facetgrid import _easy_facetgrid +from .utils import ( + _LINEWIDTH_RANGE, + _MARKERSIZE_RANGE, + _add_colorbar, + _add_legend, + _assert_valid_xy, + _determine_guide, + _ensure_plottable, + _infer_interval_breaks, + _infer_xy_labels, + _Normalize, + _process_cmap_cbar_kwargs, + _rescale_imshow_rgb, + _resolve_intervals_1dplot, + _resolve_intervals_2dplot, + _update_axes, + get_axis, + import_matplotlib_pyplot, + label_from_attrs, +) + +if TYPE_CHECKING: + from matplotlib.axes import Axes + from matplotlib.collections import PathCollection, QuadMesh + from matplotlib.colors import Colormap, Normalize + from matplotlib.container import BarContainer + from matplotlib.contour import QuadContourSet + from matplotlib.image import AxesImage + from mpl_toolkits.mplot3d.art3d import Line3D, Poly3DCollection + from numpy.typing import ArrayLike + + from ..core.dataarray import DataArray + from ..core.types import ( + AspectOptions, + ExtendOptions, + HueStyleOptions, + ScaleOptions, + T_DataArray, + ) + from .facetgrid import FacetGrid + + +def _infer_line_data( + darray: DataArray, x: Hashable | None, y: Hashable | None, hue: Hashable | None +) -> tuple[DataArray, DataArray, DataArray | None, str]: + + ndims = len(darray.dims) + + if x is not None and y is not None: + raise ValueError("Cannot specify both x and y kwargs for line plots.") + + if x is not None: + _assert_valid_xy(darray, x, "x") + + if y is not None: + _assert_valid_xy(darray, y, "y") + + if ndims == 1: + huename = None + hueplt = None + huelabel = "" + + if x is not None: + xplt = darray[x] + yplt = darray + + elif y is not None: + xplt = darray + yplt = darray[y] + + else: # Both x & y are None + dim = darray.dims[0] + xplt = darray[dim] + yplt = darray + + else: + if x is None and y is None and hue is None: + raise ValueError("For 2D inputs, please specify either hue, x or y.") + + if y is None: + if hue is not None: + _assert_valid_xy(darray, hue, "hue") + xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) + xplt = darray[xname] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(otherdim, huename, transpose_coords=False) + xplt = xplt.transpose(otherdim, huename, transpose_coords=False) + else: + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) + + else: + (xdim,) = darray[xname].dims + (huedim,) = darray[huename].dims + yplt = darray.transpose(xdim, huedim) + + else: + yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) + yplt = darray[yname] + if yplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + xplt = darray.transpose(otherdim, huename, transpose_coords=False) + yplt = yplt.transpose(otherdim, huename, transpose_coords=False) + else: + raise ValueError( + "For 2D inputs, hue must be a dimension" + " i.e. one of " + repr(darray.dims) + ) + + else: + (ydim,) = darray[yname].dims + (huedim,) = darray[huename].dims + xplt = darray.transpose(ydim, huedim) + + huelabel = label_from_attrs(darray[huename]) + hueplt = darray[huename] + + return xplt, yplt, hueplt, huelabel + + +def _infer_plot_dims( + darray: DataArray, + dims_plot: MutableMapping[str, Hashable], + default_guess: Iterable[str] = ("x", "hue", "size"), +) -> MutableMapping[str, Hashable]: + """ + Guess what dims to plot if some of the values in dims_plot are None which + happens when the user has not defined all available ways of visualizing + the data. + + Parameters + ---------- + darray : DataArray + The DataArray to check. + dims_plot : T_DimsPlot + Dims defined by the user to plot. + default_guess : Iterable[str], optional + Default values and order to retrieve dims if values in dims_plot is + missing, default: ("x", "hue", "size"). + """ + dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None} + dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values()) + + # If dims_plot[k] isn't defined then fill with one of the available dims: + for k, v in zip(default_guess, dims_avail): + if dims_plot.get(k, None) is None: + dims_plot[k] = v + + for k, v in dims_plot.items(): + _assert_valid_xy(darray, v, k) + + return dims_plot + + +def _infer_line_data2( + darray: T_DataArray, + dims_plot: MutableMapping[str, Hashable], + plotfunc_name: None | str = None, +) -> dict[str, T_DataArray]: + # Guess what dims to use if some of the values in plot_dims are None: + dims_plot = _infer_plot_dims(darray, dims_plot) + + # If there are more than 1 dimension in the array than stack all the + # dimensions so the plotter can plot anything: + if darray.ndim > 1: + # When stacking dims the lines will continue connecting. For floats + # this can be solved by adding a nan element inbetween the flattening + # points: + dims_T = [] + if np.issubdtype(darray.dtype, np.floating): + for v in ["z", "x"]: + dim = dims_plot.get(v, None) + if (dim is not None) and (dim in darray.dims): + darray_nan = np.nan * darray.isel({dim: -1}) + darray = concat([darray, darray_nan], dim=dim) + dims_T.append(dims_plot[v]) + + # Lines should never connect to the same coordinate when stacked, + # transpose to avoid this as much as possible: + darray = darray.transpose(..., *dims_T) + + # Array is now ready to be stacked: + darray = darray.stack(_stacked_dim=darray.dims) + + # Broadcast together all the chosen variables: + out = dict(y=darray) + out.update({k: darray[v] for k, v in dims_plot.items() if v is not None}) + out = dict(zip(out.keys(), broadcast(*(out.values())))) + + return out + + +# return type is Any due to the many different possibilities +def plot( + darray: DataArray, + *, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + subplot_kws: dict[str, Any] | None = None, + **kwargs: Any, +) -> Any: + """ + Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`. + + Calls xarray plotting function based on the dimensions of + the squeezed DataArray. + + =============== =========================== + Dimensions Plotting function + =============== =========================== + 1 :py:func:`xarray.plot.line` + 2 :py:func:`xarray.plot.pcolormesh` + Anything else :py:func:`xarray.plot.hist` + =============== =========================== + + Parameters + ---------- + darray : DataArray + row : Hashable or None, optional + If passed, make row faceted plots on this dimension name. + col : Hashable or None, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int or None, optional + Use together with ``col`` to wrap faceted plots. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size``, ``figsize`` and facets. + hue : Hashable or None, optional + If passed, make faceted line plots with hue on this dimension name. + subplot_kws : dict, optional + Dictionary of keyword arguments for Matplotlib subplots + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + **kwargs : optional + Additional keyword arguments for Matplotlib. + + See Also + -------- + xarray.DataArray.squeeze + """ + darray = darray.squeeze().compute() + + plot_dims = set(darray.dims) + plot_dims.discard(row) + plot_dims.discard(col) + plot_dims.discard(hue) + + ndims = len(plot_dims) + + plotfunc: Callable + if ndims in [1, 2]: + if row or col: + kwargs["subplot_kws"] = subplot_kws + kwargs["row"] = row + kwargs["col"] = col + kwargs["col_wrap"] = col_wrap + if ndims == 1: + plotfunc = line + kwargs["hue"] = hue + elif ndims == 2: + if hue: + plotfunc = line + kwargs["hue"] = hue + else: + plotfunc = pcolormesh + kwargs["subplot_kws"] = subplot_kws + else: + if row or col or hue: + raise ValueError( + "Only 1d and 2d plots are supported for facets in xarray. " + "See the package `Seaborn` for more options." + ) + plotfunc = hist + + kwargs["ax"] = ax + + return plotfunc(darray, **kwargs) + + +@overload +def line( # type: ignore[misc] # None is hashable :( + darray: DataArray, + *args: Any, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> list[Line3D]: + ... + + +@overload +def line( + darray, + *args: Any, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def line( + darray, + *args: Any, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +# This function signature should not change so that it can use +# matplotlib format strings +def line( + darray: DataArray, + *args: Any, + row: Hashable | None = None, + col: Hashable | None = None, + figsize: Iterable[float] | None = None, + aspect: AspectOptions = None, + size: float | None = None, + ax: Axes | None = None, + hue: Hashable | None = None, + x: Hashable | None = None, + y: Hashable | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + add_legend: bool = True, + _labels: bool = True, + **kwargs: Any, +) -> list[Line3D] | FacetGrid[DataArray]: + """ + Line plot of DataArray values. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. + + Parameters + ---------- + darray : DataArray + Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axes on which to plot. By default, the current is used. + Mutually exclusive with ``size`` and ``figsize``. + hue : Hashable, optional + Dimension or coordinate for which you want multiple lines plotted. + If plotting against a 2D coordinate, ``hue`` must be a dimension. + x, y : Hashable, optional + Dimension, coordinate or multi-index level for *x*, *y* axis. + Only one of these may be specified. + The other will be used for values from the DataArray on which this + plot method is called. + xincrease : bool or None, optional + Should the values on the *x* axis be increasing from left to right? + if ``None``, use the default for the Matplotlib function. + yincrease : bool or None, optional + Should the values on the *y* axis be increasing from top to bottom? + if ``None``, use the default for the Matplotlib function. + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : array-like, optional + Specify *x*- and *y*-axis limits. + add_legend : bool, default: True + Add legend with *y* axis coordinates (2D inputs only). + *args, **kwargs : optional + Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. + + Returns + ------- + primitive : list of Line3D or FacetGrid + When either col or row is given, returns a FacetGrid, otherwise + a list of matplotlib Line3D objects. + """ + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + return _easy_facetgrid(darray, line, kind="line", **allargs) + + ndims = len(darray.dims) + if ndims > 2: + raise ValueError( + "Line plots are for 1- or 2-dimensional DataArrays. " + "Passed DataArray has {ndims} " + "dimensions".format(ndims=ndims) + ) + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + else: + assert "args" not in kwargs + + ax = get_axis(figsize, size, aspect, ax) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) + + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.to_numpy(), yplt.to_numpy(), kwargs + ) + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) + + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + + if _labels: + if xlabel is not None: + ax.set_xlabel(xlabel) + + if ylabel is not None: + ax.set_ylabel(ylabel) + + ax.set_title(darray._title_for_slice()) + + if darray.ndim == 2 and add_legend: + assert hueplt is not None + ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) + + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha("right") + + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) + + return primitive + + +@overload +def step( # type: ignore[misc] # None is hashable :( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + **kwargs: Any, +) -> list[Line3D]: + ... + + +@overload +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +def step( + darray: DataArray, + *args: Any, + where: Literal["pre", "post", "mid"] = "pre", + drawstyle: str | None = None, + ds: str | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + **kwargs: Any, +) -> list[Line3D] | FacetGrid[DataArray]: + """ + Step plot of DataArray values. + + Similar to :py:func:`matplotlib:matplotlib.pyplot.step`. + + Parameters + ---------- + where : {'pre', 'post', 'mid'}, default: 'pre' + Define where the steps should be placed: + + - ``'pre'``: The y value is continued constantly to the left from + every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the + value ``y[i]``. + - ``'post'``: The y value is continued constantly to the right from + every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the + value ``y[i]``. + - ``'mid'``: Steps occur half-way between the *x* positions. + + Note that this parameter is ignored if one coordinate consists of + :py:class:`pandas.Interval` values, e.g. as a result of + :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual + boundaries of the interval are used. + drawstyle, ds : str or None, optional + Additional drawstyle. Only use one of drawstyle and ds. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + *args, **kwargs : optional + Additional arguments for :py:func:`xarray.plot.line`. + + Returns + ------- + primitive : list of Line3D or FacetGrid + When either col or row is given, returns a FacetGrid, otherwise + a list of matplotlib Line3D objects. + """ + if where not in {"pre", "post", "mid"}: + raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") + + if ds is not None: + if drawstyle is None: + drawstyle = ds + else: + raise TypeError("ds and drawstyle are mutually exclusive") + if drawstyle is None: + drawstyle = "" + drawstyle = "steps-" + where + drawstyle + + return line(darray, *args, drawstyle=drawstyle, col=col, row=row, **kwargs) + + +def hist( + darray: DataArray, + *args: Any, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + xincrease: bool | None = None, + yincrease: bool | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + **kwargs: Any, +) -> tuple[np.ndarray, np.ndarray, BarContainer]: + """ + Histogram of DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. + + Plots *N*-dimensional arrays by first flattening the array. + + Parameters + ---------- + darray : DataArray + Can have any number of dimensions. + figsize : Iterable of float, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size`` and ``figsize``. + xincrease : bool or None, optional + Should the values on the *x* axis be increasing from left to right? + if ``None``, use the default for the Matplotlib function. + yincrease : bool or None, optional + Should the values on the *y* axis be increasing from top to bottom? + if ``None``, use the default for the Matplotlib function. + xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional + Specifies scaling for the *x*- and *y*-axis, respectively. + xticks, yticks : array-like, optional + Specify tick locations for *x*- and *y*-axis. + xlim, ylim : array-like, optional + Specify *x*- and *y*-axis limits. + **kwargs : optional + Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. + + """ + assert len(args) == 0 + + ax = get_axis(figsize, size, aspect, ax) + + no_nan = np.ravel(darray.to_numpy()) + no_nan = no_nan[pd.notnull(no_nan)] + + primitive = ax.hist(no_nan, **kwargs) + + ax.set_title(darray._title_for_slice()) + ax.set_xlabel(label_from_attrs(darray)) + + _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) + + return primitive + + +def _plot1d(plotfunc): + """Decorator for common 1d plotting logic.""" + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be 2 dimensional, unless creating faceted plots. + x : Hashable or None, optional + Coordinate for x axis. If None use darray.dims[1]. + y : Hashable or None, optional + Coordinate for y axis. If None use darray.dims[0]. + z : Hashable or None, optional + If specified plot 3D and use this coordinate for *z* axis. + hue : Hashable or None, optional + Dimension or coordinate for which you want multiple lines plotted. + hue_style: {'discrete', 'continuous'} or None, optional + How to use the ``hue`` variable: + + - ``'continuous'`` -- continuous color scale + (default for numeric ``hue`` variables) + - ``'discrete'`` -- a color for each unique value, + using the default color cycle + (default for non-numeric ``hue`` variables) + + markersize: Hashable or None, optional + scatter only. Variable by which to vary size of scattered points. + linewidth: Hashable or None, optional + Variable by which to vary linewidth. + row : Hashable, optional + If passed, make row faceted plots on this dimension name. + col : Hashable, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes object, optional + If None, uses the current axis. Not applicable when using facets. + figsize : Iterable[float] or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + xincrease : bool or None, default: True + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : bool or None, default: True + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_legend : bool or None, optional + If True use xarray metadata to add a legend. + add_colorbar : bool or None, optional + If True add a colorbar. + add_labels : bool or None, optional + If True use xarray metadata to label axes + add_title : bool or None, optional + If True use xarray metadata to add a title + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the x-axes. + yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the y-axes. + xticks : ArrayLike or None, optional + Specify tick locations for x-axes. + yticks : ArrayLike or None, optional + Specify tick locations for y-axes. + xlim : ArrayLike or None, optional + Specify x-axes limits. + ylim : ArrayLike or None, optional + Specify y-axes limits. + cmap : matplotlib colormap name or colormap, optional + The mapping from data values to color space. Either a + Matplotlib colormap name or object. If not provided, this will + be either ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging + dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette, + ``levels`` must also be specified. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional arguments to wrapped matplotlib function + + Returns + ------- + artist : + The same type of primitive artist that the wrapped matplotlib + function returns + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) + def newplotfunc( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, + ) -> Any: + # All 1d plots in xarray share this function signature. + # Method signature below should be consistent. + + if subplot_kws is None: + subplot_kws = dict() + + # Handle facetgrids first + if row or col: + if z is not None: + subplot_kws.update(projection="3d") + + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + allargs["plotfunc"] = globals()[plotfunc.__name__] + + return _easy_facetgrid(darray, kind="plot1d", **allargs) + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + + if args: + assert "args" not in kwargs + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + assert z is None + z = args[2] + if len(args) > 3: + assert hue is None + hue = args[3] + if len(args) > 4: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args + + _is_facetgrid = kwargs.pop("_is_facetgrid", False) + + if plotfunc.__name__ == "scatter": + size_ = markersize + size_r = _MARKERSIZE_RANGE + else: + size_ = linewidth + size_r = _LINEWIDTH_RANGE + + # Get data to plot: + dims_plot = dict(x=x, z=z, hue=hue, size=size_) + plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__) + xplt = plts.pop("x", None) + yplt = plts.pop("y", None) + zplt = plts.pop("z", None) + kwargs.update(zplt=zplt) + hueplt = plts.pop("hue", None) + sizeplt = plts.pop("size", None) + + # Handle size and hue: + hueplt_norm = _Normalize(data=hueplt) + kwargs.update(hueplt=hueplt_norm.values) + sizeplt_norm = _Normalize( + data=sizeplt, width=size_r, _is_facetgrid=_is_facetgrid + ) + kwargs.update(sizeplt=sizeplt_norm.values) + cmap_params_subset = kwargs.pop("cmap_params_subset", {}) + cbar_kwargs = kwargs.pop("cbar_kwargs", {}) + + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # Map hue values back to its original value: + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + levels = kwargs.get("levels", hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, + cast("DataArray", hueplt_norm.values).data, + **locals(), + ) + + # subset that can be passed to scatter, hist2d + if not cmap_params_subset: + ckw = {vv: cmap_params[vv] for vv in ("vmin", "vmax", "norm", "cmap")} + cmap_params_subset.update(**ckw) + + if z is not None: + if ax is None: + subplot_kws.update(projection="3d") + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + # Using 30, 30 minimizes rotation of the plot. Making it easier to + # build on your intuition from 2D plots: + plt = import_matplotlib_pyplot() + if Version(plt.matplotlib.__version__) < Version("3.5.0"): + ax.view_init(azim=30, elev=30) + else: + # https://github.com/matplotlib/matplotlib/pull/19873 + ax.view_init(azim=30, elev=30, vertical_axis="y") + else: + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + primitive = plotfunc( + xplt, + yplt, + ax=ax, + add_labels=add_labels, + **cmap_params_subset, + **kwargs, + ) + + if np.any(np.asarray(add_labels)) and add_title: + ax.set_title(darray._title_for_slice()) + + add_colorbar_, add_legend_ = _determine_guide( + hueplt_norm, + sizeplt_norm, + add_colorbar, + add_legend, + plotfunc_name=plotfunc.__name__, + ) + + if add_colorbar_: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + + _add_colorbar( + primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params + ) + + if add_legend_: + if plotfunc.__name__ in ["scatter", "line"]: + _add_legend( + hueplt_norm + if add_legend or not add_colorbar_ + else _Normalize(None), + sizeplt_norm, + primitive, + legend_ax=ax, + plotfunc=plotfunc.__name__, + ) + else: + hueplt_norm_values: list[np.ndarray | None] + if hueplt_norm.data is not None: + hueplt_norm_values = list( + cast("DataArray", hueplt_norm.data).to_numpy() + ) + else: + hueplt_norm_values = [hueplt_norm.data] + + if plotfunc.__name__ == "hist": + ax.legend( + handles=primitive[-1], + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) + else: + ax.legend( + handles=primitive, + labels=hueplt_norm_values, + title=label_from_attrs(hueplt_norm.data), + ) + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + return primitive + + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ + + return newplotfunc + + +def _add_labels( + add_labels: bool | Iterable[bool], + darrays: Iterable[DataArray], + suffixes: Iterable[str], + rotate_labels: Iterable[bool], + ax: Axes, +) -> None: + # Set x, y, z labels: + add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels + for axis, add_label, darray, suffix, rotate_label in zip( + ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels + ): + if darray is None: + continue + + if add_label: + label = label_from_attrs(darray, extra=suffix) + if label is not None: + getattr(ax, f"set_{axis}label")(label) + + if rotate_label and np.issubdtype(darray.dtype, np.datetime64): + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + for labels in getattr(ax, f"get_{axis}ticklabels")(): + labels.set_rotation(30) + labels.set_ha("right") + + +@overload +def scatter( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> PathCollection: + ... + + +@overload +def scatter( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[DataArray]: + ... + + +@overload +def scatter( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[DataArray]: + ... + + +@_plot1d +def scatter( + xplt: DataArray | None, + yplt: DataArray | None, + ax: Axes, + add_labels: bool | Iterable[bool] = True, + **kwargs, +) -> PathCollection: + """Scatter variables against each other. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.scatter`. + """ + plt = import_matplotlib_pyplot() + + if "u" in kwargs or "v" in kwargs: + raise ValueError("u, v are not allowed in scatter plots.") + + zplt: DataArray | None = kwargs.pop("zplt", None) + hueplt: DataArray | None = kwargs.pop("hueplt", None) + sizeplt: DataArray | None = kwargs.pop("sizeplt", None) + + # Add a white border to make it easier seeing overlapping markers: + kwargs.update(edgecolors=kwargs.pop("edgecolors", "w")) + + if hueplt is not None: + kwargs.update(c=hueplt.to_numpy().ravel()) + + if sizeplt is not None: + kwargs.update(s=sizeplt.to_numpy().ravel()) + + if Version(plt.matplotlib.__version__) < Version("3.5.0"): + # Plot the data. 3d plots has the z value in upward direction + # instead of y. To make jumping between 2d and 3d easy and intuitive + # switch the order so that z is shown in the depthwise direction: + axis_order = ["x", "z", "y"] + else: + # Switching axis order not needed in 3.5.0, can also simplify the code + # that uses axis_order: + # https://github.com/matplotlib/matplotlib/pull/19873 + axis_order = ["x", "y", "z"] + + plts_dict: dict[str, DataArray | None] = dict(x=xplt, y=yplt, z=zplt) + plts_or_none = [plts_dict[v] for v in axis_order] + plts = [p for p in plts_or_none if p is not None] + primitive = ax.scatter(*[p.to_numpy().ravel() for p in plts], **kwargs) + _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) + + return primitive + + +def _plot2d(plotfunc): + """Decorator for common 2d plotting logic.""" + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be two-dimensional, unless creating faceted plots. + x : Hashable or None, optional + Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``. + y : Hashable or None, optional + Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``. + figsize : Iterable or float or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size: + *height* (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the *width* in + inches. Only used if a ``size`` is provided. + ax : matplotlib axes object, optional + Axes on which to plot. By default, use the current axes. + Mutually exclusive with ``size`` and ``figsize``. + row : Hashable or None, optional + If passed, make row faceted plots on this dimension name. + col : Hashable or None, optional + If passed, make column faceted plots on this dimension name. + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots. + xincrease : None, True, or False, optional + Should the values on the *x* axis be increasing from left to right? + If ``None``, use the default for the Matplotlib function. + yincrease : None, True, or False, optional + Should the values on the *y* axis be increasing from top to bottom? + If ``None``, use the default for the Matplotlib function. + add_colorbar : bool, optional + Add colorbar to axes. + add_labels : bool, optional + Use xarray metadata to label axes. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : matplotlib colormap name or colormap, optional + The mapping from data values to color space. If not provided, this + will be either be ``'viridis'`` (if the function infers a sequential + dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). + See :doc:`Choosing Colormaps in Matplotlib ` + for more information. + + If *seaborn* is installed, ``cmap`` may also be a + `seaborn color palette `_. + Note: if ``cmap`` is a seaborn color palette and the plot type + is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. + levels : int or array-like, optional + Split the colormap (``cmap``) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + infer_intervals : bool, optional + Only applies to pcolormesh. If ``True``, the coordinate intervals are + passed to pcolormesh. If ``False``, the original coordinates are used + (this can be useful for certain map projections). The default is to + always infer intervals, unless the mesh is irregular and plotted on + a map projection. + colors : str or array-like of color-like, optional + A single color or a sequence of colors. If the plot type is not ``'contour'`` + or ``'contourf'``, the ``levels`` argument is required. + subplot_kws : dict, optional + Dictionary of keyword arguments for Matplotlib subplots. Only used + for 2D and faceted plots. + (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). + cbar_ax : matplotlib axes object, optional + Axes in which to draw the colorbar. + cbar_kwargs : dict, optional + Dictionary of keyword arguments to pass to the colorbar + (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). + xscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the x-axes. + yscale : {'linear', 'symlog', 'log', 'logit'} or None, optional + Specifies scaling for the y-axes. + xticks : ArrayLike or None, optional + Specify tick locations for x-axes. + yticks : ArrayLike or None, optional + Specify tick locations for y-axes. + xlim : ArrayLike or None, optional + Specify x-axes limits. + ylim : ArrayLike or None, optional + Specify y-axes limits. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + **kwargs : optional + Additional keyword arguments to wrapped Matplotlib function. + + Returns + ------- + artist : + The same type of primitive artist that the wrapped Matplotlib + function returns. + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) + def newplotfunc( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, + ) -> Any: + # All 2d plots in xarray share this function signature. + + if args: + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args + + # Decide on a default for the colorbar before facetgrids + if add_colorbar is None: + add_colorbar = True + if plotfunc.__name__ == "contour" or ( + plotfunc.__name__ == "surface" and cmap is None + ): + add_colorbar = False + imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( + 3 + (row is not None) + (col is not None) + ) + if imshow_rgb: + # Don't add a colorbar when showing an image with explicit colors + add_colorbar = False + # Matplotlib does not support normalising RGB data, so do it here. + # See eg. https://github.com/matplotlib/matplotlib/pull/10220 + if robust or vmax is not None or vmin is not None: + darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust) + vmin, vmax, robust = None, None, False + + if subplot_kws is None: + subplot_kws = dict() + + if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): + if ax is None: + # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. + # Remove when minimum requirement of matplotlib is 3.2: + from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 + + # delete so it does not end up in locals() + del Axes3D + + # Need to create a "3d" Axes instance for surface plots + subplot_kws["projection"] = "3d" + + # In facet grids, shared axis labels don't make sense for surface plots + sharex = False + sharey = False + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + del allargs["darray"] + del allargs["imshow_rgb"] + allargs.update(allargs.pop("kwargs")) + # Need the decorated plotting function + allargs["plotfunc"] = globals()[plotfunc.__name__] + return _easy_facetgrid(darray, kind="dataarray", **allargs) + + plt = import_matplotlib_pyplot() + + if ( + plotfunc.__name__ == "surface" + and not kwargs.get("_is_facetgrid", False) + and ax is not None + ): + import mpl_toolkits # type: ignore + + if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): + raise ValueError( + "If ax is passed to surface(), it must be created with " + 'projection="3d"' + ) + + rgb = kwargs.pop("rgb", None) + if rgb is not None and plotfunc.__name__ != "imshow": + raise ValueError('The "rgb" keyword is only valid for imshow()') + elif rgb is not None and not imshow_rgb: + raise ValueError( + 'The "rgb" keyword is only valid for imshow()' + "with a three-dimensional array (per facet)" + ) + + xlab, ylab = _infer_xy_labels( + darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb + ) + + xval = darray[xlab] + yval = darray[ylab] + + if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": + # Passing 2d coordinate values, need to ensure they are transposed the same + # way as darray. + # Also surface plots always need 2d coordinates + xval = xval.broadcast_like(darray) + yval = yval.broadcast_like(darray) + dims = darray.dims + else: + dims = (yval.dims[0], xval.dims[0]) + + # May need to transpose for correct x, y labels + # xlab may be the name of a coord, we have to check for dim names + if imshow_rgb: + # For RGB[A] images, matplotlib requires the color dimension + # to be last. In Xarray the order should be unimportant, so + # we transpose to (y, x, color) to make this work. + yx_dims = (ylab, xlab) + dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) + + if dims != darray.dims: + darray = darray.transpose(*dims, transpose_coords=True) + + # better to pass the ndarrays directly to plotting functions + xvalnp = xval.to_numpy() + yvalnp = yval.to_numpy() + + # Pass the data as a masked ndarray too + zval = darray.to_masked_array(copy=False) + + # Replace pd.Intervals if contained in xval or yval. + xplt, xlab_extra = _resolve_intervals_2dplot(xvalnp, plotfunc.__name__) + yplt, ylab_extra = _resolve_intervals_2dplot(yvalnp, plotfunc.__name__) + + _ensure_plottable(xplt, yplt, zval) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, + zval.data, + **locals(), + _is_facetgrid=kwargs.pop("_is_facetgrid", False), + ) + + if "contour" in plotfunc.__name__: + # extend is a keyword argument only for contour and contourf, but + # passing it to the colorbar is sufficient for imshow and + # pcolormesh + kwargs["extend"] = cmap_params["extend"] + kwargs["levels"] = cmap_params["levels"] + # if colors == a single color, matplotlib draws dashed negative + # contours. we lose this feature if we pass cmap and not colors + if isinstance(colors, str): + cmap_params["cmap"] = None + kwargs["colors"] = colors + + if "pcolormesh" == plotfunc.__name__: + kwargs["infer_intervals"] = infer_intervals + kwargs["xscale"] = xscale + kwargs["yscale"] = yscale + + if "imshow" == plotfunc.__name__ and isinstance(aspect, str): + # forbid usage of mpl strings + raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") + + ax = get_axis(figsize, size, aspect, ax, **subplot_kws) + + primitive = plotfunc( + xplt, + yplt, + zval, + ax=ax, + cmap=cmap_params["cmap"], + vmin=cmap_params["vmin"], + vmax=cmap_params["vmax"], + norm=cmap_params["norm"], + **kwargs, + ) + + # Label the plot with metadata + if add_labels: + ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) + ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) + ax.set_title(darray._title_for_slice()) + if plotfunc.__name__ == "surface": + ax.set_zlabel(label_from_attrs(darray)) + + if add_colorbar: + if add_labels and "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(darray) + cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + elif cbar_ax is not None or cbar_kwargs: + # inform the user about keywords which aren't used + raise ValueError( + "cbar_ax and cbar_kwargs can't be used with add_colorbar=False." + ) + + # origin kwarg overrides yincrease + if "origin" in kwargs: + yincrease = None + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha("right") + + return primitive + + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ + + return newplotfunc + + +@overload +def imshow( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> AxesImage: + ... + + +@overload +def imshow( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def imshow( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_plot2d +def imshow( + x: np.ndarray, y: np.ndarray, z: np.ma.core.MaskedArray, ax: Axes, **kwargs: Any +) -> AxesImage: + """ + Image plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`. + + While other plot methods require the DataArray to be strictly + two-dimensional, ``imshow`` also accepts a 3D array where some + dimension can be interpreted as RGB or RGBA color channels and + allows this dimension to be specified via the kwarg ``rgb=``. + + Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax`` + for RGB(A) data, + xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data + by applying a single scaling factor and offset to all bands. + Passing ``robust=True`` infers ``vmin`` and ``vmax`` + :ref:`in the usual way `. + Additionally the y-axis is not inverted by default, you can + restore the matplotlib behavior by setting `yincrease=False`. + + .. note:: + This function needs uniformly spaced coordinates to + properly label the axes. Call :py:meth:`DataArray.plot` to check. + + The pixels are centered on the coordinates. For example, if the coordinate + value is 3.2, then the pixels for those coordinates will be centered on 3.2. + """ + + if x.ndim != 1 or y.ndim != 1: + raise ValueError( + "imshow requires 1D coordinates, try using pcolormesh or contour(f)" + ) + + def _center_pixels(x): + """Center the pixels on the coordinates.""" + if np.issubdtype(x.dtype, str): + # When using strings as inputs imshow converts it to + # integers. Choose extent values which puts the indices in + # in the center of the pixels: + return 0 - 0.5, len(x) - 0.5 + + try: + # Center the pixels assuming uniform spacing: + xstep = 0.5 * (x[1] - x[0]) + except IndexError: + # Arbitrary default value, similar to matplotlib behaviour: + xstep = 0.1 + + return x[0] - xstep, x[-1] + xstep + + # Center the pixels: + left, right = _center_pixels(x) + top, bottom = _center_pixels(y) + + defaults: dict[str, Any] = {"origin": "upper", "interpolation": "nearest"} + + if not hasattr(ax, "projection"): + # not for cartopy geoaxes + defaults["aspect"] = "auto" + + # Allow user to override these defaults + defaults.update(kwargs) + + if defaults["origin"] == "upper": + defaults["extent"] = [left, right, bottom, top] + else: + defaults["extent"] = [left, right, top, bottom] + + if z.ndim == 3: + # matplotlib imshow uses black for missing data, but Xarray makes + # missing data transparent. We therefore add an alpha channel if + # there isn't one, and set it to transparent where data is masked. + if z.shape[-1] == 3: + alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) + if np.issubdtype(z.dtype, np.integer): + alpha *= 255 + z = np.ma.concatenate((z, alpha), axis=2) + else: + z = z.copy() + z[np.any(z.mask, axis=-1), -1] = 0 + + primitive = ax.imshow(z, **defaults) + + # If x or y are strings the ticklabels have been replaced with + # integer indices. Replace them back to strings: + for axis, v in [("x", x), ("y", y)]: + if np.issubdtype(v.dtype, str): + getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) + getattr(ax, f"set_{axis}ticklabels")(v) + + return primitive + + +@overload +def contour( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadContourSet: + ... + + +@overload +def contour( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def contour( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_plot2d +def contour( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> QuadContourSet: + """ + Contour plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`. + """ + primitive = ax.contour(x, y, z, **kwargs) + return primitive + + +@overload +def contourf( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadContourSet: + ... + + +@overload +def contourf( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def contourf( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_plot2d +def contourf( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> QuadContourSet: + """ + Filled contour plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`. + """ + primitive = ax.contourf(x, y, z, **kwargs) + return primitive + + +@overload +def pcolormesh( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> QuadMesh: + ... + + +@overload +def pcolormesh( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def pcolormesh( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_plot2d +def pcolormesh( + x: np.ndarray, + y: np.ndarray, + z: np.ndarray, + ax: Axes, + xscale: ScaleOptions | None = None, + yscale: ScaleOptions | None = None, + infer_intervals=None, + **kwargs: Any, +) -> QuadMesh: + """ + Pseudocolor plot of 2D DataArray. + + Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`. + """ + + # decide on a default for infer_intervals (GH781) + x = np.asarray(x) + if infer_intervals is None: + if hasattr(ax, "projection"): + if len(x.shape) == 1: + infer_intervals = True + else: + infer_intervals = False + else: + infer_intervals = True + + if ( + infer_intervals + and not np.issubdtype(x.dtype, str) + and ( + (np.shape(x)[0] == np.shape(z)[1]) + or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) + ) + ): + if len(x.shape) == 1: + x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) + else: + # we have to infer the intervals on both axes + x = _infer_interval_breaks(x, axis=1, scale=xscale) + x = _infer_interval_breaks(x, axis=0, scale=xscale) + + if ( + infer_intervals + and not np.issubdtype(y.dtype, str) + and (np.shape(y)[0] == np.shape(z)[0]) + ): + if len(y.shape) == 1: + y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) + else: + # we have to infer the intervals on both axes + y = _infer_interval_breaks(y, axis=1, scale=yscale) + y = _infer_interval_breaks(y, axis=0, scale=yscale) + + primitive = ax.pcolormesh(x, y, z, **kwargs) + + # by default, pcolormesh picks "round" values for bounds + # this results in ugly looking plots with lots of surrounding whitespace + if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1: + # not a cartopy geoaxis + ax.set_xlim(x[0], x[-1]) + ax.set_ylim(y[0], y[-1]) + + return primitive + + +@overload +def surface( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> Poly3DCollection: + ... + + +@overload +def surface( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def surface( + darray: DataArray, + x: Hashable | None = None, + y: Hashable | None = None, + *, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_colorbar: bool | None = None, + add_labels: bool = True, + vmin: float | None = None, + vmax: float | None = None, + cmap: str | Colormap | None = None, + center: float | None = None, + robust: bool = False, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + infer_intervals=None, + colors: str | ArrayLike | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cbar_kwargs: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + norm: Normalize | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_plot2d +def surface( + x: np.ndarray, y: np.ndarray, z: np.ndarray, ax: Axes, **kwargs: Any +) -> Poly3DCollection: + """ + Surface plot of 2D DataArray. + + Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. + """ + primitive = ax.plot_surface(x, y, z, **kwargs) + return primitive diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 2b10b2afcfe..55819b0ab9f 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -2,11 +2,12 @@ import functools import inspect -from typing import TYPE_CHECKING, Any, Callable, Hashable, Mapping +import warnings +from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, TypeVar, overload from ..core.alignment import broadcast +from . import dataarray_plot from .facetgrid import _easy_facetgrid -from .plot import _PlotMethods from .utils import ( _add_colorbar, _get_nice_quiver_magnitude, @@ -16,24 +17,16 @@ ) if TYPE_CHECKING: - from ..core.dataarray import DataArray - from ..core.types import T_Dataset - - -class _Dataset_PlotMethods: - """ - Enables use of xarray.plot functions as attributes on a Dataset. - For example, Dataset.plot.scatter - """ + from matplotlib.axes import Axes + from matplotlib.collections import LineCollection, PathCollection + from matplotlib.colors import Colormap, Normalize + from matplotlib.quiver import Quiver + from numpy.typing import ArrayLike - def __init__(self, dataset): - self._ds = dataset - - def __call__(self, *args, **kwargs): - raise ValueError( - "Dataset.plot cannot be called directly. Use " - "an explicit plot method, e.g. ds.plot.scatter(...)" - ) + from ..core.dataarray import DataArray + from ..core.dataset import Dataset + from ..core.types import AspectOptions, ExtendOptions, HueStyleOptions, ScaleOptions + from .facetgrid import FacetGrid def _dsplot(plotfunc): @@ -42,64 +35,62 @@ def _dsplot(plotfunc): ---------- ds : Dataset - x, y : str - Variable names for the *x* and *y* grid positions. - u, v : str, optional - Variable names for the *u* and *v* velocities - (in *x* and *y* direction, respectively; quiver/streamplot plots only). - hue: str, optional + x : Hashable or None, optional + Variable name for x-axis. + y : Hashable or None, optional + Variable name for y-axis. + u : Hashable or None, optional + Variable name for the *u* velocity (in *x* direction). + quiver/streamplot plots only. + v : Hashable or None, optional + Variable name for the *v* velocity (in *y* direction). + quiver/streamplot plots only. + hue: Hashable or None, optional Variable by which to color scatter points or arrows. - hue_style: {'continuous', 'discrete'}, optional + hue_style: {'continuous', 'discrete'} or None, optional How to use the ``hue`` variable: - ``'continuous'`` -- continuous color scale (default for numeric ``hue`` variables) - ``'discrete'`` -- a color for each unique value, using the default color cycle (default for non-numeric ``hue`` variables) - markersize: str, optional - Variable by which to vary the size of scattered points (scatter plot only). - size_norm: matplotlib.colors.Normalize or tuple, optional - Used to normalize the ``markersize`` variable. - If a tuple is passed, the values will be passed to - :py:class:`matplotlib:matplotlib.colors.Normalize` as arguments. - Default: no normalization (``vmin=None``, ``vmax=None``, ``clip=False``). - scale: scalar, optional - Quiver only. Number of data units per arrow length unit. - Use this to control the length of the arrows: larger values lead to - smaller arrows. - add_guide: bool, optional, default: True - Add a guide that depends on ``hue_style``: - - ``'continuous'`` -- build a colorbar - - ``'discrete'`` -- build a legend - row : str, optional + row : Hashable or None, optional If passed, make row faceted plots on this dimension name. - col : str, optional + col : Hashable or None, optional If passed, make column faceted plots on this dimension name. col_wrap : int, optional Use together with ``col`` to wrap faceted plots. - ax : matplotlib axes object, optional + ax : matplotlib axes object or None, optional If ``None``, use the current axes. Not applicable when using facets. - subplot_kws : dict, optional + figsize : Iterable[float] or None, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + aspect : "auto", "equal", scalar or None, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + sharex : bool or None, optional + If True all subplots share the same x-axis. + sharey : bool or None, optional + If True all subplots share the same y-axis. + add_guide: bool or None, optional + Add a guide that depends on ``hue_style``: + + - ``'continuous'`` -- build a colorbar + - ``'discrete'`` -- build a legend + + subplot_kws : dict or None, optional Dictionary of keyword arguments for Matplotlib subplots (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). Only applies to FacetGrid plotting. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the *width* in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size: - *height* (in inches) of each plot. See also: ``aspect``. - norm : matplotlib.colors.Normalize, optional - If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding - kwarg must be ``None``. - vmin, vmax : float, optional - Values to anchor the colormap, otherwise they are inferred from the - data and other keyword arguments. When a diverging dataset is inferred, - setting one of these values will fix the other by symmetry around - ``center``. Setting both values prevents use of a diverging colormap. - If discrete levels are provided as an explicit list, both of these - values are ignored. + cbar_kwargs : dict, optional + Dictionary of keyword arguments to pass to the colorbar + (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). + cbar_ax : matplotlib axes object, optional + Axes in which to draw the colorbar. cmap : matplotlib colormap name or colormap, optional The mapping from data values to color space. Either a Matplotlib colormap name or object. If not provided, this will @@ -113,9 +104,25 @@ def _dsplot(plotfunc): `seaborn color palette `_. Note: if ``cmap`` is a seaborn color palette, ``levels`` must also be specified. - colors : str or array-like of color-like, optional - A single color or a list of colors. The ``levels`` argument - is required. + vmin : float or None, optional + Lower value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + vmax : float or None, optional + Upper value to anchor the colormap, otherwise it is inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting `vmin` or `vmax` will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + norm : matplotlib.colors.Normalize, optional + If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding + kwarg must be ``None``. + infer_intervals: bool | None + If True the intervals are infered. center : float, optional The value at which to center the colormap. Passing this value implies use of a diverging colormap. Setting it to ``False`` prevents use of a @@ -123,6 +130,9 @@ def _dsplot(plotfunc): robust : bool, optional If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is computed with 2nd and 98th percentiles instead of the extreme values. + colors : str or array-like of color-like, optional + A single color or a list of colors. The ``levels`` argument + is required. extend : {'neither', 'both', 'min', 'max'}, optional How to draw arrows extending the colorbar beyond its limits. If not provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. @@ -139,40 +149,66 @@ def _dsplot(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" - @functools.wraps(plotfunc) + @functools.wraps( + plotfunc, assigned=("__module__", "__name__", "__qualname__", "__doc__") + ) def newplotfunc( - ds, - x=None, - y=None, - u=None, - v=None, - hue=None, - hue_style=None, - col=None, - row=None, - ax=None, - figsize=None, - size=None, - col_wrap=None, - sharex=True, - sharey=True, - aspect=None, - subplot_kws=None, - add_guide=None, - cbar_kwargs=None, - cbar_ax=None, - vmin=None, - vmax=None, - norm=None, - infer_intervals=None, - center=None, - levels=None, - robust=None, - colors=None, - extend=None, - cmap=None, - **kwargs, - ): + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: AspectOptions = None, + sharex: bool = True, + sharey: bool = True, + add_guide: bool | None = None, + subplot_kws: dict[str, Any] | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, + ) -> Any: + + if args: + # TODO: Deprecated since 2022.10: + msg = "Using positional arguments is deprecated for plot methods, use keyword arguments instead." + assert x is None + x = args[0] + if len(args) > 1: + assert y is None + y = args[1] + if len(args) > 2: + assert u is None + u = args[2] + if len(args) > 3: + assert v is None + v = args[3] + if len(args) > 4: + assert hue is None + hue = args[4] + if len(args) > 5: + raise ValueError(msg) + else: + warnings.warn(msg, DeprecationWarning, stacklevel=2) + del args _is_facetgrid = kwargs.pop("_is_facetgrid", False) if _is_facetgrid: # facetgrid call @@ -271,61 +307,138 @@ def newplotfunc( return primitive - @functools.wraps(newplotfunc) - def plotmethod( - _PlotMethods_obj, - x=None, - y=None, - u=None, - v=None, - hue=None, - hue_style=None, - col=None, - row=None, - ax=None, - figsize=None, - col_wrap=None, - sharex=True, - sharey=True, - aspect=None, - size=None, - subplot_kws=None, - add_guide=None, - cbar_kwargs=None, - cbar_ax=None, - vmin=None, - vmax=None, - norm=None, - infer_intervals=None, - center=None, - levels=None, - robust=None, - colors=None, - extend=None, - cmap=None, - **kwargs, - ): - """ - The method should have the same signature as the function. - - This just makes the method work on Plotmethods objects, - and passes all the other arguments straight through. - """ - allargs = locals() - allargs["ds"] = _PlotMethods_obj._ds - allargs.update(kwargs) - for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: - del allargs[arg] - return newplotfunc(**allargs) - - # Add to class _PlotMethods - setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) + # we want to actually expose the signature of newplotfunc + # and not the copied **kwargs from the plotfunc which + # functools.wraps adds, so delete the wrapped attr + del newplotfunc.__wrapped__ return newplotfunc +@overload +def quiver( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> Quiver: + ... + + +@overload +def quiver( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: + ... + + +@overload +def quiver( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: + ... + + @_dsplot -def quiver(ds, x, y, ax, u, v, **kwargs): +def quiver( + ds: Dataset, + x: Hashable, + y: Hashable, + ax: Axes, + u: Hashable, + v: Hashable, + **kwargs: Any, +) -> Quiver: """Quiver plot of Dataset variables. Wraps :py:func:`matplotlib:matplotlib.pyplot.quiver`. @@ -335,9 +448,9 @@ def quiver(ds, x, y, ax, u, v, **kwargs): if x is None or y is None or u is None or v is None: raise ValueError("Must specify x, y, u, v for quiver plots.") - x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v]) - args = [x.values, y.values, u.values, v.values] + args = [dx.values, dy.values, du.values, dv.values] hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -356,8 +469,130 @@ def quiver(ds, x, y, ax, u, v, **kwargs): return hdl +@overload +def streamplot( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: None = None, # no wrap -> primitive + row: None = None, # no wrap -> primitive + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> LineCollection: + ... + + +@overload +def streamplot( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable, # wrap -> FacetGrid + row: Hashable | None = None, + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: + ... + + +@overload +def streamplot( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + u: Hashable | None = None, + v: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + col: Hashable | None = None, + row: Hashable, # wrap -> FacetGrid + ax: Axes | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + col_wrap: int | None = None, + sharex: bool = True, + sharey: bool = True, + aspect: AspectOptions = None, + subplot_kws: dict[str, Any] | None = None, + add_guide: bool | None = None, + cbar_kwargs: dict[str, Any] | None = None, + cbar_ax: Axes | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + infer_intervals: bool | None = None, + center: float | None = None, + levels: ArrayLike | None = None, + robust: bool | None = None, + colors: str | ArrayLike | None = None, + extend: ExtendOptions = None, + cmap: str | Colormap | None = None, + **kwargs: Any, +) -> FacetGrid[Dataset]: + ... + + @_dsplot -def streamplot(ds, x, y, ax, u, v, **kwargs): +def streamplot( + ds: Dataset, + x: Hashable, + y: Hashable, + ax: Axes, + u: Hashable, + v: Hashable, + **kwargs: Any, +) -> LineCollection: """Plot streamlines of Dataset variables. Wraps :py:func:`matplotlib:matplotlib.pyplot.streamplot`. @@ -372,25 +607,27 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): # the dimension of x must be the second dimension. 'y' cannot vary with 'columns' so # the dimension of y must be the first dimension. If x and y are both 2d, assume the # user has got them right already. - if len(ds[x].dims) == 1: - xdim = ds[x].dims[0] - if len(ds[y].dims) == 1: - ydim = ds[y].dims[0] + xdim = ds[x].dims[0] if len(ds[x].dims) == 1 else None + ydim = ds[y].dims[0] if len(ds[y].dims) == 1 else None if xdim is not None and ydim is None: - ydim = set(ds[y].dims) - {xdim} + ydims = set(ds[y].dims) - {xdim} + if len(ydims) == 1: + ydim = next(iter(ydims)) if ydim is not None and xdim is None: - xdim = set(ds[x].dims) - {ydim} + xdims = set(ds[x].dims) - {ydim} + if len(xdims) == 1: + xdim = next(iter(xdims)) - x, y, u, v = broadcast(ds[x], ds[y], ds[u], ds[v]) + dx, dy, du, dv = broadcast(ds[x], ds[y], ds[u], ds[v]) if xdim is not None and ydim is not None: # Need to ensure the arrays are transposed correctly - x = x.transpose(ydim, xdim) - y = y.transpose(ydim, xdim) - u = u.transpose(ydim, xdim) - v = v.transpose(ydim, xdim) + dx = dx.transpose(ydim, xdim) + dy = dy.transpose(ydim, xdim) + du = du.transpose(ydim, xdim) + dv = dv.transpose(ydim, xdim) - args = [x.values, y.values, u.values, v.values] + args = [dx.values, dy.values, du.values, dv.values] hue = kwargs.pop("hue") cmap_params = kwargs.pop("cmap_params") @@ -410,12 +647,12 @@ def streamplot(ds, x, y, ax, u, v, **kwargs): return hdl.lines -def _attach_to_plot_class(plotfunc: Callable) -> None: - """ - Set the function to the plot class and add a common docstring. +F = TypeVar("F", bound=Callable) - Use this decorator when relying on DataArray.plot methods for - creating the Dataset plot. + +def _update_doc_to_dataset(dataarray_plotfunc: Callable) -> Callable[[F], F]: + """ + Add a common docstring by re-using the DataArray one. TODO: Reduce code duplication. @@ -424,42 +661,48 @@ def _attach_to_plot_class(plotfunc: Callable) -> None: handle the conversion between Dataset and DataArray. * Improve docstring handling, maybe reword the DataArray versions to explain Datasets better. - * Consider automatically adding all _PlotMethods to - _Dataset_PlotMethods. Parameters ---------- - plotfunc : function + dataarray_plotfunc : Callable Function that returns a finished plot primitive. """ - # Build on the original docstring: - original_doc = getattr(_PlotMethods, plotfunc.__name__, object) - commondoc = original_doc.__doc__ - if commondoc is not None: - doc_warning = ( - f"This docstring was copied from xr.DataArray.plot.{original_doc.__name__}." - " Some inconsistencies may exist." - ) - # Add indentation so it matches the original doc: - commondoc = f"\n\n {doc_warning}\n\n {commondoc}" + + # Build on the original docstring + da_doc = dataarray_plotfunc.__doc__ + if da_doc is None: + raise NotImplementedError("DataArray plot method requires a docstring") + + da_str = """ + Parameters + ---------- + darray : DataArray + """ + ds_str = """ + + The `y` DataArray will be used as base, any other variables are added as coords. + + Parameters + ---------- + ds : Dataset + """ + # TODO: improve this? + if da_str in da_doc: + ds_doc = da_doc.replace(da_str, ds_str).replace("darray", "ds") else: - commondoc = "" - plotfunc.__doc__ = ( - f" {plotfunc.__doc__}\n\n" - " The `y` DataArray will be used as base," - " any other variables are added as coords.\n\n" - f"{commondoc}" - ) + ds_doc = da_doc - @functools.wraps(plotfunc) - def plotmethod(self, *args, **kwargs): - return plotfunc(self._ds, *args, **kwargs) + @functools.wraps(dataarray_plotfunc) + def wrapper(dataset_plotfunc: F) -> F: + dataset_plotfunc.__doc__ = ds_doc + return dataset_plotfunc - # Add to class _PlotMethods - setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) + return wrapper -def _normalize_args(plotmethod: str, args, kwargs) -> dict[str, Any]: +def _normalize_args( + plotmethod: str, args: tuple[Any, ...], kwargs: dict[str, Any] +) -> dict[str, Any]: from ..core.dataarray import DataArray # Determine positional arguments keyword by inspecting the @@ -474,7 +717,7 @@ def _normalize_args(plotmethod: str, args, kwargs) -> dict[str, Any]: return locals_ -def _temp_dataarray(ds: T_Dataset, y: Hashable, locals_: Mapping) -> DataArray: +def _temp_dataarray(ds: Dataset, y: Hashable, locals_: dict[str, Any]) -> DataArray: """Create a temporary datarray with extra coords.""" from ..core.dataarray import DataArray @@ -499,12 +742,175 @@ def _temp_dataarray(ds: T_Dataset, y: Hashable, locals_: Mapping) -> DataArray: return DataArray(_y, coords=coords) -@_attach_to_plot_class -def scatter(ds: T_Dataset, x: Hashable, y: Hashable, *args, **kwargs): +@overload +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> PathCollection: + ... + + +@overload +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@overload +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: + ... + + +@_update_doc_to_dataset(dataarray_plot.scatter) +def scatter( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> PathCollection | FacetGrid[DataArray]: """Scatter plot Dataset data variables against each other.""" - plotmethod = "scatter" - kwargs.update(x=x) - locals_ = _normalize_args(plotmethod, args, kwargs) + locals_ = locals() + del locals_["ds"] + locals_.update(locals_.pop("kwargs", {})) da = _temp_dataarray(ds, y, locals_) - return getattr(da.plot, plotmethod)(*locals_.pop("args", ()), **locals_) + return da.plot.scatter(*locals_.pop("args", ()), **locals_) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 5202489c1ec..c88fb8b9318 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -3,18 +3,28 @@ import functools import itertools import warnings -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Literal +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Generic, + Hashable, + Iterable, + Literal, + TypeVar, + cast, +) import numpy as np from ..core.formatting import format_item +from ..core.types import HueStyleOptions, T_Xarray from .utils import ( _LINEWIDTH_RANGE, _MARKERSIZE_RANGE, _add_legend, _determine_guide, _get_nice_quiver_magnitude, - _infer_meta_data, _infer_xy_labels, _Normalize, _parse_size, @@ -33,7 +43,7 @@ from matplotlib.text import Annotation from ..core.dataarray import DataArray - from ..core.types import HueStyleOptions, Self + # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams @@ -55,7 +65,10 @@ def _nicetitle(coord, value, maxchar, template): return title -class FacetGrid: +T_FacetGrid = TypeVar("T_FacetGrid", bound="FacetGrid") + + +class FacetGrid(Generic[T_Xarray]): """ Initialize the Matplotlib figure and FacetGrid object. @@ -96,7 +109,7 @@ class FacetGrid: sometimes the rightmost grid positions in the bottom row. """ - data: DataArray + data: T_Xarray name_dicts: np.ndarray fig: Figure axes: np.ndarray @@ -121,7 +134,7 @@ class FacetGrid: def __init__( self, - data: DataArray, + data: T_Xarray, col: Hashable | None = None, row: Hashable | None = None, col_wrap: int | None = None, @@ -135,8 +148,8 @@ def __init__( """ Parameters ---------- - data : DataArray - xarray DataArray to be plotted. + data : DataArray or Dataset + DataArray or Dataset to be plotted. row, col : str Dimension names that define subsets of the data, which will be drawn on separate facets in the grid. @@ -278,8 +291,12 @@ def _bottom_axes(self) -> np.ndarray: return self.axes[-1, :] def map_dataarray( - self, func: Callable, x: Hashable | None, y: Hashable | None, **kwargs: Any - ) -> FacetGrid: + self: T_FacetGrid, + func: Callable, + x: Hashable | None, + y: Hashable | None, + **kwargs: Any, + ) -> T_FacetGrid: """ Apply a plotting function to a 2d facet's subset of the data. @@ -347,8 +364,12 @@ def map_dataarray( return self def map_plot1d( - self, func: Callable, x: Hashable, y: Hashable, **kwargs: Any - ) -> Self: + self: T_FacetGrid, + func: Callable, + x: Hashable | None, + y: Hashable | None, + **kwargs: Any, + ) -> T_FacetGrid: """ Apply a plotting function to a 1d facet's subset of the data. @@ -385,18 +406,24 @@ def map_plot1d( hueplt_norm = _Normalize(hueplt) self._hue_var = hueplt cbar_kwargs = kwargs.pop("cbar_kwargs", {}) - if not hueplt_norm.data_is_numeric: - # TODO: Ticks seems a little too hardcoded, since it will always - # show all the values. But maybe it's ok, since plotting hundreds - # of categorical data isn't that meaningful anyway. - cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) - kwargs.update(levels=hueplt_norm.levels) - if "label" not in cbar_kwargs: - cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - func, hueplt_norm.values.to_numpy(), cbar_kwargs=cbar_kwargs, **kwargs - ) - self._cmap_extend = cmap_params.get("extend") + + if hueplt_norm.data is not None: + if not hueplt_norm.data_is_numeric: + # TODO: Ticks seems a little too hardcoded, since it will always + # show all the values. But maybe it's ok, since plotting hundreds + # of categorical data isn't that meaningful anyway. + cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) + kwargs.update(levels=hueplt_norm.levels) + + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, + cast("DataArray", hueplt_norm.values).data, + cbar_kwargs=cbar_kwargs, + **kwargs, + ) + self._cmap_extend = cmap_params.get("extend") + else: + cmap_params = {} # Handle sizes: _size_r = _MARKERSIZE_RANGE if func.__name__ == "scatter" else _LINEWIDTH_RANGE @@ -404,7 +431,7 @@ def map_plot1d( size = kwargs.get(_size, None) sizeplt = self.data[size] if size else None - sizeplt_norm = _Normalize(sizeplt, _size_r) + sizeplt_norm = _Normalize(data=sizeplt, width=_size_r) if size: self.data[size] = sizeplt_norm.values kwargs.update(**{_size: size}) @@ -496,12 +523,15 @@ def map_plot1d( if add_colorbar: # Colorbar is after legend so it correctly fits the plot: + if "label" not in cbar_kwargs: + cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) + self.add_colorbar(**cbar_kwargs) return self def map_dataarray_line( - self, + self: T_FacetGrid, func: Callable, x: Hashable | None, y: Hashable | None, @@ -509,8 +539,8 @@ def map_dataarray_line( add_legend: bool = True, _labels=None, **kwargs: Any, - ) -> FacetGrid: - from .plot import _infer_line_data + ) -> T_FacetGrid: + from .dataarray_plot import _infer_line_data for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value @@ -543,7 +573,7 @@ def map_dataarray_line( return self def map_dataset( - self, + self: T_FacetGrid, func: Callable, x: Hashable | None = None, y: Hashable | None = None, @@ -551,7 +581,8 @@ def map_dataset( hue_style: HueStyleOptions = None, add_guide: bool | None = None, **kwargs: Any, - ) -> FacetGrid: + ) -> T_FacetGrid: + from .dataset_plot import _infer_meta_data kwargs["add_guide"] = False @@ -706,7 +737,7 @@ def _get_largest_lims(self) -> dict[str, tuple[float, float]]: Examples -------- >>> ds = xr.tutorial.scatter_example_dataset(seed=42) - >>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w") + >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") >>> round(fg._get_largest_lims()["x"][0], 3) -0.334 """ @@ -748,7 +779,7 @@ def _set_lims( Examples -------- >>> ds = xr.tutorial.scatter_example_dataset(seed=42) - >>> fg = ds.plot.scatter("A", "B", hue="y", row="x", col="w") + >>> fg = ds.plot.scatter(x="A", y="B", hue="y", row="x", col="w") >>> fg._set_lims(x=(-0.3, 0.3), y=(0, 2), z=(0, 4)) >>> fg.axes[0, 0].get_xlim(), fg.axes[0, 0].get_ylim() ((-0.3, 0.3), (0.0, 2.0)) @@ -899,7 +930,9 @@ def set_ticks( ): tick.label1.set_fontsize(fontsize) - def map(self, func: Callable, *args: Any, **kwargs: Any) -> FacetGrid: + def map( + self: T_FacetGrid, func: Callable, *args: Hashable, **kwargs: Any + ) -> T_FacetGrid: """ Apply a plotting function to each facet's subset of the data. @@ -910,7 +943,7 @@ def map(self, func: Callable, *args: Any, **kwargs: Any) -> FacetGrid: must plot to the currently active matplotlib Axes and take a `color` keyword argument. If faceting on the `hue` dimension, it must also take a `label` keyword argument. - *args : strings + *args : Hashable Column names in self.data that identify variables with data to plot. The data for each variable is passed to `func` in the order the variables are specified in the call. @@ -941,7 +974,7 @@ def map(self, func: Callable, *args: Any, **kwargs: Any) -> FacetGrid: def _easy_facetgrid( - data: DataArray, + data: T_Xarray, plotfunc: Callable, kind: Literal["line", "dataarray", "dataset", "plot1d"], x: Hashable | None = None, @@ -957,7 +990,7 @@ def _easy_facetgrid( ax: Axes | None = None, figsize: Iterable[float] | None = None, **kwargs: Any, -) -> FacetGrid: +) -> FacetGrid[T_Xarray]: """ Convenience method to call xarray.plot.FacetGrid from 2d plotting methods @@ -1001,4 +1034,6 @@ def _easy_facetgrid( if kind == "dataset": return g.map_dataset(plotfunc, x, y, **kwargs) - raise ValueError(f"kind must be one of `line`, `dataarray`, `dataset`, got {kind}") + raise ValueError( + f"kind must be one of `line`, `dataarray`, `dataset` or `plot1d`, got {kind}" + ) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py deleted file mode 100644 index da2794159ba..00000000000 --- a/xarray/plot/plot.py +++ /dev/null @@ -1,1555 +0,0 @@ -""" -Use this module directly: - import xarray.plot as xplt - -Or use the methods on a DataArray or Dataset: - DataArray.plot._____ - Dataset.plot._____ -""" -from __future__ import annotations - -import functools -from typing import TYPE_CHECKING, Any, Hashable, Iterable, MutableMapping, Sequence - -import numpy as np -import pandas as pd -from packaging.version import Version - -from ..core.alignment import broadcast -from ..core.concat import concat -from .facetgrid import _easy_facetgrid -from .utils import ( - _LINEWIDTH_RANGE, - _MARKERSIZE_RANGE, - _add_colorbar, - _add_legend, - _assert_valid_xy, - _determine_guide, - _ensure_plottable, - _infer_interval_breaks, - _infer_xy_labels, - _Normalize, - _process_cmap_cbar_kwargs, - _rescale_imshow_rgb, - _resolve_intervals_1dplot, - _resolve_intervals_2dplot, - _update_axes, - get_axis, - import_matplotlib_pyplot, - label_from_attrs, -) - -if TYPE_CHECKING: - from ..core.types import T_DataArray - from .facetgrid import FacetGrid - - try: - import matplotlib.pyplot as plt - except ImportError: - plt: Any = None # type: ignore - - Collection = plt.matplotlib.collections.Collection - - -def _infer_line_data(darray, x, y, hue): - - ndims = len(darray.dims) - - if x is not None and y is not None: - raise ValueError("Cannot specify both x and y kwargs for line plots.") - - if x is not None: - _assert_valid_xy(darray, x, "x") - - if y is not None: - _assert_valid_xy(darray, y, "y") - - if ndims == 1: - huename = None - hueplt = None - huelabel = "" - - if x is not None: - xplt = darray[x] - yplt = darray - - elif y is not None: - xplt = darray - yplt = darray[y] - - else: # Both x & y are None - dim = darray.dims[0] - xplt = darray[dim] - yplt = darray - - else: - if x is None and y is None and hue is None: - raise ValueError("For 2D inputs, please specify either hue, x or y.") - - if y is None: - if hue is not None: - _assert_valid_xy(darray, hue, "hue") - xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename, transpose_coords=False) - xplt = xplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (xdim,) = darray[xname].dims - (huedim,) = darray[huename].dims - yplt = darray.transpose(xdim, huedim) - - else: - yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - xplt = darray.transpose(otherdim, huename, transpose_coords=False) - yplt = yplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (ydim,) = darray[yname].dims - (huedim,) = darray[huename].dims - xplt = darray.transpose(ydim, huedim) - - huelabel = label_from_attrs(darray[huename]) - hueplt = darray[huename] - - return xplt, yplt, hueplt, huelabel - - -def _infer_plot_dims( - darray: T_DataArray, - dims_plot: MutableMapping[str, Hashable], - default_guess: Iterable[str] = ("x", "hue", "size"), -) -> MutableMapping[str, Hashable]: - """ - Guess what dims to plot if some of the values in dims_plot are None which - happens when the user has not defined all available ways of visualizing - the data. - - Parameters - ---------- - darray : T_DataArray - The DataArray to check. - dims_plot : T_DimsPlot - Dims defined by the user to plot. - default_guess : Iterable[str], optional - Default values and order to retrieve dims if values in dims_plot is - missing, default: ("x", "hue", "size"). - """ - dims_plot_exist = {k: v for k, v in dims_plot.items() if v is not None} - dims_avail = tuple(v for v in darray.dims if v not in dims_plot_exist.values()) - - # If dims_plot[k] isn't defined then fill with one of the available dims: - for k, v in zip(default_guess, dims_avail): - if dims_plot.get(k, None) is None: - dims_plot[k] = v - - for k, v in dims_plot.items(): - _assert_valid_xy(darray, v, k) - - return dims_plot - - -def _infer_line_data2( - darray: T_DataArray, - dims_plot: MutableMapping[str, Hashable], - plotfunc_name: None | str = None, -) -> dict[str, T_DataArray]: - # Guess what dims to use if some of the values in plot_dims are None: - dims_plot = _infer_plot_dims(darray, dims_plot) - - # If there are more than 1 dimension in the array than stack all the - # dimensions so the plotter can plot anything: - if darray.ndim > 1: - # When stacking dims the lines will continue connecting. For floats - # this can be solved by adding a nan element inbetween the flattening - # points: - dims_T = [] - if np.issubdtype(darray.dtype, np.floating): - for v in ["z", "x"]: - dim = dims_plot.get(v, None) - if (dim is not None) and (dim in darray.dims): - darray_nan = np.nan * darray.isel({dim: -1}) - darray = concat([darray, darray_nan], dim=dim) - dims_T.append(dims_plot[v]) - - # Lines should never connect to the same coordinate when stacked, - # transpose to avoid this as much as possible: - darray = darray.transpose(..., *dims_T) - - # Array is now ready to be stacked: - darray = darray.stack(_stacked_dim=darray.dims) - - # Broadcast together all the chosen variables: - out = dict(y=darray) - out.update({k: darray[v] for k, v in dims_plot.items() if v is not None}) - out = dict(zip(out.keys(), broadcast(*(out.values())))) - - return out - - -def plot( - darray, - row=None, - col=None, - col_wrap=None, - ax=None, - hue=None, - rtol=0.01, - subplot_kws=None, - **kwargs, -): - """ - Default plot of DataArray using :py:mod:`matplotlib:matplotlib.pyplot`. - - Calls xarray plotting function based on the dimensions of - the squeezed DataArray. - - =============== =========================== - Dimensions Plotting function - =============== =========================== - 1 :py:func:`xarray.plot.line` - 2 :py:func:`xarray.plot.pcolormesh` - Anything else :py:func:`xarray.plot.hist` - =============== =========================== - - Parameters - ---------- - darray : DataArray - row : str, optional - If passed, make row faceted plots on this dimension name. - col : str, optional - If passed, make column faceted plots on this dimension name. - hue : str, optional - If passed, make faceted line plots with hue on this dimension name. - col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. - ax : matplotlib axes object, optional - Axes on which to plot. By default, use the current axes. - Mutually exclusive with ``size``, ``figsize`` and facets. - rtol : float, optional - Relative tolerance used to determine if the indexes - are uniformly spaced. Usually a small positive number. - subplot_kws : dict, optional - Dictionary of keyword arguments for Matplotlib subplots - (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). - **kwargs : optional - Additional keyword arguments for Matplotlib. - - See Also - -------- - xarray.DataArray.squeeze - """ - darray = darray.squeeze().compute() - - plot_dims = set(darray.dims) - plot_dims.discard(row) - plot_dims.discard(col) - plot_dims.discard(hue) - - ndims = len(plot_dims) - - error_msg = ( - "Only 1d and 2d plots are supported for facets in xarray. " - "See the package `Seaborn` for more options." - ) - - if ndims in [1, 2]: - if row or col: - kwargs["subplot_kws"] = subplot_kws - kwargs["row"] = row - kwargs["col"] = col - kwargs["col_wrap"] = col_wrap - if ndims == 1: - plotfunc = line - kwargs["hue"] = hue - elif ndims == 2: - if hue: - plotfunc = line - kwargs["hue"] = hue - else: - plotfunc = pcolormesh - kwargs["subplot_kws"] = subplot_kws - else: - if row or col or hue: - raise ValueError(error_msg) - plotfunc = hist - - kwargs["ax"] = ax - - return plotfunc(darray, **kwargs) - - -# This function signature should not change so that it can use -# matplotlib format strings -def line( - darray, - *args, - row=None, - col=None, - figsize=None, - aspect=None, - size=None, - ax=None, - hue=None, - x=None, - y=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - add_legend=True, - _labels=True, - **kwargs, -): - """ - Line plot of DataArray values. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.plot`. - - Parameters - ---------- - darray : DataArray - Either 1D or 2D. If 2D, one of ``hue``, ``x`` or ``y`` must be provided. - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the *width* in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size: - *height* (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axes on which to plot. By default, the current is used. - Mutually exclusive with ``size`` and ``figsize``. - hue : str, optional - Dimension or coordinate for which you want multiple lines plotted. - If plotting against a 2D coordinate, ``hue`` must be a dimension. - x, y : str, optional - Dimension, coordinate or multi-index level for *x*, *y* axis. - Only one of these may be specified. - The other will be used for values from the DataArray on which this - plot method is called. - xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional - Specifies scaling for the *x*- and *y*-axis, respectively. - xticks, yticks : array-like, optional - Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional - Specify *x*- and *y*-axis limits. - xincrease : None, True, or False, optional - Should the values on the *x* axis be increasing from left to right? - if ``None``, use the default for the Matplotlib function. - yincrease : None, True, or False, optional - Should the values on the *y* axis be increasing from top to bottom? - if ``None``, use the default for the Matplotlib function. - add_legend : bool, optional - Add legend with *y* axis coordinates (2D inputs only). - *args, **kwargs : optional - Additional arguments to :py:func:`matplotlib:matplotlib.pyplot.plot`. - """ - # Handle facetgrids first - if row or col: - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - return _easy_facetgrid(darray, line, kind="line", **allargs) - - ndims = len(darray.dims) - if ndims > 2: - raise ValueError( - "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) - ) - - # The allargs dict passed to _easy_facetgrid above contains args - if args == (): - args = kwargs.pop("args", ()) - else: - assert "args" not in kwargs - - ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) - - # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.to_numpy(), yplt.to_numpy(), kwargs - ) - xlabel = label_from_attrs(xplt, extra=x_suffix) - ylabel = label_from_attrs(yplt, extra=y_suffix) - - _ensure_plottable(xplt_val, yplt_val) - - primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) - - if _labels: - if xlabel is not None: - ax.set_xlabel(xlabel) - - if ylabel is not None: - ax.set_ylabel(ylabel) - - ax.set_title(darray._title_for_slice()) - - if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label) - - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - -def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): - """ - Step plot of DataArray values. - - Similar to :py:func:`matplotlib:matplotlib.pyplot.step`. - - Parameters - ---------- - where : {'pre', 'post', 'mid'}, default: 'pre' - Define where the steps should be placed: - - - ``'pre'``: The y value is continued constantly to the left from - every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the - value ``y[i]``. - - ``'post'``: The y value is continued constantly to the right from - every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the - value ``y[i]``. - - ``'mid'``: Steps occur half-way between the *x* positions. - - Note that this parameter is ignored if one coordinate consists of - :py:class:`pandas.Interval` values, e.g. as a result of - :py:func:`xarray.Dataset.groupby_bins`. In this case, the actual - boundaries of the interval are used. - *args, **kwargs : optional - Additional arguments for :py:func:`xarray.plot.line`. - """ - if where not in {"pre", "post", "mid"}: - raise ValueError("'where' argument to step must be 'pre', 'post' or 'mid'") - - if ds is not None: - if drawstyle is None: - drawstyle = ds - else: - raise TypeError("ds and drawstyle are mutually exclusive") - if drawstyle is None: - drawstyle = "" - drawstyle = "steps-" + where + drawstyle - - return line(darray, *args, drawstyle=drawstyle, **kwargs) - - -def hist( - darray, - figsize=None, - size=None, - aspect=None, - ax=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - **kwargs, -): - """ - Histogram of DataArray. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.hist`. - - Plots *N*-dimensional arrays by first flattening the array. - - Parameters - ---------- - darray : DataArray - Can have any number of dimensions. - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the *width* in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size: - *height* (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axes on which to plot. By default, use the current axes. - Mutually exclusive with ``size`` and ``figsize``. - **kwargs : optional - Additional keyword arguments to :py:func:`matplotlib:matplotlib.pyplot.hist`. - - """ - ax = get_axis(figsize, size, aspect, ax) - - no_nan = np.ravel(darray.to_numpy()) - no_nan = no_nan[pd.notnull(no_nan)] - - primitive = ax.hist(no_nan, **kwargs) - - ax.set_title(darray._title_for_slice()) - ax.set_xlabel(label_from_attrs(darray)) - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - -# MUST run before any 2d plotting functions are defined since -# _plot2d decorator adds them as methods here. -class _PlotMethods: - """ - Enables use of xarray.plot functions as attributes on a DataArray. - For example, DataArray.plot.imshow - """ - - __slots__ = ("_da",) - - def __init__(self, darray): - self._da = darray - - def __call__(self, **kwargs): - return plot(self._da, **kwargs) - - # we can't use functools.wraps here since that also modifies the name / qualname - __doc__ = __call__.__doc__ = plot.__doc__ - __call__.__wrapped__ = plot # type: ignore[attr-defined] - __call__.__annotations__ = plot.__annotations__ - - @functools.wraps(hist) - def hist(self, ax=None, **kwargs): - return hist(self._da, ax=ax, **kwargs) - - @functools.wraps(line) - def line(self, *args, **kwargs): - return line(self._da, *args, **kwargs) - - @functools.wraps(step) - def step(self, *args, **kwargs): - return step(self._da, *args, **kwargs) - - -def override_signature(f): - def wrapper(func): - func.__wrapped__ = f - - return func - - return wrapper - - -def _plot1d(plotfunc): - """ - Decorator for common 1d plotting logic. - - Also adds the 1d plot method to class _PlotMethods. - """ - commondoc = """ - Parameters - ---------- - darray : DataArray - Must be 2 dimensional, unless creating faceted plots - x : string, optional - Coordinate for x axis. If None use darray.dims[1] - y : string, optional - Coordinate for y axis. If None use darray.dims[0] - hue : string, optional - Dimension or coordinate for which you want multiple lines plotted. - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib.axes.Axes, optional - Axis on which to plot this figure. By default, use the current axis. - Mutually exclusive with ``size`` and ``figsize``. - row : string, optional - If passed, make row faceted plots on this dimension name - col : string, optional - If passed, make column faceted plots on this dimension name - col_wrap : int, optional - Use together with ``col`` to wrap faceted plots - xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional - Specifies scaling for the x- and y-axes respectively - xticks, yticks : Specify tick locations for x- and y-axes - xlim, ylim : Specify x- and y-axes limits - xincrease : None, True, or False, optional - Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function. - yincrease : None, True, or False, optional - Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function. - add_labels : bool, optional - Use xarray metadata to label axes - subplot_kws : dict, optional - Dictionary of keyword arguments for matplotlib subplots. Only used - for FacetGrid plots. - **kwargs : optional - Additional arguments to wrapped matplotlib function - - Returns - ------- - artist : - The same type of primitive artist that the wrapped matplotlib - function returns - """ - - # Build on the original docstring - plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" - - # plotfunc and newplotfunc have different signatures: - # - plotfunc: (x, y, z, ax, **kwargs) - # - newplotfunc: (darray, *args, x, y, **kwargs) - # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray - # and variable names. newplotfunc also explicitly lists most kwargs, so we - # need to shorten it - def signature( - darray: T_DataArray, *args, x: Hashable, **kwargs - ) -> Collection | FacetGrid: - pass - - @override_signature(signature) - @functools.wraps(plotfunc) - def newplotfunc( - darray: T_DataArray, - *args, - x: Hashable = None, - y: Hashable = None, - z: Hashable = None, - hue: Hashable = None, - hue_style=None, - markersize: Hashable = None, - linewidth: Hashable = None, - figsize=None, - size=None, - aspect=None, - ax=None, - row: Hashable = None, - col: Hashable = None, - col_wrap=None, - xincrease=True, - yincrease=True, - add_legend: bool | None = None, - add_colorbar: bool | None = None, - add_labels: bool | Sequence[bool] = True, - add_title: bool = True, - subplot_kws: dict | None = None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - cmap=None, - vmin=None, - vmax=None, - norm=None, - extend=None, - levels=None, - **kwargs, - ) -> Collection | FacetGrid: - # All 1d plots in xarray share this function signature. - # Method signature below should be consistent. - - if subplot_kws is None: - subplot_kws = dict() - - # Handle facetgrids first - if row or col: - if z is not None: - subplot_kws.update(projection="3d") - - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - allargs["plotfunc"] = globals()[plotfunc.__name__] - - return _easy_facetgrid(darray, kind="plot1d", **allargs) - - # The allargs dict passed to _easy_facetgrid above contains args - if args == (): - args = kwargs.pop("args", ()) - else: - assert "args" not in kwargs - - _is_facetgrid = kwargs.pop("_is_facetgrid", False) - - if markersize is not None: - size_ = markersize - size_r = _MARKERSIZE_RANGE - else: - size_ = linewidth - size_r = _LINEWIDTH_RANGE - - # Get data to plot: - dims_plot = dict(x=x, z=z, hue=hue, size=size_) - plts = _infer_line_data2(darray, dims_plot, plotfunc.__name__) - xplt = plts.pop("x", None) - yplt = plts.pop("y", None) - zplt = plts.pop("z", None) - kwargs.update(zplt=zplt) - hueplt = plts.pop("hue", None) - sizeplt = plts.pop("size", None) - - # Handle size and hue: - hueplt_norm = _Normalize(hueplt) - kwargs.update(hueplt=hueplt_norm.values) - sizeplt_norm = _Normalize(sizeplt, size_r, _is_facetgrid) - kwargs.update(sizeplt=sizeplt_norm.values) - cmap_params_subset = kwargs.pop("cmap_params_subset", {}) - cbar_kwargs = kwargs.pop("cbar_kwargs", {}) - - if hueplt_norm.data is not None: - if not hueplt_norm.data_is_numeric: - # Map hue values back to its original value: - cbar_kwargs.update(format=hueplt_norm.format, ticks=hueplt_norm.ticks) - levels = kwargs.get("levels", hueplt_norm.levels) - - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - plotfunc, - hueplt_norm.values.data, - **locals(), - ) - - # subset that can be passed to scatter, hist2d - if not cmap_params_subset: - ckw = {vv: cmap_params[vv] for vv in ("vmin", "vmax", "norm", "cmap")} - cmap_params_subset.update(**ckw) - - if z is not None: - if ax is None: - subplot_kws.update(projection="3d") - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - # Using 30, 30 minimizes rotation of the plot. Making it easier to - # build on your intuition from 2D plots: - plt = import_matplotlib_pyplot() - if Version(plt.matplotlib.__version__) < Version("3.5.0"): - ax.view_init(azim=30, elev=30) - else: - # https://github.com/matplotlib/matplotlib/pull/19873 - ax.view_init(azim=30, elev=30, vertical_axis="y") - else: - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - - primitive = plotfunc( - xplt, - yplt, - *args, - ax=ax, - add_labels=add_labels, - **cmap_params_subset, - **kwargs, - ) - - if np.any(add_labels) and add_title: - ax.set_title(darray._title_for_slice()) - - add_colorbar_, add_legend_ = _determine_guide( - hueplt_norm, - sizeplt_norm, - add_colorbar, - add_legend, - plotfunc_name=plotfunc.__name__, - ) - - if add_colorbar_: - if "label" not in cbar_kwargs: - cbar_kwargs["label"] = label_from_attrs(hueplt_norm.data) - - _add_colorbar( - primitive, ax, kwargs.get("cbar_ax", None), cbar_kwargs, cmap_params - ) - - if add_legend_: - if plotfunc.__name__ == "hist": - ax.legend( - handles=primitive[-1], - labels=list(hueplt_norm.values.to_numpy()), - title=label_from_attrs(hueplt_norm.data), - ) - elif plotfunc.__name__ in ["scatter", "line"]: - _add_legend( - hueplt_norm - if add_legend or not add_colorbar_ - else _Normalize(None), - sizeplt_norm, - primitive, - legend_ax=ax, - plotfunc=plotfunc.__name__, - ) - else: - ax.legend( - handles=primitive, - labels=list(hueplt_norm.values.to_numpy()), - title=label_from_attrs(hueplt_norm.data), - ) - - _update_axes( - ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim - ) - - return primitive - - # For use as DataArray.plot.plotmethod - @functools.wraps(newplotfunc) - def plotmethod( - _PlotMethods_obj, - *args, - x: Hashable = None, - y: Hashable = None, - z: Hashable = None, - hue: Hashable = None, - hue_style=None, - markersize: Hashable = None, - linewidth: Hashable = None, - figsize=None, - size=None, - aspect=None, - ax=None, - row: Hashable = None, - col: Hashable = None, - col_wrap=None, - xincrease=True, - yincrease=True, - add_legend: bool | None = None, - add_colorbar: bool | None = None, - add_labels: bool | Sequence[bool] = True, - subplot_kws=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - cmap=None, - vmin=None, - vmax=None, - norm=None, - extend=None, - levels=None, - **kwargs, - ) -> Collection: - """ - The method should have the same signature as the function. - - This just makes the method work on Plotmethods objects, - and passes all the other arguments straight through. - """ - allargs = locals().copy() - allargs["darray"] = _PlotMethods_obj._da - allargs.update(kwargs) - for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: - del allargs[arg] - return newplotfunc(**allargs) - - # Add to class _PlotMethods - setattr(_PlotMethods, plotmethod.__name__, plotmethod) - - return newplotfunc - - -def _add_labels( - add_labels: bool | Sequence[bool], - darrays: Sequence[T_DataArray], - suffixes: Iterable[str], - rotate_labels: Iterable[bool], - ax, -) -> None: - # Set x, y, z labels: - add_labels = [add_labels] * 3 if isinstance(add_labels, bool) else add_labels - for axis, add_label, darray, suffix, rotate_label in zip( - ("x", "y", "z"), add_labels, darrays, suffixes, rotate_labels - ): - if darray is None: - continue - - if add_label: - label = label_from_attrs(darray, extra=suffix) - if label is not None: - getattr(ax, f"set_{axis}label")(label) - - if rotate_label and np.issubdtype(darray.dtype, np.datetime64): - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - for labels in getattr(ax, f"get_{axis}ticklabels")(): - labels.set_rotation(30) - labels.set_ha("right") - - -@_plot1d -def scatter( - xplt, yplt, *args, ax, add_labels: bool | Sequence[bool] = True, **kwargs -) -> plt.scatter: - plt = import_matplotlib_pyplot() - - zplt = kwargs.pop("zplt", None) - hueplt = kwargs.pop("hueplt", None) - sizeplt = kwargs.pop("sizeplt", None) - - # Add a white border to make it easier seeing overlapping markers: - kwargs.update(edgecolors=kwargs.pop("edgecolors", "w")) - - if hueplt is not None: - kwargs.update(c=hueplt.to_numpy().ravel()) - - if sizeplt is not None: - kwargs.update(s=sizeplt.to_numpy().ravel()) - - if Version(plt.matplotlib.__version__) < Version("3.5.0"): - # Plot the data. 3d plots has the z value in upward direction - # instead of y. To make jumping between 2d and 3d easy and intuitive - # switch the order so that z is shown in the depthwise direction: - axis_order = ["x", "z", "y"] - else: - # Switching axis order not needed in 3.5.0, can also simplify the code - # that uses axis_order: - # https://github.com/matplotlib/matplotlib/pull/19873 - axis_order = ["x", "y", "z"] - - plts_dict = dict(x=xplt, y=yplt, z=zplt) - plts = [plts_dict[v] for v in axis_order if plts_dict[v] is not None] - primitive = ax.scatter(*[v.to_numpy().ravel() for v in plts], **kwargs) - _add_labels(add_labels, plts, ("", "", ""), (True, False, False), ax) - - return primitive - - -def _plot2d(plotfunc): - """ - Decorator for common 2d plotting logic - - Also adds the 2d plot method to class _PlotMethods - """ - commondoc = """ - Parameters - ---------- - darray : DataArray - Must be two-dimensional, unless creating faceted plots. - x : str, optional - Coordinate for *x* axis. If ``None``, use ``darray.dims[1]``. - y : str, optional - Coordinate for *y* axis. If ``None``, use ``darray.dims[0]``. - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the *width* in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size: - *height* (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axes on which to plot. By default, use the current axes. - Mutually exclusive with ``size`` and ``figsize``. - row : string, optional - If passed, make row faceted plots on this dimension name. - col : string, optional - If passed, make column faceted plots on this dimension name. - col_wrap : int, optional - Use together with ``col`` to wrap faceted plots. - xscale, yscale : {'linear', 'symlog', 'log', 'logit'}, optional - Specifies scaling for the *x*- and *y*-axis, respectively. - xticks, yticks : array-like, optional - Specify tick locations for *x*- and *y*-axis. - xlim, ylim : array-like, optional - Specify *x*- and *y*-axis limits. - xincrease : None, True, or False, optional - Should the values on the *x* axis be increasing from left to right? - If ``None``, use the default for the Matplotlib function. - yincrease : None, True, or False, optional - Should the values on the *y* axis be increasing from top to bottom? - If ``None``, use the default for the Matplotlib function. - add_colorbar : bool, optional - Add colorbar to axes. - add_labels : bool, optional - Use xarray metadata to label axes. - norm : matplotlib.colors.Normalize, optional - If ``norm`` has ``vmin`` or ``vmax`` specified, the corresponding - kwarg must be ``None``. - vmin, vmax : float, optional - Values to anchor the colormap, otherwise they are inferred from the - data and other keyword arguments. When a diverging dataset is inferred, - setting one of these values will fix the other by symmetry around - ``center``. Setting both values prevents use of a diverging colormap. - If discrete levels are provided as an explicit list, both of these - values are ignored. - cmap : matplotlib colormap name or colormap, optional - The mapping from data values to color space. If not provided, this - will be either be ``'viridis'`` (if the function infers a sequential - dataset) or ``'RdBu_r'`` (if the function infers a diverging dataset). - See :doc:`Choosing Colormaps in Matplotlib ` - for more information. - - If *seaborn* is installed, ``cmap`` may also be a - `seaborn color palette `_. - Note: if ``cmap`` is a seaborn color palette and the plot type - is not ``'contour'`` or ``'contourf'``, ``levels`` must also be specified. - colors : str or array-like of color-like, optional - A single color or a sequence of colors. If the plot type is not ``'contour'`` - or ``'contourf'``, the ``levels`` argument is required. - center : float, optional - The value at which to center the colormap. Passing this value implies - use of a diverging colormap. Setting it to ``False`` prevents use of a - diverging colormap. - robust : bool, optional - If ``True`` and ``vmin`` or ``vmax`` are absent, the colormap range is - computed with 2nd and 98th percentiles instead of the extreme values. - extend : {'neither', 'both', 'min', 'max'}, optional - How to draw arrows extending the colorbar beyond its limits. If not - provided, ``extend`` is inferred from ``vmin``, ``vmax`` and the data limits. - levels : int or array-like, optional - Split the colormap (``cmap``) into discrete color intervals. If an integer - is provided, "nice" levels are chosen based on the data range: this can - imply that the final number of levels is not exactly the expected one. - Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to - setting ``levels=np.linspace(vmin, vmax, N)``. - infer_intervals : bool, optional - Only applies to pcolormesh. If ``True``, the coordinate intervals are - passed to pcolormesh. If ``False``, the original coordinates are used - (this can be useful for certain map projections). The default is to - always infer intervals, unless the mesh is irregular and plotted on - a map projection. - subplot_kws : dict, optional - Dictionary of keyword arguments for Matplotlib subplots. Only used - for 2D and faceted plots. - (see :py:meth:`matplotlib:matplotlib.figure.Figure.add_subplot`). - cbar_ax : matplotlib axes object, optional - Axes in which to draw the colorbar. - cbar_kwargs : dict, optional - Dictionary of keyword arguments to pass to the colorbar - (see :meth:`matplotlib:matplotlib.figure.Figure.colorbar`). - **kwargs : optional - Additional keyword arguments to wrapped Matplotlib function. - - Returns - ------- - artist : - The same type of primitive artist that the wrapped Matplotlib - function returns. - """ - - # Build on the original docstring - plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" - - # plotfunc and newplotfunc have different signatures: - # - plotfunc: (x, y, z, ax, **kwargs) - # - newplotfunc: (darray, x, y, **kwargs) - # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray - # and variable names. newplotfunc also explicitly lists most kwargs, so we - # need to shorten it - def signature(darray, x, y, **kwargs): - pass - - @override_signature(signature) - @functools.wraps(plotfunc) - def newplotfunc( - darray, - x=None, - y=None, - figsize=None, - size=None, - aspect=None, - ax=None, - row=None, - col=None, - col_wrap=None, - xincrease=True, - yincrease=True, - add_colorbar=None, - add_labels=True, - vmin=None, - vmax=None, - cmap=None, - center=None, - robust=False, - extend=None, - levels=None, - infer_intervals=None, - colors=None, - subplot_kws=None, - cbar_ax=None, - cbar_kwargs=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - norm=None, - **kwargs, - ): - # All 2d plots in xarray share this function signature. - # Method signature below should be consistent. - - # Decide on a default for the colorbar before facetgrids - if add_colorbar is None: - add_colorbar = True - if plotfunc.__name__ == "contour" or ( - plotfunc.__name__ == "surface" and cmap is None - ): - add_colorbar = False - imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == ( - 3 + (row is not None) + (col is not None) - ) - if imshow_rgb: - # Don't add a colorbar when showing an image with explicit colors - add_colorbar = False - # Matplotlib does not support normalising RGB data, so do it here. - # See eg. https://github.com/matplotlib/matplotlib/pull/10220 - if robust or vmax is not None or vmin is not None: - darray = _rescale_imshow_rgb(darray.as_numpy(), vmin, vmax, robust) - vmin, vmax, robust = None, None, False - - if subplot_kws is None: - subplot_kws = dict() - - if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False): - if ax is None: - # TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2. - # Remove when minimum requirement of matplotlib is 3.2: - from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401 - - # delete so it does not end up in locals() - del Axes3D - - # Need to create a "3d" Axes instance for surface plots - subplot_kws["projection"] = "3d" - - # In facet grids, shared axis labels don't make sense for surface plots - sharex = False - sharey = False - - # Handle facetgrids first - if row or col: - allargs = locals().copy() - del allargs["darray"] - del allargs["imshow_rgb"] - allargs.update(allargs.pop("kwargs")) - # Need the decorated plotting function - allargs["plotfunc"] = globals()[plotfunc.__name__] - return _easy_facetgrid(darray, kind="dataarray", **allargs) - - plt = import_matplotlib_pyplot() - - if ( - plotfunc.__name__ == "surface" - and not kwargs.get("_is_facetgrid", False) - and ax is not None - ): - import mpl_toolkits # type: ignore - - if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D): - raise ValueError( - "If ax is passed to surface(), it must be created with " - 'projection="3d"' - ) - - rgb = kwargs.pop("rgb", None) - if rgb is not None and plotfunc.__name__ != "imshow": - raise ValueError('The "rgb" keyword is only valid for imshow()') - elif rgb is not None and not imshow_rgb: - raise ValueError( - 'The "rgb" keyword is only valid for imshow()' - "with a three-dimensional array (per facet)" - ) - - xlab, ylab = _infer_xy_labels( - darray=darray, x=x, y=y, imshow=imshow_rgb, rgb=rgb - ) - - xval = darray[xlab] - yval = darray[ylab] - - if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface": - # Passing 2d coordinate values, need to ensure they are transposed the same - # way as darray. - # Also surface plots always need 2d coordinates - xval = xval.broadcast_like(darray) - yval = yval.broadcast_like(darray) - dims = darray.dims - else: - dims = (yval.dims[0], xval.dims[0]) - - # May need to transpose for correct x, y labels - # xlab may be the name of a coord, we have to check for dim names - if imshow_rgb: - # For RGB[A] images, matplotlib requires the color dimension - # to be last. In Xarray the order should be unimportant, so - # we transpose to (y, x, color) to make this work. - yx_dims = (ylab, xlab) - dims = yx_dims + tuple(d for d in darray.dims if d not in yx_dims) - - if dims != darray.dims: - darray = darray.transpose(*dims, transpose_coords=True) - - # better to pass the ndarrays directly to plotting functions - xval = xval.to_numpy() - yval = yval.to_numpy() - - # Pass the data as a masked ndarray too - zval = darray.to_masked_array(copy=False) - - # Replace pd.Intervals if contained in xval or yval. - xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) - yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) - - _ensure_plottable(xplt, yplt, zval) - - cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( - plotfunc, - zval.data, - **locals(), - _is_facetgrid=kwargs.pop("_is_facetgrid", False), - ) - - if "contour" in plotfunc.__name__: - # extend is a keyword argument only for contour and contourf, but - # passing it to the colorbar is sufficient for imshow and - # pcolormesh - kwargs["extend"] = cmap_params["extend"] - kwargs["levels"] = cmap_params["levels"] - # if colors == a single color, matplotlib draws dashed negative - # contours. we lose this feature if we pass cmap and not colors - if isinstance(colors, str): - cmap_params["cmap"] = None - kwargs["colors"] = colors - - if "pcolormesh" == plotfunc.__name__: - kwargs["infer_intervals"] = infer_intervals - kwargs["xscale"] = xscale - kwargs["yscale"] = yscale - - if "imshow" == plotfunc.__name__ and isinstance(aspect, str): - # forbid usage of mpl strings - raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray") - - ax = get_axis(figsize, size, aspect, ax, **subplot_kws) - - primitive = plotfunc( - xplt, - yplt, - zval, - ax=ax, - cmap=cmap_params["cmap"], - vmin=cmap_params["vmin"], - vmax=cmap_params["vmax"], - norm=cmap_params["norm"], - **kwargs, - ) - - # Label the plot with metadata - if add_labels: - ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) - ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) - ax.set_title(darray._title_for_slice()) - if plotfunc.__name__ == "surface": - ax.set_zlabel(label_from_attrs(darray)) - - if add_colorbar: - if add_labels and "label" not in cbar_kwargs: - cbar_kwargs["label"] = label_from_attrs(darray) - cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) - elif cbar_ax is not None or cbar_kwargs: - # inform the user about keywords which aren't used - raise ValueError( - "cbar_ax and cbar_kwargs can't be used with add_colorbar=False." - ) - - # origin kwarg overrides yincrease - if "origin" in kwargs: - yincrease = None - - _update_axes( - ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim - ) - - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") - - return primitive - - # For use as DataArray.plot.plotmethod - @functools.wraps(newplotfunc) - def plotmethod( - _PlotMethods_obj, - x=None, - y=None, - figsize=None, - size=None, - aspect=None, - ax=None, - row=None, - col=None, - col_wrap=None, - xincrease=True, - yincrease=True, - add_colorbar=None, - add_labels=True, - vmin=None, - vmax=None, - cmap=None, - colors=None, - center=None, - robust=False, - extend=None, - levels=None, - infer_intervals=None, - subplot_kws=None, - cbar_ax=None, - cbar_kwargs=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - norm=None, - **kwargs, - ): - """ - The method should have the same signature as the function. - - This just makes the method work on Plotmethods objects, - and passes all the other arguments straight through. - """ - allargs = locals() - allargs["darray"] = _PlotMethods_obj._da - allargs.update(kwargs) - for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: - del allargs[arg] - return newplotfunc(**allargs) - - # Add to class _PlotMethods - setattr(_PlotMethods, plotmethod.__name__, plotmethod) - - return newplotfunc - - -@_plot2d -def imshow(x, y, z, ax, **kwargs): - """ - Image plot of 2D DataArray. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.imshow`. - - While other plot methods require the DataArray to be strictly - two-dimensional, ``imshow`` also accepts a 3D array where some - dimension can be interpreted as RGB or RGBA color channels and - allows this dimension to be specified via the kwarg ``rgb=``. - - Unlike :py:func:`matplotlib:matplotlib.pyplot.imshow`, which ignores ``vmin``/``vmax`` - for RGB(A) data, - xarray *will* use ``vmin`` and ``vmax`` for RGB(A) data - by applying a single scaling factor and offset to all bands. - Passing ``robust=True`` infers ``vmin`` and ``vmax`` - :ref:`in the usual way `. - - .. note:: - This function needs uniformly spaced coordinates to - properly label the axes. Call :py:meth:`DataArray.plot` to check. - - The pixels are centered on the coordinates. For example, if the coordinate - value is 3.2, then the pixels for those coordinates will be centered on 3.2. - """ - - if x.ndim != 1 or y.ndim != 1: - raise ValueError( - "imshow requires 1D coordinates, try using pcolormesh or contour(f)" - ) - - def _center_pixels(x): - """Center the pixels on the coordinates.""" - if np.issubdtype(x.dtype, str): - # When using strings as inputs imshow converts it to - # integers. Choose extent values which puts the indices in - # in the center of the pixels: - return 0 - 0.5, len(x) - 0.5 - - try: - # Center the pixels assuming uniform spacing: - xstep = 0.5 * (x[1] - x[0]) - except IndexError: - # Arbitrary default value, similar to matplotlib behaviour: - xstep = 0.1 - - return x[0] - xstep, x[-1] + xstep - - # Center the pixels: - left, right = _center_pixels(x) - top, bottom = _center_pixels(y) - - defaults = {"origin": "upper", "interpolation": "nearest"} - - if not hasattr(ax, "projection"): - # not for cartopy geoaxes - defaults["aspect"] = "auto" - - # Allow user to override these defaults - defaults.update(kwargs) - - if defaults["origin"] == "upper": - defaults["extent"] = [left, right, bottom, top] - else: - defaults["extent"] = [left, right, top, bottom] - - if z.ndim == 3: - # matplotlib imshow uses black for missing data, but Xarray makes - # missing data transparent. We therefore add an alpha channel if - # there isn't one, and set it to transparent where data is masked. - if z.shape[-1] == 3: - alpha = np.ma.ones(z.shape[:2] + (1,), dtype=z.dtype) - if np.issubdtype(z.dtype, np.integer): - alpha *= 255 - z = np.ma.concatenate((z, alpha), axis=2) - else: - z = z.copy() - z[np.any(z.mask, axis=-1), -1] = 0 - - primitive = ax.imshow(z, **defaults) - - # If x or y are strings the ticklabels have been replaced with - # integer indices. Replace them back to strings: - for axis, v in [("x", x), ("y", y)]: - if np.issubdtype(v.dtype, str): - getattr(ax, f"set_{axis}ticks")(np.arange(len(v))) - getattr(ax, f"set_{axis}ticklabels")(v) - - return primitive - - -@_plot2d -def contour(x, y, z, ax, **kwargs): - """ - Contour plot of 2D DataArray. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.contour`. - """ - primitive = ax.contour(x, y, z, **kwargs) - return primitive - - -@_plot2d -def contourf(x, y, z, ax, **kwargs): - """ - Filled contour plot of 2D DataArray. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.contourf`. - """ - primitive = ax.contourf(x, y, z, **kwargs) - return primitive - - -@_plot2d -def pcolormesh(x, y, z, ax, xscale=None, yscale=None, infer_intervals=None, **kwargs): - """ - Pseudocolor plot of 2D DataArray. - - Wraps :py:func:`matplotlib:matplotlib.pyplot.pcolormesh`. - """ - - # decide on a default for infer_intervals (GH781) - x = np.asarray(x) - if infer_intervals is None: - if hasattr(ax, "projection"): - if len(x.shape) == 1: - infer_intervals = True - else: - infer_intervals = False - else: - infer_intervals = True - - if ( - infer_intervals - and not np.issubdtype(x.dtype, str) - and ( - (np.shape(x)[0] == np.shape(z)[1]) - or ((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])) - ) - ): - if len(x.shape) == 1: - x = _infer_interval_breaks(x, check_monotonic=True, scale=xscale) - else: - # we have to infer the intervals on both axes - x = _infer_interval_breaks(x, axis=1, scale=xscale) - x = _infer_interval_breaks(x, axis=0, scale=xscale) - - if ( - infer_intervals - and not np.issubdtype(y.dtype, str) - and (np.shape(y)[0] == np.shape(z)[0]) - ): - if len(y.shape) == 1: - y = _infer_interval_breaks(y, check_monotonic=True, scale=yscale) - else: - # we have to infer the intervals on both axes - y = _infer_interval_breaks(y, axis=1, scale=yscale) - y = _infer_interval_breaks(y, axis=0, scale=yscale) - - primitive = ax.pcolormesh(x, y, z, **kwargs) - - # by default, pcolormesh picks "round" values for bounds - # this results in ugly looking plots with lots of surrounding whitespace - if not hasattr(ax, "projection") and x.ndim == 1 and y.ndim == 1: - # not a cartopy geoaxis - ax.set_xlim(x[0], x[-1]) - ax.set_ylim(y[0], y[-1]) - - return primitive - - -@_plot2d -def surface(x, y, z, ax, **kwargs): - """ - Surface plot of 2D DataArray. - - Wraps :py:meth:`matplotlib:mpl_toolkits.mplot3d.axes3d.Axes3D.plot_surface`. - """ - primitive = ax.plot_surface(x, y, z, **kwargs) - return primitive diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index d1fe0cd0bb7..e27695c4347 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -5,7 +5,16 @@ import warnings from datetime import datetime from inspect import getfullargspec -from typing import TYPE_CHECKING, Any, Callable, Hashable, Iterable, Mapping, Sequence +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Hashable, + Iterable, + Mapping, + Sequence, + overload, +) import numpy as np import pandas as pd @@ -31,8 +40,13 @@ if TYPE_CHECKING: from matplotlib.axes import Axes + from matplotlib.colors import Normalize + from matplotlib.ticker import FuncFormatter + from numpy.typing import ArrayLike from ..core.dataarray import DataArray + from ..core.dataset import Dataset + from ..core.types import AspectOptions, ScaleOptions try: import matplotlib.pyplot as plt @@ -42,8 +56,8 @@ ROBUST_PERCENTILE = 2.0 # copied from seaborn -_MARKERSIZE_RANGE = np.array([18.0, 72.0]) -_LINEWIDTH_RANGE = np.array([1.5, 6.0]) +_MARKERSIZE_RANGE = (18.0, 72.0) +_LINEWIDTH_RANGE = (1.5, 6.0) def import_matplotlib_pyplot(): @@ -319,7 +333,10 @@ def _determine_cmap_params( def _infer_xy_labels_3d( - darray: DataArray, x: Hashable | None, y: Hashable | None, rgb: Hashable | None + darray: DataArray | Dataset, + x: Hashable | None, + y: Hashable | None, + rgb: Hashable | None, ) -> tuple[Hashable, Hashable]: """ Determine x and y labels for showing RGB images. @@ -378,7 +395,7 @@ def _infer_xy_labels_3d( def _infer_xy_labels( - darray: DataArray, + darray: DataArray | Dataset, x: Hashable | None, y: Hashable | None, imshow: bool = False, @@ -417,7 +434,9 @@ def _infer_xy_labels( # TODO: Can by used to more than x or y, rename? -def _assert_valid_xy(darray: DataArray, xy: Hashable | None, name: str) -> None: +def _assert_valid_xy( + darray: DataArray | Dataset, xy: Hashable | None, name: str +) -> None: """ make sure x and y passed to plotting functions are valid """ @@ -441,7 +460,7 @@ def _assert_valid_xy(darray: DataArray, xy: Hashable | None, name: str) -> None: def get_axis( figsize: Iterable[float] | None = None, size: float | None = None, - aspect: float | None = None, + aspect: AspectOptions = None, ax: Axes | None = None, **subplot_kws: Any, ) -> Axes: @@ -462,10 +481,14 @@ def get_axis( if size is not None: if ax is not None: raise ValueError("cannot provide both `size` and `ax` arguments") - if aspect is None: + if aspect is None or aspect == "auto": width, height = mpl.rcParams["figure.figsize"] - aspect = width / height - figsize = (size * aspect, size) + faspect = width / height + elif aspect == "equal": + faspect = 1 + else: + faspect = aspect + figsize = (size * faspect, size) _, ax = plt.subplots(figsize=figsize, subplot_kw=subplot_kws) return ax @@ -757,16 +780,16 @@ def _rescale_imshow_rgb(darray, vmin, vmax, robust): def _update_axes( - ax, - xincrease, - yincrease, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, -): + ax: Axes, + xincrease: bool | None, + yincrease: bool | None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, +) -> None: """ Update axes with provided parameters """ @@ -885,7 +908,7 @@ def _process_cmap_cbar_kwargs( levels=None, _is_facetgrid=False, **kwargs, -): +) -> tuple[dict[str, Any], dict[str, Any]]: """ Parameters ---------- @@ -895,8 +918,8 @@ def _process_cmap_cbar_kwargs( Returns ------- - cmap_params - cbar_kwargs + cmap_params : dict + cbar_kwargs : dict """ if func.__name__ == "surface": # Leave user to specify cmap settings for surface plots @@ -1284,21 +1307,40 @@ def _infer_meta_data(ds, x, y, hue, hue_style, add_guide, funcname): } +@overload +def _parse_size( + data: None, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> None: + ... + + +@overload +def _parse_size( + data: DataArray, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> pd.Series: + ... + + # copied from seaborn -def _parse_size(data, norm): +def _parse_size( + data: DataArray | None, + norm: tuple[float | None, float | None, bool] | Normalize | None, +) -> None | pd.Series: import matplotlib as mpl if data is None: return None - data = data.values.flatten() + flatdata = data.values.flatten() - if not _is_numeric(data): - levels = np.unique(data) + if not _is_numeric(flatdata): + levels = np.unique(flatdata) numbers = np.arange(1, 1 + len(levels))[::-1] else: - levels = numbers = np.sort(np.unique(data)) + levels = numbers = np.sort(np.unique(flatdata)) min_width, max_width = _MARKERSIZE_RANGE # width_range = min_width, max_width @@ -1310,6 +1352,7 @@ def _parse_size(data, norm): elif not isinstance(norm, mpl.colors.Normalize): err = "``size_norm`` must be None, tuple, or Normalize object." raise ValueError(err) + assert isinstance(norm, mpl.colors.Normalize) norm.clip = True if not norm.scaled(): @@ -1341,53 +1384,58 @@ class _Normalize(Sequence): The default is None. """ + _data: DataArray | None + _data_unique: np.ndarray + _data_unique_index: np.ndarray + _data_unique_inverse: np.ndarray + _data_is_numeric: bool + _width: tuple[float, float] | None + __slots__ = ( "_data", + "_data_unique", + "_data_unique_index", + "_data_unique_inverse", "_data_is_numeric", "_width", - "_unique", - "_unique_index", - "_unique_inverse", - "plt", ) - def __init__(self, data, width=None, _is_facetgrid=False): + def __init__( + self, + data: DataArray | None, + width: tuple[float, float] | None = None, + _is_facetgrid: bool = False, + ) -> None: self._data = data self._width = width if not _is_facetgrid else None - self.plt = import_matplotlib_pyplot() pint_array_type = DuckArrayModule("pint").type - to_unique = data.to_numpy() if isinstance(self._type, pint_array_type) else data - unique, unique_inverse = np.unique(to_unique, return_inverse=True) - self._unique = unique - self._unique_index = np.arange(0, unique.size) - if data is not None: - self._unique_inverse = data.copy(data=unique_inverse.reshape(data.shape)) - self._data_is_numeric = _is_numeric(data) - else: - self._unique_inverse = unique_inverse - self._data_is_numeric = False + to_unique = ( + data.to_numpy() # type: ignore[union-attr] + if isinstance(data if data is None else data.data, pint_array_type) + else data + ) + data_unique, data_unique_inverse = np.unique(to_unique, return_inverse=True) # type: ignore[call-overload] + self._data_unique = data_unique + self._data_unique_index = np.arange(0, data_unique.size) + self._data_unique_inverse = data_unique_inverse + self._data_is_numeric = False if data is None else _is_numeric(data) def __repr__(self) -> str: with np.printoptions(precision=4, suppress=True, threshold=5): return ( f"<_Normalize(data, width={self._width})>\n" - f"{self._unique} -> {self.values_unique}" + f"{self._data_unique} -> {self._values_unique}" ) def __len__(self) -> int: - return len(self._unique) + return len(self._data_unique) def __getitem__(self, key): - return self._unique[key] + return self._data_unique[key] @property - def _type(self): - data = self.data - return data.data if data is not None else data - - @property - def data(self): + def data(self) -> DataArray | None: return self._data @property @@ -1400,11 +1448,23 @@ def data_is_numeric(self) -> bool: >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) >>> _Normalize(a).data_is_numeric False + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a).data_is_numeric + True """ return self._data_is_numeric - def _calc_widths(self, y): - if self._width is None or y is None: + @overload + def _calc_widths(self, y: np.ndarray) -> np.ndarray: + ... + + @overload + def _calc_widths(self, y: DataArray) -> DataArray: + ... + + def _calc_widths(self, y: np.ndarray | DataArray) -> np.ndarray | DataArray: + if self._width is None: return y x0, x1 = self._width @@ -1414,18 +1474,23 @@ def _calc_widths(self, y): return widths - def _indexes_centered(self, x) -> None | Any: + @overload + def _indexes_centered(self, x: np.ndarray) -> np.ndarray: + ... + + @overload + def _indexes_centered(self, x: DataArray) -> DataArray: + ... + + def _indexes_centered(self, x: np.ndarray | DataArray) -> np.ndarray | DataArray: """ Offset indexes to make sure being in the center of self.levels. ["a", "b", "c"] -> [1, 3, 5] """ - if self.data is None: - return None - else: - return x * 2 + 1 + return x * 2 + 1 @property - def values(self): + def values(self) -> DataArray | None: """ Return a normalized number array for the unique levels. @@ -1453,43 +1518,52 @@ def values(self): array([27., 18., 18., 27., 54., 72.]) Dimensions without coordinates: dim_0 """ - return self._calc_widths( - self.data - if self.data_is_numeric - else self._indexes_centered(self._unique_inverse) - ) + if self.data is None: + return None - def _integers(self): - """ - Return integers. - ["a", "b", "c"] -> [1, 3, 5] - """ - return self._indexes_centered(self._unique_index) + val: DataArray + if self.data_is_numeric: + val = self.data + else: + arr = self._indexes_centered(self._data_unique_inverse) + val = self.data.copy(data=arr.reshape(self.data.shape)) + + return self._calc_widths(val) @property - def values_unique(self) -> np.ndarray: + def _values_unique(self) -> np.ndarray | None: """ Return unique values. Examples -------- >>> a = xr.DataArray(["b", "a", "a", "b", "c"]) - >>> _Normalize(a).values_unique + >>> _Normalize(a)._values_unique array([1, 3, 5]) - >>> a = xr.DataArray([2, 1, 1, 2, 3]) - >>> _Normalize(a).values_unique - array([1, 2, 3]) - >>> _Normalize(a, width=[18, 72]).values_unique + + >>> _Normalize(a, width=[18, 72])._values_unique array([18., 45., 72.]) + + >>> a = xr.DataArray([0.5, 0, 0, 0.5, 2, 3]) + >>> _Normalize(a)._values_unique + array([0. , 0.5, 2. , 3. ]) + + >>> _Normalize(a, width=[18, 72])._values_unique + array([18., 27., 54., 72.]) """ - return ( - self._integers() - if not self.data_is_numeric - else self._calc_widths(self._unique) - ) + if self.data is None: + return None + + val: np.ndarray + if self.data_is_numeric: + val = self._data_unique + else: + val = self._indexes_centered(self._data_unique_index) + + return self._calc_widths(val) @property - def ticks(self) -> None | np.ndarray: + def ticks(self) -> np.ndarray | None: """ Return ticks for plt.colorbar if the data is not numeric. @@ -1499,7 +1573,13 @@ def ticks(self) -> None | np.ndarray: >>> _Normalize(a).ticks array([1, 3, 5]) """ - return self._integers() if not self.data_is_numeric else None + val: None | np.ndarray + if self.data_is_numeric: + val = None + else: + val = self._indexes_centered(self._data_unique_index) + + return val @property def levels(self) -> np.ndarray: @@ -1513,11 +1593,16 @@ def levels(self) -> np.ndarray: >>> _Normalize(a).levels array([0, 2, 4, 6]) """ - return np.append(self._unique_index, np.max(self._unique_index) + 1) * 2 + return ( + np.append(self._data_unique_index, np.max(self._data_unique_index) + 1) * 2 + ) @property def _lookup(self) -> pd.Series: - return pd.Series(dict(zip(self.values_unique, self._unique))) + if self._values_unique is None: + raise ValueError("self.data can't be None.") + + return pd.Series(dict(zip(self._values_unique, self._data_unique))) def _lookup_arr(self, x) -> np.ndarray: # Use reindex to be less sensitive to float errors. reindex only @@ -1527,7 +1612,7 @@ def _lookup_arr(self, x) -> np.ndarray: return self._lookup.sort_index().reindex(x, method="nearest").to_numpy() @property - def format(self) -> plt.FuncFormatter: + def format(self) -> FuncFormatter: """ Return a FuncFormatter that maps self.values elements back to the original value as a string. Useful with plt.colorbar. @@ -1545,11 +1630,12 @@ def format(self) -> plt.FuncFormatter: >>> aa.format(1) '3.0' """ + plt = import_matplotlib_pyplot() def _func(x: Any, pos: None | Any = None): return f"{self._lookup_arr([x])[0]}" - return self.plt.FuncFormatter(_func) + return plt.FuncFormatter(_func) @property def func(self) -> Callable[[Any, None | Any], Any]: @@ -1595,7 +1681,7 @@ def _determine_guide( else: add_colorbar = False - if (add_legend) and hueplt_norm.data is None and sizeplt_norm.data is None: + if add_legend and hueplt_norm.data is None and sizeplt_norm.data is None: raise KeyError("Cannot create a legend when hue and markersize is None.") if add_legend is None: if ( diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 08afdffc3b1..8d6b8f11475 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -4,7 +4,6 @@ import platform import warnings from contextlib import contextmanager, nullcontext -from typing import Any from unittest import mock # noqa: F401 import numpy as np @@ -43,7 +42,9 @@ ) -def _importorskip(modname: str, minversion: str | None = None) -> tuple[bool, Any]: +def _importorskip( + modname: str, minversion: str | None = None +) -> tuple[bool, pytest.MarkDecorator]: try: mod = importlib.import_module(modname) has = True diff --git a/xarray/tests/test_formatting.py b/xarray/tests/test_formatting.py index 7c6e6ae1489..abbff51e0f9 100644 --- a/xarray/tests/test_formatting.py +++ b/xarray/tests/test_formatting.py @@ -637,12 +637,12 @@ def test__mapping_repr_recursive() -> None: # GH:issue:7111 # direct recursion - ds = xr.Dataset({"a": [["x"], [1, 2, 3]]}) + ds = xr.Dataset({"a": ("x", [1, 2, 3])}) ds.attrs["ds"] = ds formatting.dataset_repr(ds) # indirect recursion - ds2 = xr.Dataset({"b": [["y"], [1, 2, 3]]}) + ds2 = xr.Dataset({"b": ("y", [1, 2, 3])}) ds.attrs["ds"] = ds2 ds2.attrs["ds"] = ds formatting.dataset_repr(ds2) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 2a51bf89943..d675de87484 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -5,7 +5,7 @@ import math from copy import copy from datetime import datetime -from typing import Any +from typing import Any, Callable, Hashable, Literal import numpy as np import pandas as pd @@ -14,8 +14,8 @@ import xarray as xr import xarray.plot as xplt from xarray import DataArray, Dataset +from xarray.plot.dataarray_plot import _infer_interval_breaks from xarray.plot.dataset_plot import _infer_meta_data -from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( _assert_valid_xy, _build_discrete_cmap, @@ -170,16 +170,16 @@ def contourf_called(self, plotmethod): class TestPlot(PlotTestCase): @pytest.fixture(autouse=True) - def setup_array(self): + def setup_array(self) -> None: self.darray = DataArray(easy_array((2, 3, 4))) - def test_accessor(self): - from ..plot.plot import _PlotMethods + def test_accessor(self) -> None: + from xarray.plot.accessor import DataArrayPlotAccessor - assert DataArray.plot is _PlotMethods - assert isinstance(self.darray.plot, _PlotMethods) + assert DataArray.plot is DataArrayPlotAccessor + assert isinstance(self.darray.plot, DataArrayPlotAccessor) - def test_label_from_attrs(self): + def test_label_from_attrs(self) -> None: da = self.darray.copy() assert "" == label_from_attrs(da) @@ -209,7 +209,7 @@ def test_label_from_attrs(self): da.attrs = dict(long_name=long_latex_name) assert label_from_attrs(da) == long_latex_name - def test1d(self): + def test1d(self) -> None: self.darray[:, 0, 0].plot() with pytest.raises(ValueError, match=r"x must be one of None, 'dim_0'"): @@ -218,14 +218,14 @@ def test1d(self): with pytest.raises(TypeError, match=r"complex128"): (self.darray[:, 0, 0] + 1j).plot() - def test_1d_bool(self): + def test_1d_bool(self) -> None: xr.ones_like(self.darray[:, 0, 0], dtype=bool).plot() - def test_1d_x_y_kw(self): + def test_1d_x_y_kw(self) -> None: z = np.arange(10) da = DataArray(np.cos(z), dims=["z"], coords=[z], name="f") - xy = [[None, None], [None, "z"], ["z", None]] + xy: list[list[None | str]] = [[None, None], [None, "z"], ["z", None]] f, ax = plt.subplots(3, 1) for aa, (x, y) in enumerate(xy): @@ -241,7 +241,7 @@ def test_1d_x_y_kw(self): with pytest.raises(ValueError, match=rf"y {error_msg}"): da.plot(y="f") - def test_multiindex_level_as_coord(self): + def test_multiindex_level_as_coord(self) -> None: da = xr.DataArray( np.arange(5), dims="x", @@ -258,7 +258,7 @@ def test_multiindex_level_as_coord(self): assert_array_equal(h.get_ydata(), da[y].values) # Test for bug in GH issue #2725 - def test_infer_line_data(self): + def test_infer_line_data(self) -> None: current = DataArray( name="I", data=np.array([5, 8]), @@ -277,7 +277,7 @@ def test_infer_line_data(self): line = current.plot.line()[0] assert_array_equal(line.get_xdata(), current.coords["t"].values) - def test_line_plot_along_1d_coord(self): + def test_line_plot_along_1d_coord(self) -> None: # Test for bug in GH #3334 x_coord = xr.DataArray(data=[0.1, 0.2], dims=["x"]) t_coord = xr.DataArray(data=[10, 20], dims=["t"]) @@ -294,7 +294,7 @@ def test_line_plot_along_1d_coord(self): line = da.plot(y="time", hue="x")[0] assert_array_equal(line.get_ydata(), da.coords["time"].values) - def test_line_plot_wrong_hue(self): + def test_line_plot_wrong_hue(self) -> None: da = xr.DataArray( data=np.array([[0, 1], [5, 9]]), dims=["x", "t"], @@ -303,7 +303,7 @@ def test_line_plot_wrong_hue(self): with pytest.raises(ValueError, match="hue must be one of"): da.plot(x="t", hue="wrong_coord") - def test_2d_line(self): + def test_2d_line(self) -> None: with pytest.raises(ValueError, match=r"hue"): self.darray[:, :, 0].plot.line() @@ -316,7 +316,7 @@ def test_2d_line(self): with pytest.raises(ValueError, match=r"Cannot"): self.darray[:, :, 0].plot.line(x="dim_1", y="dim_0", hue="dim_1") - def test_2d_line_accepts_legend_kw(self): + def test_2d_line_accepts_legend_kw(self) -> None: self.darray[:, :, 0].plot.line(x="dim_0", add_legend=False) assert not plt.gca().get_legend() plt.cla() @@ -325,21 +325,21 @@ def test_2d_line_accepts_legend_kw(self): # check whether legend title is set assert plt.gca().get_legend().get_title().get_text() == "dim_1" - def test_2d_line_accepts_x_kw(self): + def test_2d_line_accepts_x_kw(self) -> None: self.darray[:, :, 0].plot.line(x="dim_0") assert plt.gca().get_xlabel() == "dim_0" plt.cla() self.darray[:, :, 0].plot.line(x="dim_1") assert plt.gca().get_xlabel() == "dim_1" - def test_2d_line_accepts_hue_kw(self): + def test_2d_line_accepts_hue_kw(self) -> None: self.darray[:, :, 0].plot.line(hue="dim_0") assert plt.gca().get_legend().get_title().get_text() == "dim_0" plt.cla() self.darray[:, :, 0].plot.line(hue="dim_1") assert plt.gca().get_legend().get_title().get_text() == "dim_1" - def test_2d_coords_line_plot(self): + def test_2d_coords_line_plot(self) -> None: lon, lat = np.meshgrid(np.linspace(-20, 20, 5), np.linspace(0, 30, 4)) lon += lat / 10 lat += lon / 10 @@ -360,7 +360,7 @@ def test_2d_coords_line_plot(self): with pytest.raises(ValueError, match="For 2D inputs, hue must be a dimension"): da.plot.line(x="lon", hue="lat") - def test_2d_coord_line_plot_coords_transpose_invariant(self): + def test_2d_coord_line_plot_coords_transpose_invariant(self) -> None: # checks for bug reported in GH #3933 x = np.arange(10) y = np.arange(20) @@ -371,20 +371,20 @@ def test_2d_coord_line_plot_coords_transpose_invariant(self): ds["v"] = ds.x + ds.y ds["v"].plot.line(y="z", hue="x") - def test_2d_before_squeeze(self): + def test_2d_before_squeeze(self) -> None: a = DataArray(easy_array((1, 5))) a.plot() - def test2d_uniform_calls_imshow(self): + def test2d_uniform_calls_imshow(self) -> None: assert self.imshow_called(self.darray[:, :, 0].plot.imshow) @pytest.mark.slow - def test2d_nonuniform_calls_contourf(self): + def test2d_nonuniform_calls_contourf(self) -> None: a = self.darray[:, :, 0] a.coords["dim_1"] = [2, 1, 89] assert self.contourf_called(a.plot.contourf) - def test2d_1d_2d_coordinates_contourf(self): + def test2d_1d_2d_coordinates_contourf(self) -> None: sz = (20, 10) depth = easy_array(sz) a = DataArray( @@ -396,7 +396,7 @@ def test2d_1d_2d_coordinates_contourf(self): a.plot.contourf(x="time", y="depth") a.plot.contourf(x="depth", y="time") - def test2d_1d_2d_coordinates_pcolormesh(self): + def test2d_1d_2d_coordinates_pcolormesh(self) -> None: # Test with equal coordinates to catch bug from #5097 sz = 10 y2d, x2d = np.meshgrid(np.arange(sz), np.arange(sz)) @@ -424,7 +424,7 @@ def test2d_1d_2d_coordinates_pcolormesh(self): _, unique_counts = np.unique(v[:-1], axis=0, return_counts=True) assert np.all(unique_counts == 1) - def test_contourf_cmap_set(self): + def test_contourf_cmap_set(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) cmap = mpl.cm.viridis @@ -451,7 +451,7 @@ def test_contourf_cmap_set(self): # check the set_over color assert pl.cmap(np.inf) == cmap(np.inf) - def test_contourf_cmap_set_with_bad_under_over(self): + def test_contourf_cmap_set_with_bad_under_over(self) -> None: a = DataArray(easy_array((4, 4)), dims=["z", "time"]) # make a copy here because we want a local cmap that we will modify. @@ -487,13 +487,13 @@ def test_contourf_cmap_set_with_bad_under_over(self): # check the set_over color has been kept assert pl.cmap(np.inf) == cmap(np.inf) - def test3d(self): + def test3d(self) -> None: self.darray.plot() - def test_can_pass_in_axis(self): + def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot) - def test__infer_interval_breaks(self): + def test__infer_interval_breaks(self) -> None: assert_array_equal([-0.5, 0.5, 1.5], _infer_interval_breaks([0, 1])) assert_array_equal( [-0.5, 0.5, 5.0, 9.5, 10.5], _infer_interval_breaks([0, 1, 9, 10]) @@ -518,7 +518,7 @@ def test__infer_interval_breaks(self): with pytest.raises(ValueError): _infer_interval_breaks(np.array([0, 2, 1]), check_monotonic=True) - def test__infer_interval_breaks_logscale(self): + def test__infer_interval_breaks_logscale(self) -> None: """ Check if interval breaks are defined in the logspace if scale="log" """ @@ -538,7 +538,7 @@ def test__infer_interval_breaks_logscale(self): x = _infer_interval_breaks(x, axis=0, scale="log") np.testing.assert_allclose(x, expected_interval_breaks) - def test__infer_interval_breaks_logscale_invalid_coords(self): + def test__infer_interval_breaks_logscale_invalid_coords(self) -> None: """ Check error is raised when passing non-positive coordinates with logscale """ @@ -551,7 +551,7 @@ def test__infer_interval_breaks_logscale_invalid_coords(self): with pytest.raises(ValueError): _infer_interval_breaks(x, scale="log") - def test_geo_data(self): + def test_geo_data(self) -> None: # Regression test for gh2250 # Realistic coordinates taken from the example dataset lat = np.array( @@ -583,7 +583,7 @@ def test_geo_data(self): ax = plt.gca() assert ax.has_data() - def test_datetime_dimension(self): + def test_datetime_dimension(self) -> None: nrow = 3 ncol = 4 time = pd.date_range("2000-01-01", periods=nrow) @@ -596,7 +596,7 @@ def test_datetime_dimension(self): @pytest.mark.slow @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid(self): + def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) d.coords["z"] = list("abcd") @@ -613,7 +613,7 @@ def test_convenient_facetgrid(self): d[0].plot(x="x", y="y", col="z", ax=plt.gca()) @pytest.mark.slow - def test_subplot_kws(self): + def test_subplot_kws(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) d.coords["z"] = list("abcd") @@ -630,7 +630,7 @@ def test_subplot_kws(self): assert ax.get_facecolor()[0:3] == mpl.colors.to_rgb("r") @pytest.mark.slow - def test_plot_size(self): + def test_plot_size(self) -> None: self.darray[:, 0, 0].plot(figsize=(13, 5)) assert tuple(plt.gcf().get_size_inches()) == (13, 5) @@ -657,7 +657,7 @@ def test_plot_size(self): @pytest.mark.slow @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid_4d(self): + def test_convenient_facetgrid_4d(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) g = d.plot(x="x", y="y", col="columns", row="rows") @@ -669,28 +669,28 @@ def test_convenient_facetgrid_4d(self): with pytest.raises(ValueError, match=r"[Ff]acet"): d.plot(x="x", y="y", col="columns", ax=plt.gca()) - def test_coord_with_interval(self): + def test_coord_with_interval(self) -> None: """Test line plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot() - def test_coord_with_interval_x(self): + def test_coord_with_interval_x(self) -> None: """Test line plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins") - def test_coord_with_interval_y(self): + def test_coord_with_interval_y(self) -> None: """Test line plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins") - def test_coord_with_interval_xy(self): + def test_coord_with_interval_xy(self) -> None: """Test line plot with intervals on both x and y axes.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot() @pytest.mark.parametrize("dim", ("x", "y")) - def test_labels_with_units_with_interval(self, dim): + def test_labels_with_units_with_interval(self, dim) -> None: """Test line plot with intervals and a units attribute.""" bins = [-1, 0, 1, 2] arr = self.darray.groupby_bins("dim_0", bins).mean(...) @@ -706,75 +706,75 @@ def test_labels_with_units_with_interval(self, dim): class TestPlot1D(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: d = [0, 1.1, 0, 2] self.darray = DataArray(d, coords={"period": range(len(d))}, dims="period") self.darray.period.attrs["units"] = "s" - def test_xlabel_is_index_name(self): + def test_xlabel_is_index_name(self) -> None: self.darray.plot() assert "period [s]" == plt.gca().get_xlabel() - def test_no_label_name_on_x_axis(self): + def test_no_label_name_on_x_axis(self) -> None: self.darray.plot(y="period") assert "" == plt.gca().get_xlabel() - def test_no_label_name_on_y_axis(self): + def test_no_label_name_on_y_axis(self) -> None: self.darray.plot() assert "" == plt.gca().get_ylabel() - def test_ylabel_is_data_name(self): + def test_ylabel_is_data_name(self) -> None: self.darray.name = "temperature" self.darray.attrs["units"] = "degrees_Celsius" self.darray.plot() assert "temperature [degrees_Celsius]" == plt.gca().get_ylabel() - def test_xlabel_is_data_name(self): + def test_xlabel_is_data_name(self) -> None: self.darray.name = "temperature" self.darray.attrs["units"] = "degrees_Celsius" self.darray.plot(y="period") assert "temperature [degrees_Celsius]" == plt.gca().get_xlabel() - def test_format_string(self): + def test_format_string(self) -> None: self.darray.plot.line("ro") - def test_can_pass_in_axis(self): + def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.line) - def test_nonnumeric_index(self): + def test_nonnumeric_index(self) -> None: a = DataArray([1, 2, 3], {"letter": ["a", "b", "c"]}, dims="letter") a.plot.line() - def test_primitive_returned(self): + def test_primitive_returned(self) -> None: p = self.darray.plot.line() assert isinstance(p[0], mpl.lines.Line2D) @pytest.mark.slow - def test_plot_nans(self): + def test_plot_nans(self) -> None: self.darray[1] = np.nan self.darray.plot.line() - def test_x_ticks_are_rotated_for_time(self): + def test_x_ticks_are_rotated_for_time(self) -> None: time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.arange(len(time)), [("t", time)]) a.plot.line() rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 - def test_xyincrease_false_changes_axes(self): + def test_xyincrease_false_changes_axes(self) -> None: self.darray.plot.line(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() ylim = plt.gca().get_ylim() diffs = xlim[1] - xlim[0], ylim[1] - ylim[0] assert all(x < 0 for x in diffs) - def test_slice_in_title(self): + def test_slice_in_title(self) -> None: self.darray.coords["d"] = 10.009 self.darray.plot.line() title = plt.gca().get_title() assert "d = 10.01" == title - def test_slice_in_title_single_item_array(self): + def test_slice_in_title_single_item_array(self) -> None: """Edge case for data of shape (1, N) or (N, 1).""" darray = self.darray.expand_dims({"d": np.array([10.009])}) darray.plot.line(x="period") @@ -784,55 +784,55 @@ def test_slice_in_title_single_item_array(self): class TestPlotStep(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.darray = DataArray(easy_array((2, 3, 4))) - def test_step(self): + def test_step(self) -> None: hdl = self.darray[0, 0].plot.step() assert "steps" in hdl[0].get_drawstyle() @pytest.mark.parametrize("where", ["pre", "post", "mid"]) - def test_step_with_where(self, where): + def test_step_with_where(self, where) -> None: hdl = self.darray[0, 0].plot.step(where=where) assert hdl[0].get_drawstyle() == f"steps-{where}" - def test_step_with_hue(self): + def test_step_with_hue(self) -> None: hdl = self.darray[0].plot.step(hue="dim_2") assert hdl[0].get_drawstyle() == "steps-pre" @pytest.mark.parametrize("where", ["pre", "post", "mid"]) - def test_step_with_hue_and_where(self, where): + def test_step_with_hue_and_where(self, where) -> None: hdl = self.darray[0].plot.step(hue="dim_2", where=where) assert hdl[0].get_drawstyle() == f"steps-{where}" - def test_drawstyle_steps(self): + def test_drawstyle_steps(self) -> None: hdl = self.darray[0].plot(hue="dim_2", drawstyle="steps") assert hdl[0].get_drawstyle() == "steps" @pytest.mark.parametrize("where", ["pre", "post", "mid"]) - def test_drawstyle_steps_with_where(self, where): + def test_drawstyle_steps_with_where(self, where) -> None: hdl = self.darray[0].plot(hue="dim_2", drawstyle=f"steps-{where}") assert hdl[0].get_drawstyle() == f"steps-{where}" - def test_coord_with_interval_step(self): + def test_coord_with_interval_step(self) -> None: """Test step plot with intervals.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step() assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) - def test_coord_with_interval_step_x(self): + def test_coord_with_interval_step_x(self) -> None: """Test step plot with intervals explicitly on x axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins") assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) - def test_coord_with_interval_step_y(self): + def test_coord_with_interval_step_y(self) -> None: """Test step plot with intervals explicitly on y axis.""" bins = [-1, 0, 1, 2] self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins") assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) - def test_coord_with_interval_step_x_and_y_raises_valueeerror(self): + def test_coord_with_interval_step_x_and_y_raises_valueeerror(self) -> None: """Test that step plot with intervals both on x and y axes raises an error.""" arr = xr.DataArray( [pd.Interval(0, 1), pd.Interval(1, 2)], @@ -844,41 +844,41 @@ def test_coord_with_interval_step_x_and_y_raises_valueeerror(self): class TestPlotHistogram(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.darray = DataArray(easy_array((2, 3, 4))) - def test_3d_array(self): + def test_3d_array(self) -> None: self.darray.plot.hist() - def test_xlabel_uses_name(self): + def test_xlabel_uses_name(self) -> None: self.darray.name = "testpoints" self.darray.attrs["units"] = "testunits" self.darray.plot.hist() assert "testpoints [testunits]" == plt.gca().get_xlabel() - def test_title_is_histogram(self): + def test_title_is_histogram(self) -> None: self.darray.coords["d"] = 10 self.darray.plot.hist() assert "d = 10" == plt.gca().get_title() - def test_can_pass_in_kwargs(self): + def test_can_pass_in_kwargs(self) -> None: nbins = 5 self.darray.plot.hist(bins=nbins) assert nbins == len(plt.gca().patches) - def test_can_pass_in_axis(self): + def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.darray.plot.hist) - def test_primitive_returned(self): + def test_primitive_returned(self) -> None: h = self.darray.plot.hist() assert isinstance(h[-1][0], mpl.patches.Rectangle) @pytest.mark.slow - def test_plot_nans(self): + def test_plot_nans(self) -> None: self.darray[0, 0, 0] = np.nan self.darray.plot.hist() - def test_hist_coord_with_interval(self): + def test_hist_coord_with_interval(self) -> None: ( self.darray.groupby_bins("dim_0", [-1, 0, 1, 2]) .mean(...) @@ -889,10 +889,10 @@ def test_hist_coord_with_interval(self): @requires_matplotlib class TestDetermineCmapParams: @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.data = np.linspace(0, 1, num=100) - def test_robust(self): + def test_robust(self) -> None: cmap_params = _determine_cmap_params(self.data, robust=True) assert cmap_params["vmin"] == np.percentile(self.data, 2) assert cmap_params["vmax"] == np.percentile(self.data, 98) @@ -901,7 +901,7 @@ def test_robust(self): assert cmap_params["levels"] is None assert cmap_params["norm"] is None - def test_center(self): + def test_center(self) -> None: cmap_params = _determine_cmap_params(self.data, center=0.5) assert cmap_params["vmax"] - 0.5 == 0.5 - cmap_params["vmin"] assert cmap_params["cmap"] == "RdBu_r" @@ -909,22 +909,22 @@ def test_center(self): assert cmap_params["levels"] is None assert cmap_params["norm"] is None - def test_cmap_sequential_option(self): + def test_cmap_sequential_option(self) -> None: with xr.set_options(cmap_sequential="magma"): cmap_params = _determine_cmap_params(self.data) assert cmap_params["cmap"] == "magma" - def test_cmap_sequential_explicit_option(self): + def test_cmap_sequential_explicit_option(self) -> None: with xr.set_options(cmap_sequential=mpl.cm.magma): cmap_params = _determine_cmap_params(self.data) assert cmap_params["cmap"] == mpl.cm.magma - def test_cmap_divergent_option(self): + def test_cmap_divergent_option(self) -> None: with xr.set_options(cmap_divergent="magma"): cmap_params = _determine_cmap_params(self.data, center=0.5) assert cmap_params["cmap"] == "magma" - def test_nan_inf_are_ignored(self): + def test_nan_inf_are_ignored(self) -> None: cmap_params1 = _determine_cmap_params(self.data) data = self.data data[50:55] = np.nan @@ -934,7 +934,7 @@ def test_nan_inf_are_ignored(self): assert cmap_params1["vmax"] == cmap_params2["vmax"] @pytest.mark.slow - def test_integer_levels(self): + def test_integer_levels(self) -> None: data = self.data + 1 # default is to cover full data range but with no guarantee on Nlevels @@ -973,7 +973,7 @@ def test_integer_levels(self): assert cmap_params["cmap"].name == "viridis" assert cmap_params["extend"] == "both" - def test_list_levels(self): + def test_list_levels(self) -> None: data = self.data + 1 orig_levels = [0, 1, 2, 3, 4, 5] @@ -990,7 +990,7 @@ def test_list_levels(self): cmap_params = _determine_cmap_params(data, levels=wrap_levels(orig_levels)) assert_array_equal(cmap_params["levels"], orig_levels) - def test_divergentcontrol(self): + def test_divergentcontrol(self) -> None: neg = self.data - 0.1 pos = self.data @@ -1068,7 +1068,7 @@ def test_divergentcontrol(self): # specifying levels makes cmap a Colormap object assert cmap_params["cmap"].name == "RdBu_r" - def test_norm_sets_vmin_vmax(self): + def test_norm_sets_vmin_vmax(self) -> None: vmin = self.data.min() vmax = self.data.max() @@ -1112,13 +1112,13 @@ def setUp(self): plt.close("all") @pytest.mark.slow - def test_recover_from_seaborn_jet_exception(self): + def test_recover_from_seaborn_jet_exception(self) -> None: pal = _color_palette("jet", 4) assert type(pal) == np.ndarray assert len(pal) == 4 @pytest.mark.slow - def test_build_discrete_cmap(self): + def test_build_discrete_cmap(self) -> None: for (cmap, levels, extend, filled) in [ ("jet", [0, 1], "both", False), ("hot", [-4, 4], "max", True), @@ -1136,7 +1136,7 @@ def test_build_discrete_cmap(self): assert ncmap.colorbar_extend == "max" @pytest.mark.slow - def test_discrete_colormap_list_of_levels(self): + def test_discrete_colormap_list_of_levels(self) -> None: for extend, levels in [ ("max", [-1, 2, 4, 8, 10]), ("both", [2, 5, 10, 11]), @@ -1155,7 +1155,7 @@ def test_discrete_colormap_list_of_levels(self): assert len(levels) - 1 == len(primitive.cmap.colors) @pytest.mark.slow - def test_discrete_colormap_int_levels(self): + def test_discrete_colormap_int_levels(self) -> None: for extend, levels, vmin, vmax, cmap in [ ("neither", 7, None, None, None), ("neither", 7, None, 20, mpl.cm.RdBu), @@ -1181,13 +1181,13 @@ def test_discrete_colormap_int_levels(self): assert "max" == primitive.cmap.colorbar_extend assert levels >= len(primitive.cmap.colors) - def test_discrete_colormap_list_levels_and_vmin_or_vmax(self): + def test_discrete_colormap_list_levels_and_vmin_or_vmax(self) -> None: levels = [0, 5, 10, 15] primitive = self.darray.plot(levels=levels, vmin=-3, vmax=20) assert primitive.norm.vmax == max(levels) assert primitive.norm.vmin == min(levels) - def test_discrete_colormap_provided_boundary_norm(self): + def test_discrete_colormap_provided_boundary_norm(self) -> None: norm = mpl.colors.BoundaryNorm([0, 5, 10, 15], 4) primitive = self.darray.plot.contourf(norm=norm) np.testing.assert_allclose(primitive.levels, norm.boundaries) @@ -1201,11 +1201,15 @@ class Common2dMixin: Should have the same name as the method. """ + darray: DataArray + plotfunc: staticmethod + pass_in_axis: Callable + # Needs to be overridden in TestSurface for facet grid plots subplot_kws: dict[Any, Any] | None = None @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: da = DataArray( easy_array((10, 15), start=-1), dims=["y", "x"], @@ -1218,7 +1222,7 @@ def setUp(self): ds["y2d"] = DataArray(y, dims=["y", "x"]) ds = ds.set_coords(["x2d", "y2d"]) # set darray and plot method - self.darray = ds.testvar + self.darray: DataArray = ds.testvar # Add CF-compliant metadata self.darray.attrs["long_name"] = "a_long_name" @@ -1230,30 +1234,30 @@ def setUp(self): self.plotmethod = getattr(self.darray.plot, self.plotfunc.__name__) - def test_label_names(self): + def test_label_names(self) -> None: self.plotmethod() assert "x_long_name [x_units]" == plt.gca().get_xlabel() assert "y_long_name [y_units]" == plt.gca().get_ylabel() - def test_1d_raises_valueerror(self): + def test_1d_raises_valueerror(self) -> None: with pytest.raises(ValueError, match=r"DataArray must be 2d"): self.plotfunc(self.darray[0, :]) - def test_bool(self): + def test_bool(self) -> None: xr.ones_like(self.darray, dtype=bool).plot() - def test_complex_raises_typeerror(self): + def test_complex_raises_typeerror(self) -> None: with pytest.raises(TypeError, match=r"complex128"): (self.darray + 1j).plot() - def test_3d_raises_valueerror(self): + def test_3d_raises_valueerror(self) -> None: a = DataArray(easy_array((2, 3, 4))) if self.plotfunc.__name__ == "imshow": pytest.skip() with pytest.raises(ValueError, match=r"DataArray must be 2d"): self.plotfunc(a) - def test_nonnumeric_index(self): + def test_nonnumeric_index(self) -> None: a = DataArray(easy_array((3, 2)), coords=[["a", "b", "c"], ["d", "e"]]) if self.plotfunc.__name__ == "surface": # ax.plot_surface errors with nonnumerics: @@ -1262,7 +1266,7 @@ def test_nonnumeric_index(self): else: self.plotfunc(a) - def test_multiindex_raises_typeerror(self): + def test_multiindex_raises_typeerror(self) -> None: a = DataArray( easy_array((3, 2)), dims=("x", "y"), @@ -1272,10 +1276,10 @@ def test_multiindex_raises_typeerror(self): with pytest.raises(TypeError, match=r"[Pp]lot"): self.plotfunc(a) - def test_can_pass_in_axis(self): + def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.plotmethod) - def test_xyincrease_defaults(self): + def test_xyincrease_defaults(self) -> None: # With default settings the axis must be ordered regardless # of the coords order. @@ -1291,28 +1295,28 @@ def test_xyincrease_defaults(self): bounds = plt.gca().get_xlim() assert bounds[0] < bounds[1] - def test_xyincrease_false_changes_axes(self): + def test_xyincrease_false_changes_axes(self) -> None: self.plotmethod(xincrease=False, yincrease=False) xlim = plt.gca().get_xlim() ylim = plt.gca().get_ylim() diffs = xlim[0] - 14, xlim[1] - 0, ylim[0] - 9, ylim[1] - 0 assert all(abs(x) < 1 for x in diffs) - def test_xyincrease_true_changes_axes(self): + def test_xyincrease_true_changes_axes(self) -> None: self.plotmethod(xincrease=True, yincrease=True) xlim = plt.gca().get_xlim() ylim = plt.gca().get_ylim() diffs = xlim[0] - 0, xlim[1] - 14, ylim[0] - 0, ylim[1] - 9 assert all(abs(x) < 1 for x in diffs) - def test_x_ticks_are_rotated_for_time(self): + def test_x_ticks_are_rotated_for_time(self) -> None: time = pd.date_range("2000-01-01", "2000-01-10") a = DataArray(np.random.randn(2, len(time)), [("xx", [1, 2]), ("t", time)]) a.plot(x="t") rotation = plt.gca().get_xticklabels()[0].get_rotation() assert rotation != 0 - def test_plot_nans(self): + def test_plot_nans(self) -> None: x1 = self.darray[:5] x2 = self.darray.copy() x2[5:] = np.nan @@ -1323,25 +1327,25 @@ def test_plot_nans(self): @pytest.mark.filterwarnings("ignore::UserWarning") @pytest.mark.filterwarnings("ignore:invalid value encountered") - def test_can_plot_all_nans(self): + def test_can_plot_all_nans(self) -> None: # regression test for issue #1780 self.plotfunc(DataArray(np.full((2, 2), np.nan))) @pytest.mark.filterwarnings("ignore: Attempting to set") - def test_can_plot_axis_size_one(self): + def test_can_plot_axis_size_one(self) -> None: if self.plotfunc.__name__ not in ("contour", "contourf"): self.plotfunc(DataArray(np.ones((1, 1)))) - def test_disallows_rgb_arg(self): + def test_disallows_rgb_arg(self) -> None: with pytest.raises(ValueError): # Always invalid for most plots. Invalid for imshow with 2D data. self.plotfunc(DataArray(np.ones((2, 2))), rgb="not None") - def test_viridis_cmap(self): + def test_viridis_cmap(self) -> None: cmap_name = self.plotmethod(cmap="viridis").get_cmap().name assert "viridis" == cmap_name - def test_default_cmap(self): + def test_default_cmap(self) -> None: cmap_name = self.plotmethod().get_cmap().name assert "RdBu_r" == cmap_name @@ -1349,26 +1353,26 @@ def test_default_cmap(self): assert "viridis" == cmap_name @requires_seaborn - def test_seaborn_palette_as_cmap(self): + def test_seaborn_palette_as_cmap(self) -> None: cmap_name = self.plotmethod(levels=2, cmap="husl").get_cmap().name assert "husl" == cmap_name - def test_can_change_default_cmap(self): + def test_can_change_default_cmap(self) -> None: cmap_name = self.plotmethod(cmap="Blues").get_cmap().name assert "Blues" == cmap_name - def test_diverging_color_limits(self): + def test_diverging_color_limits(self) -> None: artist = self.plotmethod() vmin, vmax = artist.get_clim() assert round(abs(-vmin - vmax), 7) == 0 - def test_xy_strings(self): - self.plotmethod("y", "x") + def test_xy_strings(self) -> None: + self.plotmethod(x="y", y="x") ax = plt.gca() assert "y_long_name [y_units]" == ax.get_xlabel() assert "x_long_name [x_units]" == ax.get_ylabel() - def test_positional_coord_string(self): + def test_positional_coord_string(self) -> None: self.plotmethod(y="x") ax = plt.gca() assert "x_long_name [x_units]" == ax.get_ylabel() @@ -1379,26 +1383,26 @@ def test_positional_coord_string(self): assert "x_long_name [x_units]" == ax.get_xlabel() assert "y_long_name [y_units]" == ax.get_ylabel() - def test_bad_x_string_exception(self): + def test_bad_x_string_exception(self) -> None: with pytest.raises(ValueError, match=r"x and y cannot be equal."): self.plotmethod(x="y", y="y") error_msg = "must be one of None, 'x', 'x2d', 'y', 'y2d'" with pytest.raises(ValueError, match=rf"x {error_msg}"): - self.plotmethod("not_a_real_dim", "y") + self.plotmethod(x="not_a_real_dim", y="y") with pytest.raises(ValueError, match=rf"x {error_msg}"): self.plotmethod(x="not_a_real_dim") with pytest.raises(ValueError, match=rf"y {error_msg}"): self.plotmethod(y="not_a_real_dim") self.darray.coords["z"] = 100 - def test_coord_strings(self): + def test_coord_strings(self) -> None: # 1d coords (same as dims) assert {"x", "y"} == set(self.darray.dims) self.plotmethod(y="y", x="x") - def test_non_linked_coords(self): + def test_non_linked_coords(self) -> None: # plot with coordinate names that are not dimensions self.darray.coords["newy"] = self.darray.y + 150 # Normal case, without transpose @@ -1410,7 +1414,7 @@ def test_non_linked_coords(self): # simply ensure that these high coords were passed over assert np.min(ax.get_ylim()) > 100.0 - def test_non_linked_coords_transpose(self): + def test_non_linked_coords_transpose(self) -> None: # plot with coordinate names that are not dimensions, # and with transposed y and x axes # This used to raise an error with pcolormesh and contour @@ -1424,7 +1428,7 @@ def test_non_linked_coords_transpose(self): # simply ensure that these high coords were passed over assert np.min(ax.get_xlim()) > 100.0 - def test_multiindex_level_as_coord(self): + def test_multiindex_level_as_coord(self) -> None: da = DataArray( easy_array((3, 2)), dims=("x", "y"), @@ -1445,7 +1449,7 @@ def test_multiindex_level_as_coord(self): with pytest.raises(ValueError, match=r"y must be one of None, 'a', 'b', 'x'"): self.plotfunc(da, x="a", y="y") - def test_default_title(self): + def test_default_title(self) -> None: a = DataArray(easy_array((4, 3, 2)), dims=["a", "b", "c"]) a.coords["c"] = [0, 1] a.coords["d"] = "foo" @@ -1453,11 +1457,11 @@ def test_default_title(self): title = plt.gca().get_title() assert "c = 1, d = foo" == title or "d = foo, c = 1" == title - def test_colorbar_default_label(self): + def test_colorbar_default_label(self) -> None: self.plotmethod(add_colorbar=True) assert "a_long_name [a_units]" in text_in_fig() - def test_no_labels(self): + def test_no_labels(self) -> None: self.darray.name = "testvar" self.darray.attrs["units"] = "test_units" self.plotmethod(add_labels=False) @@ -1469,7 +1473,7 @@ def test_no_labels(self): ]: assert string not in alltxt - def test_colorbar_kwargs(self): + def test_colorbar_kwargs(self) -> None: # replace label self.darray.attrs.pop("long_name") self.darray.attrs["units"] = "test_units" @@ -1520,7 +1524,7 @@ def test_colorbar_kwargs(self): cbar_kwargs={"label": "label"}, ) - def test_verbose_facetgrid(self): + def test_verbose_facetgrid(self) -> None: a = easy_array((10, 15, 3)) d = DataArray(a, dims=["y", "x", "z"]) g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws) @@ -1528,15 +1532,14 @@ def test_verbose_facetgrid(self): for ax in g.axes.flat: assert ax.has_data() - def test_2d_function_and_method_signature_same(self): - func_sig = inspect.getcallargs(self.plotfunc, self.darray) - method_sig = inspect.getcallargs(self.plotmethod) - del method_sig["_PlotMethods_obj"] - del func_sig["darray"] - assert func_sig == method_sig + def test_2d_function_and_method_signature_same(self) -> None: + func_sig = inspect.signature(self.plotfunc) + method_sig = inspect.signature(self.plotmethod) + for argname, param in method_sig.parameters.items(): + assert func_sig.parameters[argname] == param @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid(self): + def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) @@ -1568,7 +1571,7 @@ def test_convenient_facetgrid(self): assert "" == ax.get_xlabel() @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid_4d(self): + def test_convenient_facetgrid_4d(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") @@ -1578,7 +1581,7 @@ def test_convenient_facetgrid_4d(self): assert ax.has_data() @pytest.mark.filterwarnings("ignore:This figure includes") - def test_facetgrid_map_only_appends_mappables(self): + def test_facetgrid_map_only_appends_mappables(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) g = self.plotfunc(d, x="x", y="y", col="columns", row="rows") @@ -1590,7 +1593,7 @@ def test_facetgrid_map_only_appends_mappables(self): assert expected == actual - def test_facetgrid_cmap(self): + def test_facetgrid_cmap(self) -> None: # Regression test for GH592 data = np.random.random(size=(20, 25, 12)) + np.linspace(-3, 3, 12) d = DataArray(data, dims=["x", "y", "time"]) @@ -1600,7 +1603,7 @@ def test_facetgrid_cmap(self): # check that all colormaps are the same assert len({m.get_cmap().name for m in fg._mappables}) == 1 - def test_facetgrid_cbar_kwargs(self): + def test_facetgrid_cbar_kwargs(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) g = self.plotfunc( @@ -1616,25 +1619,25 @@ def test_facetgrid_cbar_kwargs(self): if g.cbar is not None: assert get_colorbar_label(g.cbar) == "test_label" - def test_facetgrid_no_cbar_ax(self): + def test_facetgrid_no_cbar_ax(self) -> None: a = easy_array((10, 15, 2, 3)) d = DataArray(a, dims=["y", "x", "columns", "rows"]) with pytest.raises(ValueError): self.plotfunc(d, x="x", y="y", col="columns", row="rows", cbar_ax=1) - def test_cmap_and_color_both(self): + def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): self.plotmethod(colors="k", cmap="RdBu") - def test_2d_coord_with_interval(self): + def test_2d_coord_with_interval(self) -> None: for dim in self.darray.dims: gp = self.darray.groupby_bins(dim, range(15), restore_coord_dims=True).mean( - dim + [dim] ) for kind in ["imshow", "pcolormesh", "contourf", "contour"]: getattr(gp.plot, kind)() - def test_colormap_error_norm_and_vmin_vmax(self): + def test_colormap_error_norm_and_vmin_vmax(self) -> None: norm = mpl.colors.LogNorm(0.1, 1e1) with pytest.raises(ValueError): @@ -1650,17 +1653,17 @@ class TestContourf(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.contourf) @pytest.mark.slow - def test_contourf_called(self): + def test_contourf_called(self) -> None: # Having both statements ensures the test works properly assert not self.contourf_called(self.darray.plot.imshow) assert self.contourf_called(self.darray.plot.contourf) - def test_primitive_artist_returned(self): + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl.contour.QuadContourSet) @pytest.mark.slow - def test_extend(self): + def test_extend(self) -> None: artist = self.plotmethod() assert artist.extend == "neither" @@ -1678,7 +1681,7 @@ def test_extend(self): assert artist.extend == "max" @pytest.mark.slow - def test_2d_coord_names(self): + def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() @@ -1686,7 +1689,7 @@ def test_2d_coord_names(self): assert "y2d" == ax.get_ylabel() @pytest.mark.slow - def test_levels(self): + def test_levels(self) -> None: artist = self.plotmethod(levels=[-0.5, -0.4, 0.1]) assert artist.extend == "both" @@ -1705,7 +1708,7 @@ class TestContour(Common2dMixin, PlotTestCase): def _color_as_tuple(c): return tuple(c[:3]) - def test_colors(self): + def test_colors(self) -> None: # with single color, we don't want rgb array artist = self.plotmethod(colors="k") @@ -1722,7 +1725,7 @@ def test_colors(self): # the last color is now under "over" assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) - def test_colors_np_levels(self): + def test_colors_np_levels(self) -> None: # https://github.com/pydata/xarray/issues/3284 levels = np.array([-0.5, 0.0, 0.5, 1.0]) @@ -1732,23 +1735,23 @@ def test_colors_np_levels(self): # the last color is now under "over" assert self._color_as_tuple(artist.cmap._rgba_over) == (0.0, 0.0, 1.0) - def test_cmap_and_color_both(self): + def test_cmap_and_color_both(self) -> None: with pytest.raises(ValueError): self.plotmethod(colors="k", cmap="RdBu") - def list_of_colors_in_cmap_raises_error(self): + def list_of_colors_in_cmap_raises_error(self) -> None: with pytest.raises(ValueError, match=r"list of colors"): self.plotmethod(cmap=["k", "b"]) @pytest.mark.slow - def test_2d_coord_names(self): + def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() - def test_single_level(self): + def test_single_level(self) -> None: # this used to raise an error, but not anymore since # add_colorbar defaults to false self.plotmethod(levels=[0.1]) @@ -1759,23 +1762,23 @@ class TestPcolormesh(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.pcolormesh) - def test_primitive_artist_returned(self): + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl.collections.QuadMesh) - def test_everything_plotted(self): + def test_everything_plotted(self) -> None: artist = self.plotmethod() assert artist.get_array().size == self.darray.size @pytest.mark.slow - def test_2d_coord_names(self): + def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() assert "x2d" == ax.get_xlabel() assert "y2d" == ax.get_ylabel() - def test_dont_infer_interval_breaks_for_cartopy(self): + def test_dont_infer_interval_breaks_for_cartopy(self) -> None: # Regression for GH 781 ax = plt.gca() # Simulate a Cartopy Axis @@ -1794,7 +1797,7 @@ class TestPcolormeshLogscale(PlotTestCase): plotfunc = staticmethod(xplt.pcolormesh) @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.boundaries = (-1, 9, -4, 3) shape = (8, 11) x = np.logspace(self.boundaries[0], self.boundaries[1], shape[1]) @@ -1807,7 +1810,7 @@ def setUp(self): ) self.darray = da - def test_interval_breaks_logspace(self): + def test_interval_breaks_logspace(self) -> None: """ Check if the outer vertices of the pcolormesh are the expected values @@ -1838,22 +1841,22 @@ class TestImshow(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.imshow) @pytest.mark.slow - def test_imshow_called(self): + def test_imshow_called(self) -> None: # Having both statements ensures the test works properly assert not self.imshow_called(self.darray.plot.contourf) assert self.imshow_called(self.darray.plot.imshow) - def test_xy_pixel_centered(self): + def test_xy_pixel_centered(self) -> None: self.darray.plot.imshow(yincrease=False) assert np.allclose([-0.5, 14.5], plt.gca().get_xlim()) assert np.allclose([9.5, -0.5], plt.gca().get_ylim()) - def test_default_aspect_is_auto(self): + def test_default_aspect_is_auto(self) -> None: self.darray.plot.imshow() assert "auto" == plt.gca().get_aspect() @pytest.mark.slow - def test_cannot_change_mpl_aspect(self): + def test_cannot_change_mpl_aspect(self) -> None: with pytest.raises(ValueError, match=r"not available in xarray"): self.darray.plot.imshow(aspect="equal") @@ -1864,45 +1867,45 @@ def test_cannot_change_mpl_aspect(self): assert tuple(plt.gcf().get_size_inches()) == (10, 5) @pytest.mark.slow - def test_primitive_artist_returned(self): + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl.image.AxesImage) @pytest.mark.slow @requires_seaborn - def test_seaborn_palette_needs_levels(self): + def test_seaborn_palette_needs_levels(self) -> None: with pytest.raises(ValueError): self.plotmethod(cmap="husl") - def test_2d_coord_names(self): + def test_2d_coord_names(self) -> None: with pytest.raises(ValueError, match=r"requires 1D coordinates"): self.plotmethod(x="x2d", y="y2d") - def test_plot_rgb_image(self): + def test_plot_rgb_image(self) -> None: DataArray( easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] ).plot.imshow() assert 0 == len(find_possible_colorbars()) - def test_plot_rgb_image_explicit(self): + def test_plot_rgb_image_explicit(self) -> None: DataArray( easy_array((10, 15, 3), start=0), dims=["y", "x", "band"] ).plot.imshow(y="y", x="x", rgb="band") assert 0 == len(find_possible_colorbars()) - def test_plot_rgb_faceted(self): + def test_plot_rgb_faceted(self) -> None: DataArray( easy_array((2, 2, 10, 15, 3), start=0), dims=["a", "b", "y", "x", "band"] ).plot.imshow(row="a", col="b") assert 0 == len(find_possible_colorbars()) - def test_plot_rgba_image_transposed(self): + def test_plot_rgba_image_transposed(self) -> None: # We can handle the color axis being in any position DataArray( easy_array((4, 10, 15), start=0), dims=["band", "y", "x"] ).plot.imshow() - def test_warns_ambigious_dim(self): + def test_warns_ambigious_dim(self) -> None: arr = DataArray(easy_array((3, 3, 3)), dims=["y", "x", "band"]) with pytest.warns(UserWarning): arr.plot.imshow() @@ -1910,40 +1913,45 @@ def test_warns_ambigious_dim(self): arr.plot.imshow(rgb="band") arr.plot.imshow(x="x", y="y") - def test_rgb_errors_too_many_dims(self): + def test_rgb_errors_too_many_dims(self) -> None: arr = DataArray(easy_array((3, 3, 3, 3)), dims=["y", "x", "z", "band"]) with pytest.raises(ValueError): arr.plot.imshow(rgb="band") - def test_rgb_errors_bad_dim_sizes(self): + def test_rgb_errors_bad_dim_sizes(self) -> None: arr = DataArray(easy_array((5, 5, 5)), dims=["y", "x", "band"]) with pytest.raises(ValueError): arr.plot.imshow(rgb="band") - def test_normalize_rgb_imshow(self): - for kwargs in ( - dict(vmin=-1), - dict(vmax=2), - dict(vmin=-1, vmax=1), - dict(vmin=0, vmax=0), - dict(vmin=0, robust=True), - dict(vmax=-1, robust=True), - ): - da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) - arr = da.plot.imshow(**kwargs).get_array() - assert 0 <= arr.min() <= arr.max() <= 1, kwargs + @pytest.mark.parametrize( + ["vmin", "vmax", "robust"], + [ + (-1, None, False), + (None, 2, False), + (-1, 1, False), + (0, 0, False), + (0, None, True), + (None, -1, True), + ], + ) + def test_normalize_rgb_imshow( + self, vmin: float | None, vmax: float | None, robust: bool + ) -> None: + da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) + arr = da.plot.imshow(vmin=vmin, vmax=vmax, robust=robust).get_array() + assert 0 <= arr.min() <= arr.max() <= 1 - def test_normalize_rgb_one_arg_error(self): + def test_normalize_rgb_one_arg_error(self) -> None: da = DataArray(easy_array((5, 5, 3), start=-0.6, stop=1.4)) # If passed one bound that implies all out of range, error: - for kwargs in [dict(vmax=-1), dict(vmin=2)]: + for vmin, vmax in ((None, -1), (2, None)): with pytest.raises(ValueError): - da.plot.imshow(**kwargs) + da.plot.imshow(vmin=vmin, vmax=vmax) # If passed two that's just moving the range, *not* an error: - for kwargs in [dict(vmax=-1, vmin=-1.2), dict(vmin=2, vmax=2.1)]: - da.plot.imshow(**kwargs) + for vmin2, vmax2 in ((-1.2, -1), (2, 2.1)): + da.plot.imshow(vmin=vmin2, vmax=vmax2) - def test_imshow_rgb_values_in_valid_range(self): + def test_imshow_rgb_values_in_valid_range(self) -> None: da = DataArray(np.arange(75, dtype="uint8").reshape((5, 5, 3))) _, ax = plt.subplots() out = da.plot.imshow(ax=ax).get_array() @@ -1951,12 +1959,12 @@ def test_imshow_rgb_values_in_valid_range(self): assert (out[..., :3] == da.values).all() # Compare without added alpha @pytest.mark.filterwarnings("ignore:Several dimensions of this array") - def test_regression_rgb_imshow_dim_size_one(self): + def test_regression_rgb_imshow_dim_size_one(self) -> None: # Regression: https://github.com/pydata/xarray/issues/1966 da = DataArray(easy_array((1, 3, 3), start=0.0, stop=1.0)) da.plot.imshow() - def test_origin_overrides_xyincrease(self): + def test_origin_overrides_xyincrease(self) -> None: da = DataArray(easy_array((3, 2)), coords=[[-2, 0, 2], [-1, 1]]) with figure_context(): da.plot.imshow(origin="upper") @@ -1974,12 +1982,12 @@ class TestSurface(Common2dMixin, PlotTestCase): plotfunc = staticmethod(xplt.surface) subplot_kws = {"projection": "3d"} - def test_primitive_artist_returned(self): + def test_primitive_artist_returned(self) -> None: artist = self.plotmethod() assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection) @pytest.mark.slow - def test_2d_coord_names(self): + def test_2d_coord_names(self) -> None: self.plotmethod(x="x2d", y="y2d") # make sure labels came out ok ax = plt.gca() @@ -1987,34 +1995,34 @@ def test_2d_coord_names(self): assert "y2d" == ax.get_ylabel() assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel() - def test_xyincrease_false_changes_axes(self): + def test_xyincrease_false_changes_axes(self) -> None: # Does not make sense for surface plots pytest.skip("does not make sense for surface plots") - def test_xyincrease_true_changes_axes(self): + def test_xyincrease_true_changes_axes(self) -> None: # Does not make sense for surface plots pytest.skip("does not make sense for surface plots") - def test_can_pass_in_axis(self): + def test_can_pass_in_axis(self) -> None: self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"}) - def test_default_cmap(self): + def test_default_cmap(self) -> None: # Does not make sense for surface plots with default arguments pytest.skip("does not make sense for surface plots") - def test_diverging_color_limits(self): + def test_diverging_color_limits(self) -> None: # Does not make sense for surface plots with default arguments pytest.skip("does not make sense for surface plots") - def test_colorbar_kwargs(self): + def test_colorbar_kwargs(self) -> None: # Does not make sense for surface plots with default arguments pytest.skip("does not make sense for surface plots") - def test_cmap_and_color_both(self): + def test_cmap_and_color_both(self) -> None: # Does not make sense for surface plots with default arguments pytest.skip("does not make sense for surface plots") - def test_seaborn_palette_as_cmap(self): + def test_seaborn_palette_as_cmap(self) -> None: # seaborn does not work with mpl_toolkits.mplot3d with pytest.raises(ValueError): super().test_seaborn_palette_as_cmap() @@ -2022,7 +2030,7 @@ def test_seaborn_palette_as_cmap(self): # Need to modify this test for surface(), because all subplots should have labels, # not just left and bottom @pytest.mark.filterwarnings("ignore:tight_layout cannot") - def test_convenient_facetgrid(self): + def test_convenient_facetgrid(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"]) g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2) @@ -2041,28 +2049,28 @@ def test_convenient_facetgrid(self): assert "y" == ax.get_ylabel() assert "x" == ax.get_xlabel() - def test_viridis_cmap(self): + def test_viridis_cmap(self) -> None: return super().test_viridis_cmap() - def test_can_change_default_cmap(self): + def test_can_change_default_cmap(self) -> None: return super().test_can_change_default_cmap() - def test_colorbar_default_label(self): + def test_colorbar_default_label(self) -> None: return super().test_colorbar_default_label() - def test_facetgrid_map_only_appends_mappables(self): + def test_facetgrid_map_only_appends_mappables(self) -> None: return super().test_facetgrid_map_only_appends_mappables() class TestFacetGrid(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: d = easy_array((10, 15, 3)) self.darray = DataArray(d, dims=["y", "x", "z"], coords={"z": ["a", "b", "c"]}) self.g = xplt.FacetGrid(self.darray, col="z") @pytest.mark.slow - def test_no_args(self): + def test_no_args(self) -> None: self.g.map_dataarray(xplt.contourf, "x", "y") # Don't want colorbar labeled with 'None' @@ -2073,7 +2081,7 @@ def test_no_args(self): assert ax.has_data() @pytest.mark.slow - def test_names_appear_somewhere(self): + def test_names_appear_somewhere(self) -> None: self.darray.name = "testvar" self.g.map_dataarray(xplt.contourf, "x", "y") for k, ax in zip("abc", self.g.axes.flat): @@ -2085,7 +2093,7 @@ def test_names_appear_somewhere(self): assert label in alltxt @pytest.mark.slow - def test_text_not_super_long(self): + def test_text_not_super_long(self) -> None: self.darray.coords["z"] = [100 * letter for letter in "abc"] g = xplt.FacetGrid(self.darray, col="z") g.map_dataarray(xplt.contour, "x", "y") @@ -2097,7 +2105,7 @@ def test_text_not_super_long(self): assert t0.endswith("...") @pytest.mark.slow - def test_colorbar(self): + def test_colorbar(self) -> None: vmin = self.darray.values.min() vmax = self.darray.values.max() expected = np.array((vmin, vmax)) @@ -2111,7 +2119,7 @@ def test_colorbar(self): assert 1 == len(find_possible_colorbars()) @pytest.mark.slow - def test_empty_cell(self): + def test_empty_cell(self) -> None: g = xplt.FacetGrid(self.darray, col="z", col_wrap=2) g.map_dataarray(xplt.imshow, "x", "y") @@ -2120,12 +2128,12 @@ def test_empty_cell(self): assert not bottomright.get_visible() @pytest.mark.slow - def test_norow_nocol_error(self): + def test_norow_nocol_error(self) -> None: with pytest.raises(ValueError, match=r"[Rr]ow"): xplt.FacetGrid(self.darray) @pytest.mark.slow - def test_groups(self): + def test_groups(self) -> None: self.g.map_dataarray(xplt.imshow, "x", "y") upperleft_dict = self.g.name_dicts[0, 0] upperleft_array = self.darray.loc[upperleft_dict] @@ -2134,19 +2142,19 @@ def test_groups(self): assert_equal(upperleft_array, z0) @pytest.mark.slow - def test_float_index(self): + def test_float_index(self) -> None: self.darray.coords["z"] = [0.1, 0.2, 0.4] g = xplt.FacetGrid(self.darray, col="z") g.map_dataarray(xplt.imshow, "x", "y") @pytest.mark.slow - def test_nonunique_index_error(self): + def test_nonunique_index_error(self) -> None: self.darray.coords["z"] = [0.1, 0.2, 0.2] with pytest.raises(ValueError, match=r"[Uu]nique"): xplt.FacetGrid(self.darray, col="z") @pytest.mark.slow - def test_robust(self): + def test_robust(self) -> None: z = np.zeros((20, 20, 2)) darray = DataArray(z, dims=["y", "x", "z"]) darray[:, :, 1] = 1 @@ -2168,7 +2176,7 @@ def test_robust(self): assert largest < 21 @pytest.mark.slow - def test_can_set_vmin_vmax(self): + def test_can_set_vmin_vmax(self) -> None: vmin, vmax = 50.0, 1000.0 expected = np.array((vmin, vmax)) self.g.map_dataarray(xplt.imshow, "x", "y", vmin=vmin, vmax=vmax) @@ -2178,7 +2186,7 @@ def test_can_set_vmin_vmax(self): assert np.allclose(expected, clim) @pytest.mark.slow - def test_vmin_vmax_equal(self): + def test_vmin_vmax_equal(self) -> None: # regression test for GH3734 fg = self.g.map_dataarray(xplt.imshow, "x", "y", vmin=50, vmax=50) for mappable in fg._mappables: @@ -2186,14 +2194,14 @@ def test_vmin_vmax_equal(self): @pytest.mark.slow @pytest.mark.filterwarnings("ignore") - def test_can_set_norm(self): + def test_can_set_norm(self) -> None: norm = mpl.colors.SymLogNorm(0.1) self.g.map_dataarray(xplt.imshow, "x", "y", norm=norm) for image in plt.gcf().findobj(mpl.image.AxesImage): assert image.norm is norm @pytest.mark.slow - def test_figure_size(self): + def test_figure_size(self) -> None: assert_array_equal(self.g.fig.get_size_inches(), (10, 3)) @@ -2216,7 +2224,7 @@ def test_figure_size(self): g = xplt.plot(self.darray, row=2, col="z", ax=plt.gca(), size=6) @pytest.mark.slow - def test_num_ticks(self): + def test_num_ticks(self) -> None: nticks = 99 maxticks = nticks + 1 self.g.map_dataarray(xplt.imshow, "x", "y") @@ -2231,14 +2239,14 @@ def test_num_ticks(self): assert yticks >= nticks / 2.0 @pytest.mark.slow - def test_map(self): + def test_map(self) -> None: assert self.g._finalized is False self.g.map(plt.contourf, "x", "y", ...) assert self.g._finalized is True self.g.map(lambda: None) @pytest.mark.slow - def test_map_dataset(self): + def test_map_dataset(self) -> None: g = xplt.FacetGrid(self.darray.to_dataset(name="foo"), col="z") g.map(plt.contourf, "x", "y", "foo") @@ -2257,7 +2265,7 @@ def test_map_dataset(self): assert 1 == len(find_possible_colorbars()) @pytest.mark.slow - def test_set_axis_labels(self): + def test_set_axis_labels(self) -> None: g = self.g.map_dataarray(xplt.contourf, "x", "y") g.set_axis_labels("longitude", "latitude") alltxt = text_in_fig() @@ -2265,7 +2273,7 @@ def test_set_axis_labels(self): assert label in alltxt @pytest.mark.slow - def test_facetgrid_colorbar(self): + def test_facetgrid_colorbar(self) -> None: a = easy_array((10, 15, 4)) d = DataArray(a, dims=["y", "x", "z"], name="foo") @@ -2279,7 +2287,7 @@ def test_facetgrid_colorbar(self): assert 0 == len(find_possible_colorbars()) @pytest.mark.slow - def test_facetgrid_polar(self): + def test_facetgrid_polar(self) -> None: # test if polar projection in FacetGrid does not raise an exception self.darray.plot.pcolormesh( col="z", subplot_kws=dict(projection="polar"), sharex=False, sharey=False @@ -2289,7 +2297,7 @@ def test_facetgrid_polar(self): @pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetGrid4d(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: a = easy_array((10, 15, 3, 2)) darray = DataArray(a, dims=["y", "x", "col", "row"]) darray.coords["col"] = np.array( @@ -2301,7 +2309,7 @@ def setUp(self): self.darray = darray - def test_title_kwargs(self): + def test_title_kwargs(self) -> None: g = xplt.FacetGrid(self.darray, col="col", row="row") g.set_titles(template="{value}", weight="bold") @@ -2314,7 +2322,7 @@ def test_title_kwargs(self): assert property_in_axes_text("weight", "bold", label, ax) @pytest.mark.slow - def test_default_labels(self): + def test_default_labels(self) -> None: g = xplt.FacetGrid(self.darray, col="col", row="row") assert (2, 3) == g.axes.shape @@ -2344,10 +2352,10 @@ def test_default_labels(self): @pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetedLinePlotsLegend(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.darray = xr.tutorial.scatter_example_dataset() - def test_legend_labels(self): + def test_legend_labels(self) -> None: fg = self.darray.A.plot.line(col="x", row="w", hue="z") all_legend_labels = [t.get_text() for t in fg.figlegend.texts] # labels in legend should be ['0', '1', '2', '3'] @@ -2357,7 +2365,7 @@ def test_legend_labels(self): @pytest.mark.filterwarnings("ignore:tight_layout cannot") class TestFacetedLinePlots(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: self.darray = DataArray( np.random.randn(10, 6, 3, 4), dims=["hue", "x", "col", "row"], @@ -2371,14 +2379,14 @@ def setUp(self): self.darray.col.attrs["units"] = "colunits" self.darray.row.attrs["units"] = "rowunits" - def test_facetgrid_shape(self): + def test_facetgrid_shape(self) -> None: g = self.darray.plot(row="row", col="col", hue="hue") assert g.axes.shape == (len(self.darray.row), len(self.darray.col)) g = self.darray.plot(row="col", col="row", hue="hue") assert g.axes.shape == (len(self.darray.col), len(self.darray.row)) - def test_unnamed_args(self): + def test_unnamed_args(self) -> None: g = self.darray.plot.line("o--", row="row", col="col", hue="hue") lines = [ q for q in g.axes.flat[0].get_children() if isinstance(q, mpl.lines.Line2D) @@ -2387,7 +2395,7 @@ def test_unnamed_args(self): assert lines[0].get_marker() == "o" assert lines[0].get_linestyle() == "--" - def test_default_labels(self): + def test_default_labels(self) -> None: g = self.darray.plot(row="row", col="col", hue="hue") # Rightmost column should be labeled for label, ax in zip(self.darray.coords["row"].values, g.axes[:, -1]): @@ -2401,7 +2409,7 @@ def test_default_labels(self): for ax in g.axes[:, 0]: assert substring_in_axes(self.darray.name, ax) - def test_test_empty_cell(self): + def test_test_empty_cell(self) -> None: g = ( self.darray.isel(row=1) .drop_vars("row") @@ -2411,7 +2419,7 @@ def test_test_empty_cell(self): assert not bottomright.has_data() assert not bottomright.get_visible() - def test_set_axis_labels(self): + def test_set_axis_labels(self) -> None: g = self.darray.plot(row="row", col="col", hue="hue") g.set_axis_labels("longitude", "latitude") alltxt = text_in_fig() @@ -2419,15 +2427,15 @@ def test_set_axis_labels(self): assert "longitude" in alltxt assert "latitude" in alltxt - def test_axes_in_faceted_plot(self): + def test_axes_in_faceted_plot(self) -> None: with pytest.raises(ValueError): self.darray.plot.line(row="row", col="col", x="x", ax=plt.axes()) - def test_figsize_and_size(self): + def test_figsize_and_size(self) -> None: with pytest.raises(ValueError): - self.darray.plot.line(row="row", col="col", x="x", size=3, figsize=4) + self.darray.plot.line(row="row", col="col", x="x", size=3, figsize=(4, 3)) - def test_wrong_num_of_dimensions(self): + def test_wrong_num_of_dimensions(self) -> None: with pytest.raises(ValueError): self.darray.plot(row="row", hue="hue") self.darray.plot.line(row="row", hue="hue") @@ -2436,7 +2444,7 @@ def test_wrong_num_of_dimensions(self): @requires_matplotlib class TestDatasetQuiverPlots(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: das = [ DataArray( np.random.randn(3, 3, 4, 4), @@ -2455,7 +2463,7 @@ def setUp(self): ds["mag"] = np.hypot(ds.u, ds.v) self.ds = ds - def test_quiver(self): + def test_quiver(self) -> None: with figure_context(): hdl = self.ds.isel(row=0, col=0).plot.quiver(x="x", y="y", u="u", v="v") assert isinstance(hdl, mpl.quiver.Quiver) @@ -2467,13 +2475,14 @@ def test_quiver(self): x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" ) - def test_facetgrid(self): + def test_facetgrid(self) -> None: with figure_context(): fg = self.ds.plot.quiver( x="x", y="y", u="u", v="v", row="row", col="col", scale=1, hue="mag" ) for handle in fg._mappables: assert isinstance(handle, mpl.quiver.Quiver) + assert fg.quiverkey is not None assert "uunits" in fg.quiverkey.text.get_text() with figure_context(): @@ -2519,7 +2528,7 @@ def test_add_guide(self, add_guide, hue_style, legend, colorbar): @requires_matplotlib class TestDatasetStreamplotPlots(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: das = [ DataArray( np.random.randn(3, 3, 2, 2), @@ -2538,7 +2547,7 @@ def setUp(self): ds["mag"] = np.hypot(ds.u, ds.v) self.ds = ds - def test_streamline(self): + def test_streamline(self) -> None: with figure_context(): hdl = self.ds.isel(row=0, col=0).plot.streamplot(x="x", y="y", u="u", v="v") assert isinstance(hdl, mpl.collections.LineCollection) @@ -2550,7 +2559,7 @@ def test_streamline(self): x="x", y="y", u="u", v="v", hue="mag", hue_style="discrete" ) - def test_facetgrid(self): + def test_facetgrid(self) -> None: with figure_context(): fg = self.ds.plot.streamplot( x="x", y="y", u="u", v="v", row="row", col="col", hue="mag" @@ -2574,7 +2583,7 @@ def test_facetgrid(self): @requires_matplotlib class TestDatasetScatterPlots(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: das = [ DataArray( np.random.randn(3, 3, 4, 4), @@ -2593,21 +2602,52 @@ def setUp(self): ds.B.attrs["units"] = "Bunits" self.ds = ds - def test_accessor(self): - from ..plot.dataset_plot import _Dataset_PlotMethods + def test_accessor(self) -> None: + from ..plot.accessor import DatasetPlotAccessor + + assert Dataset.plot is DatasetPlotAccessor + assert isinstance(self.ds.plot, DatasetPlotAccessor) + + @pytest.mark.parametrize( + "add_guide, hue_style, legend, colorbar", + [ + (None, None, False, True), + (False, None, False, False), + (True, None, False, True), + (True, "continuous", False, True), + (False, "discrete", False, False), + (True, "discrete", True, False), + ], + ) + def test_add_guide( + self, + add_guide: bool | None, + hue_style: Literal["continuous", "discrete", None], + legend: bool, + colorbar: bool, + ) -> None: - assert Dataset.plot is _Dataset_PlotMethods - assert isinstance(self.ds.plot, _Dataset_PlotMethods) + meta_data = _infer_meta_data( + self.ds, + x="A", + y="B", + hue="hue", + hue_style=hue_style, + add_guide=add_guide, + funcname="scatter", + ) + assert meta_data["add_legend"] is legend + assert meta_data["add_colorbar"] is colorbar - def test_facetgrid_shape(self): + def test_facetgrid_shape(self) -> None: g = self.ds.plot.scatter(x="A", y="B", row="row", col="col") assert g.axes.shape == (len(self.ds.row), len(self.ds.col)) g = self.ds.plot.scatter(x="A", y="B", row="col", col="row") assert g.axes.shape == (len(self.ds.col), len(self.ds.row)) - def test_default_labels(self): - g = self.ds.plot.scatter("A", "B", row="row", col="col", hue="hue") + def test_default_labels(self) -> None: + g = self.ds.plot.scatter(x="A", y="B", row="row", col="col", hue="hue") # Top row should be labeled for label, ax in zip(self.ds.coords["col"].values, g.axes[0, :]): @@ -2621,22 +2661,34 @@ def test_default_labels(self): for ax in g.axes[:, 0]: assert ax.get_ylabel() == "B [Bunits]" - def test_axes_in_faceted_plot(self): + def test_axes_in_faceted_plot(self) -> None: with pytest.raises(ValueError): self.ds.plot.scatter(x="A", y="B", row="row", ax=plt.axes()) - def test_figsize_and_size(self): + def test_figsize_and_size(self) -> None: with pytest.raises(ValueError): - self.ds.plot.scatter(x="A", y="B", row="row", size=3, figsize=4) + self.ds.plot.scatter(x="A", y="B", row="row", size=3, figsize=(4, 3)) @pytest.mark.parametrize( "x, y, hue, add_legend, add_colorbar, error_type", [ - ("A", "The Spanish Inquisition", None, None, None, KeyError), - ("The Spanish Inquisition", "B", None, None, True, ValueError), + pytest.param( + "A", "The Spanish Inquisition", None, None, None, KeyError, id="bad_y" + ), + pytest.param( + "The Spanish Inquisition", "B", None, None, True, ValueError, id="bad_x" + ), ], ) - def test_bad_args(self, x, y, hue, add_legend, add_colorbar, error_type): + def test_bad_args( + self, + x: Hashable, + y: Hashable, + hue: Hashable | None, + add_legend: bool | None, + add_colorbar: bool | None, + error_type: type[Exception], + ): with pytest.raises(error_type): self.ds.plot.scatter( x=x, y=y, hue=hue, add_legend=add_legend, add_colorbar=add_colorbar @@ -2644,7 +2696,7 @@ def test_bad_args(self, x, y, hue, add_legend, add_colorbar, error_type): @pytest.mark.xfail(reason="datetime,timedelta hue variable not supported.") @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) - def test_datetime_hue(self, hue_style): + def test_datetime_hue(self, hue_style: Literal["discrete", "continuous"]) -> None: ds2 = self.ds.copy() ds2["hue"] = pd.date_range("2000-1-1", periods=4) ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) @@ -2652,30 +2704,35 @@ def test_datetime_hue(self, hue_style): ds2["hue"] = pd.timedelta_range("-1D", periods=4, freq="D") ds2.plot.scatter(x="A", y="B", hue="hue", hue_style=hue_style) - def test_facetgrid_hue_style(self): - # Can't move this to pytest.mark.parametrize because py37-bare-minimum - # doesn't have matplotlib. - for hue_style in ("discrete", "continuous"): - g = self.ds.plot.scatter( - x="A", y="B", row="row", col="col", hue="hue", hue_style=hue_style - ) - # 'discrete' and 'continuous', should be single PathCollection - assert isinstance(g._mappables[-1], mpl.collections.PathCollection) + @pytest.mark.parametrize("hue_style", ["discrete", "continuous"]) + def test_facetgrid_hue_style( + self, hue_style: Literal["discrete", "continuous"] + ) -> None: + g = self.ds.plot.scatter( + x="A", y="B", row="row", col="col", hue="hue", hue_style=hue_style + ) + assert isinstance(g._mappables[-1], mpl.collections.PathCollection) @pytest.mark.parametrize( - "x, y, hue, markersize", [("A", "B", "x", "col"), ("x", "row", "A", "B")] + ["x", "y", "hue", "markersize"], + [("A", "B", "x", "col"), ("x", "row", "A", "B")], ) - def test_scatter(self, x, y, hue, markersize): + def test_scatter( + self, x: Hashable, y: Hashable, hue: Hashable, markersize: Hashable + ) -> None: self.ds.plot.scatter(x=x, y=y, hue=hue, markersize=markersize) - def test_non_numeric_legend(self): + with pytest.raises(ValueError, match=r"u, v"): + self.ds.plot.scatter(x=x, y=y, u="col", v="row") + + def test_non_numeric_legend(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["a", "b", "c", "d"] pc = ds2.plot.scatter(x="A", y="B", hue="hue") # should make a discrete legend assert pc.axes.legend_ is not None - def test_legend_labels(self): + def test_legend_labels(self) -> None: # regression test for #4126: incorrect legend labels ds2 = self.ds.copy() ds2["hue"] = ["a", "a", "b", "b"] @@ -2690,11 +2747,13 @@ def test_legend_labels(self): ] assert actual == expected - def test_legend_labels_facetgrid(self): + def test_legend_labels_facetgrid(self) -> None: ds2 = self.ds.copy() ds2["hue"] = ["d", "a", "c", "b"] g = ds2.plot.scatter(x="A", y="B", hue="hue", markersize="x", col="col") - actual = tuple(t.get_text() for t in g.figlegend.texts) + legend = g.figlegend + assert legend is not None + actual = tuple(t.get_text() for t in legend.texts) expected = ( "x [xunits]", "$\\mathdefault{0}$", @@ -2703,14 +2762,14 @@ def test_legend_labels_facetgrid(self): ) assert actual == expected - def test_add_legend_by_default(self): + def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") assert len(sc.figure.axes) == 2 class TestDatetimePlot(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: """ Create a DataArray with a time-axis that contains datetime objects. """ @@ -2722,11 +2781,11 @@ def setUp(self): self.darray = darray - def test_datetime_line_plot(self): + def test_datetime_line_plot(self) -> None: # test if line plot raises no Exception self.darray.plot.line() - def test_datetime_units(self): + def test_datetime_units(self) -> None: # test that matplotlib-native datetime works: fig, ax = plt.subplots() ax.plot(self.darray["time"], self.darray) @@ -2735,7 +2794,7 @@ def test_datetime_units(self): # mpl.dates.AutoDateLocator passes and no other subclasses: assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator - def test_datetime_plot1d(self): + def test_datetime_plot1d(self) -> None: # Test that matplotlib-native datetime works: p = self.darray.plot.line() ax = p[0].axes @@ -2744,7 +2803,7 @@ def test_datetime_plot1d(self): # mpl.dates.AutoDateLocator passes and no other subclasses: assert type(ax.xaxis.get_major_locator()) is mpl.dates.AutoDateLocator - def test_datetime_plot2d(self): + def test_datetime_plot2d(self) -> None: # Test that matplotlib-native datetime works: da = DataArray( np.arange(3 * 4).reshape(3, 4), @@ -2768,7 +2827,7 @@ def test_datetime_plot2d(self): @requires_cftime class TestCFDatetimePlot(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: """ Create a DataArray with a time-axis that contains cftime.datetime objects. @@ -2781,13 +2840,13 @@ def setUp(self): self.darray = darray - def test_cfdatetime_line_plot(self): + def test_cfdatetime_line_plot(self) -> None: self.darray.isel(x=0).plot.line() - def test_cfdatetime_pcolormesh_plot(self): + def test_cfdatetime_pcolormesh_plot(self) -> None: self.darray.plot.pcolormesh() - def test_cfdatetime_contour_plot(self): + def test_cfdatetime_contour_plot(self) -> None: self.darray.plot.contour() @@ -2795,7 +2854,7 @@ def test_cfdatetime_contour_plot(self): @pytest.mark.skipif(has_nc_time_axis, reason="nc_time_axis is installed") class TestNcAxisNotInstalled(PlotTestCase): @pytest.fixture(autouse=True) - def setUp(self): + def setUp(self) -> None: """ Create a DataArray with a time-axis that contains cftime.datetime objects. @@ -2809,7 +2868,7 @@ def setUp(self): self.darray = darray - def test_ncaxis_notinstalled_line_plot(self): + def test_ncaxis_notinstalled_line_plot(self) -> None: with pytest.raises(ImportError, match=r"optional `nc-time-axis`"): self.darray.plot.line() @@ -2847,60 +2906,60 @@ def data_array_logspaced(self, request): ) @pytest.mark.parametrize("xincrease", [True, False]) - def test_xincrease_kwarg(self, data_array, xincrease): + def test_xincrease_kwarg(self, data_array, xincrease) -> None: with figure_context(): data_array.plot(xincrease=xincrease) assert plt.gca().xaxis_inverted() == (not xincrease) @pytest.mark.parametrize("yincrease", [True, False]) - def test_yincrease_kwarg(self, data_array, yincrease): + def test_yincrease_kwarg(self, data_array, yincrease) -> None: with figure_context(): data_array.plot(yincrease=yincrease) assert plt.gca().yaxis_inverted() == (not yincrease) @pytest.mark.parametrize("xscale", ["linear", "logit", "symlog"]) - def test_xscale_kwarg(self, data_array, xscale): + def test_xscale_kwarg(self, data_array, xscale) -> None: with figure_context(): data_array.plot(xscale=xscale) assert plt.gca().get_xscale() == xscale @pytest.mark.parametrize("yscale", ["linear", "logit", "symlog"]) - def test_yscale_kwarg(self, data_array, yscale): + def test_yscale_kwarg(self, data_array, yscale) -> None: with figure_context(): data_array.plot(yscale=yscale) assert plt.gca().get_yscale() == yscale - def test_xscale_log_kwarg(self, data_array_logspaced): + def test_xscale_log_kwarg(self, data_array_logspaced) -> None: xscale = "log" with figure_context(): data_array_logspaced.plot(xscale=xscale) assert plt.gca().get_xscale() == xscale - def test_yscale_log_kwarg(self, data_array_logspaced): + def test_yscale_log_kwarg(self, data_array_logspaced) -> None: yscale = "log" with figure_context(): data_array_logspaced.plot(yscale=yscale) assert plt.gca().get_yscale() == yscale - def test_xlim_kwarg(self, data_array): + def test_xlim_kwarg(self, data_array) -> None: with figure_context(): expected = (0.0, 1000.0) data_array.plot(xlim=[0, 1000]) assert plt.gca().get_xlim() == expected - def test_ylim_kwarg(self, data_array): + def test_ylim_kwarg(self, data_array) -> None: with figure_context(): data_array.plot(ylim=[0, 1000]) expected = (0.0, 1000.0) assert plt.gca().get_ylim() == expected - def test_xticks_kwarg(self, data_array): + def test_xticks_kwarg(self, data_array) -> None: with figure_context(): data_array.plot(xticks=np.arange(5)) expected = np.arange(5).tolist() assert_array_equal(plt.gca().get_xticks(), expected) - def test_yticks_kwarg(self, data_array): + def test_yticks_kwarg(self, data_array) -> None: with figure_context(): data_array.plot(yticks=np.arange(5)) expected = np.arange(5) @@ -2909,7 +2968,7 @@ def test_yticks_kwarg(self, data_array): @requires_matplotlib @pytest.mark.parametrize("plotfunc", ["pcolormesh", "contourf", "contour"]) -def test_plot_transposed_nondim_coord(plotfunc): +def test_plot_transposed_nondim_coord(plotfunc) -> None: x = np.linspace(0, 10, 101) h = np.linspace(3, 7, 101) s = np.linspace(0, 1, 51) @@ -2927,7 +2986,7 @@ def test_plot_transposed_nondim_coord(plotfunc): @requires_matplotlib @pytest.mark.parametrize("plotfunc", ["pcolormesh", "imshow"]) -def test_plot_transposes_properly(plotfunc): +def test_plot_transposes_properly(plotfunc) -> None: # test that we aren't mistakenly transposing when the 2 dimensions have equal sizes. da = xr.DataArray([np.sin(2 * np.pi / 10 * np.arange(10))] * 10, dims=("y", "x")) with figure_context(): @@ -2939,7 +2998,7 @@ def test_plot_transposes_properly(plotfunc): @requires_matplotlib -def test_facetgrid_single_contour(): +def test_facetgrid_single_contour() -> None: # regression test for GH3569 x, y = np.meshgrid(np.arange(12), np.arange(12)) z = xr.DataArray(np.sqrt(x**2 + y**2)) @@ -2987,6 +3046,8 @@ def test_get_axis_raises(): pytest.param(None, 5, None, False, {}, id="size"), pytest.param(None, 5.5, None, False, {"label": "test"}, id="size_kwargs"), pytest.param(None, 5, 1, False, {}, id="size+aspect"), + pytest.param(None, 5, "auto", False, {}, id="auto_aspect"), + pytest.param(None, 5, "equal", False, {}, id="equal_aspect"), pytest.param(None, None, None, True, {}, id="ax"), pytest.param(None, None, None, False, {}, id="default"), pytest.param(None, None, None, False, {"label": "test"}, id="default_kwargs"), @@ -3036,7 +3097,7 @@ def test_get_axis_current() -> None: @requires_matplotlib -def test_maybe_gca(): +def test_maybe_gca() -> None: with figure_context(): ax = _maybe_gca(aspect=1) @@ -3076,7 +3137,9 @@ def test_maybe_gca(): ("A", "B", "z", "y", "x", "w", None, True, True), ], ) -def test_datarray_scatter(x, y, z, hue, markersize, row, col, add_legend, add_colorbar): +def test_datarray_scatter( + x, y, z, hue, markersize, row, col, add_legend, add_colorbar +) -> None: """Test datarray scatter. Merge with TestPlot1D eventually.""" ds = xr.tutorial.scatter_example_dataset()