diff --git a/doc/api-hidden.rst b/doc/api-hidden.rst index 1900c208532..d2b6de64345 100644 --- a/doc/api-hidden.rst +++ b/doc/api-hidden.rst @@ -391,6 +391,7 @@ plot.imshow plot.pcolormesh plot.scatter + plot.lines plot.surface CFTimeIndex.all diff --git a/doc/api.rst b/doc/api.rst index 342ae08e1a4..d8c447e960a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -1128,6 +1128,7 @@ Dataset :template: autosummary/accessor_method.rst Dataset.plot.scatter + Dataset.plot.lines Dataset.plot.quiver Dataset.plot.streamplot @@ -1152,6 +1153,7 @@ DataArray DataArray.plot.pcolormesh DataArray.plot.step DataArray.plot.scatter + DataArray.plot.lines DataArray.plot.surface diff --git a/doc/user-guide/plotting.rst b/doc/user-guide/plotting.rst index 42cbd1eb5b0..6455bc9da3d 100644 --- a/doc/user-guide/plotting.rst +++ b/doc/user-guide/plotting.rst @@ -847,6 +847,75 @@ And adding the z-axis For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. +Lines +~~~~~ + +.. ipython:: python + :suppress: + + plt.close("all") + +:py:func:`xarray.plot.lines` calls matplotlib.collections.LineCollection under the hood, +allowing multiple lines being drawn efficiently. It uses similar arguments as +:py:func:`xarray.plot.scatter`. + +Let's return to the air temperature dataset: + +.. ipython:: python + :okwarning: + + airtemps = xr.tutorial.open_dataset("air_temperature") + air = airtemps.air - 273.15 + air.attrs = airtemps.air.attrs + air.attrs["units"] = "deg C" + + @savefig lines_air_hue.png + air.isel(lon=10).plot.lines(x="time", hue="lat") + +Make it a little more transparent: + +.. ipython:: python + :okwarning: + + @savefig lines_air_hue_alpha.png + air.isel(lon=10).plot.lines(x="time", hue="lat", alpha=0.2) + +Zoom in a little on the x-axis, and compare a few latitudes and longitudes, +group them using ``hue`` and ``linewidth``. The ``linewidth`` kwarg works in +a similar way as ``markersize`` kwarg for scatter plots, it lets you vary the +line's size by variable value. + +.. ipython:: python + :okwarning: + + air_zoom = air.isel(time=slice(1200, 1500), lat=[5, 10, 15], lon=[10, 15]) + + @savefig lines_hue_linewidth.png + air_zoom.plot.lines(x="time", hue="lat", linewidth="lon", add_colorbar=False) + +Lines can modify the linestyle but does not allow markers. Instead combine :py:func:`xarray.plot.lines` +with :py:func:`xarray.plot.scatter`: + +.. ipython:: python + :okwarning: + + air.isel(lat=10, lon=10)[:200].plot.lines(x="time", color="k", linestyle="dashed") + air.isel(lat=10, lon=10)[:200].plot.scatter(x="time", color="k", marker="^") + @savefig lines_linestyle_marker.png + plt.draw() + + +Switching to another dataset with more variables we can analyse in similar +fashion as :py:func:`xarray.plot.scatter`: + +.. ipython:: python + :okwarning: + + ds = xr.tutorial.scatter_example_dataset(seed=42) + + @savefig lines_xyhuewidthrowcol.png + ds.plot.lines(x="A", y="B", hue="y", linewidth="x", row="x", col="w") + Quiver ~~~~~~ diff --git a/doc/whats-new.rst b/doc/whats-new.rst index aab51d71b09..da17d50008d 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -36,6 +36,10 @@ New Features iso8601-parser (:pull:`9885`). By `Kai Mühlbauer `_. +- Added new plot method :py:meth:`DataArray.plot.lines` which allows creating line plots efficiently in + a similiar manner to :py:meth:`DataArray.plot.scatter`, also available for datasets. (:pull:`7173`) + By `Jimmy Westling `_. + Breaking changes ~~~~~~~~~~~~~~~~ - Methods including ``dropna``, ``rank``, ``idxmax``, ``idxmin`` require diff --git a/xarray/plot/__init__.py b/xarray/plot/__init__.py index 49f12e13bfc..58721ab7e0e 100644 --- a/xarray/plot/__init__.py +++ b/xarray/plot/__init__.py @@ -13,6 +13,7 @@ hist, imshow, line, + lines, pcolormesh, plot, step, @@ -28,6 +29,7 @@ "hist", "imshow", "line", + "lines", "pcolormesh", "plot", "scatter", diff --git a/xarray/plot/accessor.py b/xarray/plot/accessor.py index 9db4ae4e3f7..426b0b80497 100644 --- a/xarray/plot/accessor.py +++ b/xarray/plot/accessor.py @@ -135,6 +135,130 @@ def line( def line(self, *args, **kwargs) -> list[Line3D] | FacetGrid[DataArray]: return dataarray_plot.line(self._da, *args, **kwargs) + @overload + def lines( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> LineCollection: ... + + @overload + def lines( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def lines( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataarray_plot.lines) + def lines(self, *args, **kwargs) -> LineCollection | FacetGrid[DataArray]: + return dataarray_plot.lines(self._da, *args, **kwargs) + @overload def step( # type: ignore[misc,unused-ignore] # None is hashable :( self, @@ -923,6 +1047,130 @@ def __call__(self, *args, **kwargs) -> NoReturn: "an explicit plot method, e.g. ds.plot.scatter(...)" ) + @overload + def lines( # type: ignore[misc,unused-ignore] # None is hashable :( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> LineCollection: ... + + @overload + def lines( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @overload + def lines( + self, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap=None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend=None, + levels=None, + **kwargs: Any, + ) -> FacetGrid[DataArray]: ... + + @functools.wraps(dataset_plot.lines) + def lines(self, *args, **kwargs) -> LineCollection | FacetGrid[DataArray]: + return dataset_plot.lines(self._ds, *args, **kwargs) + @overload def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( self, diff --git a/xarray/plot/dataarray_plot.py b/xarray/plot/dataarray_plot.py index cca9fe4f561..fda1422b83c 100644 --- a/xarray/plot/dataarray_plot.py +++ b/xarray/plot/dataarray_plot.py @@ -23,6 +23,7 @@ _guess_coords_to_plot, _infer_interval_breaks, _infer_xy_labels, + _line, _Normalize, _process_cmap_cbar_kwargs, _rescale_imshow_rgb, @@ -36,7 +37,7 @@ if TYPE_CHECKING: from matplotlib.axes import Axes - from matplotlib.collections import PathCollection, QuadMesh + from matplotlib.collections import LineCollection, PathCollection, QuadMesh from matplotlib.colors import Colormap, Normalize from matplotlib.container import BarContainer from matplotlib.contour import QuadContourSet @@ -186,18 +187,32 @@ def _prepare_plot1d_data( """ # If there are more than 1 dimension in the array than stack all the # dimensions so the plotter can plot anything: - if darray.ndim > 1: + if darray.ndim >= 2: # When stacking dims the lines will continue connecting. For floats # this can be solved by adding a nan element in between the flattening # points: - dims_T = [] - if np.issubdtype(darray.dtype, np.floating): - for v in ["z", "x"]: - dim = coords_to_plot.get(v, None) - if (dim is not None) and (dim in darray.dims): - darray_nan = np.nan * darray.isel({dim: -1}) - darray = concat([darray, darray_nan], dim=dim) - dims_T.append(coords_to_plot[v]) + dims_T: list[Hashable] = [] + if plotfunc_name == "lines" and np.issubdtype(darray.dtype, np.floating): + i = 0 + for v in ("z", "x"): + coord = coords_to_plot.get(v, None) + if coord is not None: + if coord in darray.dims: + # Dimension coordinate: + d = coord + else: + # Coordinate with multiple dimensions: + c = darray[coord] + dims_filt = dict.fromkeys(c.dims) + for k in dims_filt.keys() & set(dims_T): + dims_filt.pop(k) + + d = tuple(dims_filt.keys())[i] + + darray_nan = np.nan * darray.isel({d: -1}) + darray = concat([darray, darray_nan], dim=d) + dims_T.append(d) + # i += 1 # Lines should never connect to the same coordinate when stacked, # transpose to avoid this as much as possible: @@ -472,6 +487,10 @@ def line( primitive : list of Line3D or FacetGrid When either col or row is given, returns a FacetGrid, otherwise a list of matplotlib Line3D objects. + + See also + -------- + Use :py:func:`xarray.plot.lines` for efficient plotting of many lines. """ # Handle facetgrids first if row or col: @@ -1038,7 +1057,7 @@ def newplotfunc( ) if add_legend_: - if plotfunc.__name__ in ["scatter", "line"]: + if plotfunc.__name__ in ["scatter", "lines"]: _add_legend( ( hueplt_norm @@ -1108,6 +1127,177 @@ def _add_labels( _set_concise_date(ax, axis=axis) +@overload +def lines( # type: ignore[misc,unused-ignore] # None is hashable :( + darray: DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> LineCollection: ... + + +@overload +def lines( + darray: T_DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[DataArray]: ... + + +@overload +def lines( + darray: T_DataArray, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs, +) -> FacetGrid[DataArray]: ... + + +@_plot1d +def lines( + xplt: DataArray | None, + yplt: DataArray | None, + ax: Axes, + add_labels: bool | Iterable[bool] = True, + **kwargs, +) -> LineCollection: + """ + Line plot of DataArray values. + + Wraps :func:`matplotlib:matplotlib.collections.LineCollection` which allows + efficient plotting of many lines in a similar fashion to + :py:func:`xarray.plot.scatter`. + """ + if "u" in kwargs or "v" in kwargs: + raise ValueError("u, v are not allowed in lines plots.") + + zplt: DataArray | None = kwargs.pop("zplt", None) + hueplt: DataArray | None = kwargs.pop("hueplt", None) + sizeplt: DataArray | None = kwargs.pop("sizeplt", None) + + if hueplt is not None: + kwargs.update(c=hueplt.to_numpy().ravel()) + + if sizeplt is not None: + kwargs.update(s=sizeplt.to_numpy().ravel()) + + plts_or_none = (xplt, yplt, zplt) + _add_labels(add_labels, plts_or_none, ("", "", ""), ax) + + xplt_np = None if xplt is None else xplt.to_numpy().ravel() + yplt_np = None if yplt is None else yplt.to_numpy().ravel() + zplt_np = None if zplt is None else zplt.to_numpy().ravel() + plts_np = tuple(p for p in (xplt_np, yplt_np, zplt_np) if p is not None) + + if len(plts_np) == 3: + import mpl_toolkits + + assert isinstance(ax, mpl_toolkits.mplot3d.axes3d.Axes3D) + return _line(ax, *plts_np, **kwargs) + + if len(plts_np) == 2: + return _line(ax, *plts_np, **kwargs) + + raise ValueError("At least two variables required for a lines plot.") + + @overload def scatter( # type: ignore[misc,unused-ignore] # None is hashable :( darray: DataArray, diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py index 44d4ffa676a..e5aafd3e7d5 100644 --- a/xarray/plot/dataset_plot.py +++ b/xarray/plot/dataset_plot.py @@ -916,3 +916,180 @@ def scatter( da = _temp_dataarray(ds, y, locals_) return da.plot.scatter(*locals_.pop("args", ()), **locals_) + + +@overload +def lines( # type: ignore[misc,unused-ignore] # None is hashable :( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: None = None, # no wrap -> primitive + col: None = None, # no wrap -> primitive + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> LineCollection: ... + + +@overload +def lines( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable, # wrap -> FacetGrid + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +@overload +def lines( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable, # wrap -> FacetGrid + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> FacetGrid[DataArray]: ... + + +@_update_doc_to_dataset(dataarray_plot.lines) +def lines( + ds: Dataset, + *args: Any, + x: Hashable | None = None, + y: Hashable | None = None, + z: Hashable | None = None, + hue: Hashable | None = None, + hue_style: HueStyleOptions = None, + markersize: Hashable | None = None, + linewidth: Hashable | None = None, + figsize: Iterable[float] | None = None, + size: float | None = None, + aspect: float | None = None, + ax: Axes | None = None, + row: Hashable | None = None, + col: Hashable | None = None, + col_wrap: int | None = None, + xincrease: bool | None = True, + yincrease: bool | None = True, + add_legend: bool | None = None, + add_colorbar: bool | None = None, + add_labels: bool | Iterable[bool] = True, + add_title: bool = True, + subplot_kws: dict[str, Any] | None = None, + xscale: ScaleOptions = None, + yscale: ScaleOptions = None, + xticks: ArrayLike | None = None, + yticks: ArrayLike | None = None, + xlim: ArrayLike | None = None, + ylim: ArrayLike | None = None, + cmap: str | Colormap | None = None, + vmin: float | None = None, + vmax: float | None = None, + norm: Normalize | None = None, + extend: ExtendOptions = None, + levels: ArrayLike | None = None, + **kwargs: Any, +) -> LineCollection | FacetGrid[DataArray]: + """ + Line plot Dataset data variables against each other. + + Wraps :func:`matplotlib:matplotlib.collections.LineCollection` which allows + efficient plotting of many lines in a similar fashion to + :py:func:`xarray.plot.scatter`. + """ + locals_ = locals() + del locals_["ds"] + locals_.update(locals_.pop("kwargs", {})) + da = _temp_dataarray(ds, y, locals_) + + return da.plot.lines(*locals_.pop("args", ()), **locals_) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 82105d5fb6a..9175e46de62 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -481,7 +481,7 @@ def map_plot1d( func_kwargs["add_title"] = False add_labels_ = np.zeros(self.axs.shape + (3,), dtype=bool) - if kwargs.get("z") is not None: + if coords_to_plot["z"] is not None: # 3d plots looks better with all labels. 3d plots can't sharex either so it # is easy to get lost while rotating the plots: add_labels_[:] = True @@ -494,10 +494,10 @@ def map_plot1d( # Set up the lists of names for the row and column facet variables: if self._single_group: full = tuple( - {self._single_group: x} - for x in range(self.data[self._single_group].size) + {self._single_group: v} + for v in range(self.data[self._single_group].size) ) - empty = tuple(None for x in range(self._nrow * self._ncol - len(full))) + empty = (None,) * (self._nrow * self._ncol - len(full)) name_d = full + empty else: rowcols = itertools.product( @@ -520,8 +520,8 @@ def map_plot1d( subset = self.data.isel(d) mappable = func( subset, - x=x, - y=y, + x=coords_to_plot["x"], + z=coords_to_plot["z"], ax=ax, hue=hue, _size=size_, diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index c1f0cc11f54..0e6e0c86710 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -38,8 +38,12 @@ if TYPE_CHECKING: from matplotlib.axes import Axes - from matplotlib.colors import Normalize + from matplotlib.collections import LineCollection + from matplotlib.colors import Colormap, Normalize + from matplotlib.lines import Line2D from matplotlib.ticker import FuncFormatter + from matplotlib.typing import ColorType, DrawStyleType, LineStyleType + from mpl_toolkits.mplot3d.art3d import Line3DCollection from numpy.typing import ArrayLike from xarray.core.dataarray import DataArray @@ -1076,9 +1080,9 @@ def _get_color_and_size(value): elif prop == "sizes": if isinstance(self, mpl.collections.LineCollection): - arr = self.get_linewidths() + arr = np.ma.asarray(self.get_linewidths()) else: - arr = self.get_sizes() + arr = np.ma.asarray(self.get_sizes()) _color = kwargs.pop("color", "k") def _get_color_and_size(value): @@ -1091,7 +1095,7 @@ def _get_color_and_size(value): ) # Get the unique values and their labels: - values = np.unique(arr) + values = np.unique(arr[~arr.mask]) label_values = np.asarray(func(values)) label_values_are_numeric = np.issubdtype(label_values.dtype, np.number) @@ -1735,16 +1739,19 @@ def _add_legend( # values correctly. Order might be different because # legend_elements uses np.unique instead of pd.unique, # FacetGrid.add_legend might have troubles with this: - hdl, lbl = [], [] + hdl: list[Line2D] = [] + lbl: list[str] = [] for p in primitive: hdl_, lbl_ = legend_elements(p, prop, num="auto", func=huesizeplt.func) hdl += hdl_ lbl += lbl_ - # Only save unique values: - u, ind = np.unique(lbl, return_index=True) - ind = np.argsort(ind) - lbl = u[ind].tolist() + # Only save unique values, don't sort values as it was already sort in + # legend_elements: + lbl_ = np.array(lbl) + _, ind = np.unique(lbl_, return_index=True) + ind = np.sort(ind) + lbl = lbl_[ind].tolist() hdl = np.array(hdl)[ind].tolist() # Add a subtitle: @@ -1848,6 +1855,251 @@ def _guess_coords_to_plot( return coords_to_plot +@overload +def _line( + self, # Axes, + x: float | ArrayLike, + y: float | ArrayLike, + z: None = ..., + s: float | ArrayLike | None = ..., + c: Sequence[ColorType] | ColorType | None = ..., + linestyle: LineStyleType | None = ..., + cmap: str | Colormap | None = ..., + norm: str | Normalize | None = ..., + vmin: float | None = ..., + vmax: float | None = ..., + alpha: float | None = ..., + linewidths: float | Sequence[float] | None = ..., + *, + edgecolors: Literal["face", "none"] | ColorType | Sequence[ColorType] | None = ..., + plotnonfinite: bool = ..., + data=..., + **kwargs, +) -> LineCollection: ... + + +@overload +def _line( + self, # Axes3D, + x: float | ArrayLike, + y: float | ArrayLike, + z: float | ArrayLike = ..., + s: float | ArrayLike | None = ..., + c: Sequence[ColorType] | ColorType | None = ..., + linestyle: LineStyleType | None = ..., + cmap: str | Colormap | None = ..., + norm: str | Normalize | None = ..., + vmin: float | None = ..., + vmax: float | None = ..., + alpha: float | None = ..., + linewidths: float | Sequence[float] | None = ..., + *, + edgecolors: Literal["face", "none"] | ColorType | Sequence[ColorType] | None = ..., + plotnonfinite: bool = ..., + data=..., + drawstyle: DrawStyleType = ..., + **kwargs, +) -> Line3DCollection: ... + + +def _line( + self, # Axes | Axes3D + x: float | ArrayLike, + y: float | ArrayLike, + z: float | ArrayLike | None = None, + s: float | ArrayLike | None = None, + c: Sequence[ColorType] | ColorType | None = None, + linestyle: LineStyleType | None = None, + cmap: str | Colormap | None = None, + norm: str | Normalize | None = None, + vmin: float | None = None, + vmax: float | None = None, + alpha: float | None = None, + linewidths: float | Sequence[float] | None = None, + *, + edgecolors: Literal["face", "none"] | ColorType | Sequence[ColorType] | None = None, + plotnonfinite: bool = False, + data=None, + drawstyle: DrawStyleType = "default", + **kwargs, +) -> LineCollection | Line3DCollection: + """ + ax.scatter-like wrapper for LineCollection. + + This function helps the handling of datetimes since Linecollection doesn't + support it directly, just like PatchCollection doesn't either. + + The function attempts to be as similar to the scatter version as possible. + """ + import matplotlib.cbook as cbook + import matplotlib.collections as mcoll + import matplotlib.pyplot as plt + from matplotlib import _api + + rcParams = plt.matplotlib.rcParams + + def _parse_lines_color_args( + self, c, edgecolors, kwargs, xsize, get_next_color_func + ): + if edgecolors is None: + # Use "face" instead of rcParams['scatter.edgecolors'] + edgecolors = "face" + + c, colors, edgecolors = self._parse_scatter_color_args( + c, + edgecolors, + kwargs, + x_.size, + get_next_color_func=self._get_patches_for_fill.get_next_color, + ) + + return c, colors, edgecolors + + # add edgecolors and linewidths to kwargs so they + # can be processed by normailze_kwargs + if edgecolors is not None: + kwargs.update({"edgecolors": edgecolors}) + if linewidths is not None: + kwargs.update({"linewidths": linewidths}) + + kwargs = cbook.normalize_kwargs(kwargs, mcoll.Collection) + # re direct linewidth and edgecolor so it can be + # further processed by the rest of the function + linewidths = kwargs.pop("linewidth", None) + edgecolors = kwargs.pop("edgecolor", None) + + # Process **kwargs to handle aliases, conflicts with explicit kwargs: + x_: np.ndarray + y_: np.ndarray + x_, y_ = self._process_unit_info( + [("x", x), ("y", y)], kwargs + ) # type ignore[union-attr] + + # Handle z inputs: + if z is not None: + from mpl_toolkits.mplot3d.art3d import Line3DCollection + + LineCollection_ = Line3DCollection + add_collection_ = self.add_collection3d + auto_scale = self.auto_scale_xyz + auto_scale_args: tuple[Any, ...] = (x_, y_, z, self.has_data()) + else: + LineCollection_ = plt.matplotlib.collections.LineCollection + add_collection_ = self.add_collection + auto_scale = self._request_autoscale_view + auto_scale_args = tuple() + + if s is None: + s = np.array([rcParams["lines.linewidth"]]) + + s_: np.ndarray = np.ma.ravel(s) + if len(s_) not in (1, x_.size) or ( + not np.issubdtype(s_.dtype, np.floating) + and not np.issubdtype(s_.dtype, np.integer) + ): + raise ValueError( + "s must be a scalar, " "or float array-like with the same size as x and y" + ) + + # get the original edgecolor the user passed before we normalize + orig_edgecolor = edgecolors + if edgecolors is None: + orig_edgecolor = kwargs.get("edgecolor", None) + c, colors, edgecolors = _parse_lines_color_args( + self, + c, + edgecolors, + kwargs, + x_.size, + get_next_color_func=self._get_patches_for_fill.get_next_color, + ) + + if plotnonfinite and colors is None: + c = np.ma.masked_invalid(c) + ( + x_, + y_, + s_, + edgecolors, + linewidths, + ) = cbook._combine_masks( # type: ignore[attr-defined] # non-public? + x_, y_, s_, edgecolors, linewidths + ) + else: + ( + x_, + y_, + s_, + c, + colors, + edgecolors, + linewidths, + ) = cbook._combine_masks( # type: ignore[attr-defined] # non-public? + x_, y_, s_, c, colors, edgecolors, linewidths + ) + + # Unmask edgecolors if it was actually a single RGB or RGBA. + if ( + x_.size in (3, 4) + and isinstance(edgecolors, np.ma.MaskedArray) + and not np.ma.is_masked(orig_edgecolor) + ): + edgecolors = edgecolors.data + + # load default linestyle from rcParams + if linestyle is None: + linestyle = rcParams["lines.linestyle"] + + if drawstyle == "default": + # Draw linear lines: + xyz = list(v for v in (x_, y_, z) if v is not None) + else: + # Draw stepwise lines: + from matplotlib.cbook import STEP_LOOKUP_MAP + + step_func = STEP_LOOKUP_MAP[drawstyle] + xyz = step_func(*tuple(v for v in (x_, y_, z) if v is not None)) + + # Broadcast arrays to correct format: + # https://stackoverflow.com/questions/42215777/matplotlib-line-color-in-3d + points = np.stack(np.broadcast_arrays(*xyz), axis=-1).reshape(-1, 1, len(xyz)) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + collection = LineCollection_( + segments, + linewidths=s_, + linestyles=linestyle, + facecolors=colors, + edgecolors=edgecolors, + alpha=alpha, + # offset_transform=kwargs.pop("transform", self.transData), + ) + # collection.set_transform(plt.matplotlib.transforms.IdentityTransform()) + collection.update(kwargs) + + if colors is None: + collection.set_array(c) + collection.set_cmap(cmap) + collection.set_norm(norm) + collection._scale_norm(norm, vmin, vmax) + else: + extra_kwargs = {"cmap": cmap, "norm": norm, "vmin": vmin, "vmax": vmax} + extra_keys = [k for k, v in extra_kwargs.items() if v is not None] + if any(extra_keys): + keys_str = ", ".join(f"'{k}'" for k in extra_keys) + _api.warn_external( + "No data for colormapping provided via 'c'. " + f"Parameters {keys_str} will be ignored" + ) + collection._internal_update(kwargs) + + add_collection_(collection) + + auto_scale(*auto_scale_args) + + return collection + + def _set_concise_date(ax: Axes, axis: Literal["x", "y", "z"] = "x") -> None: """ Use ConciseDateFormatter which is meant to improve the diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 8e00b943de8..fc3a255868b 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -2920,6 +2920,57 @@ def test_legend_labels_facetgrid(self) -> None: ) assert actual == expected + def test_legend_labels_facegrid2(self) -> None: + ds = xr.tutorial.scatter_example_dataset(seed=42) + + g = ds.plot.scatter( + x="A", y="B", hue="y", markersize="x", row="x", col="w", add_colorbar=False + ) + + legend = g.figlegend + assert legend is not None + actual_text = [t.get_text() for t in legend.texts] + expected_text = [ + "y [yunits]", + "$\\mathdefault{0.0}$", + "$\\mathdefault{0.1}$", + "$\\mathdefault{0.2}$", + "$\\mathdefault{0.3}$", + "$\\mathdefault{0.4}$", + "$\\mathdefault{0.5}$", + "$\\mathdefault{0.6}$", + "$\\mathdefault{0.7}$", + "$\\mathdefault{0.8}$", + "$\\mathdefault{0.9}$", + "$\\mathdefault{1.0}$", + "x [xunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + ] + assert actual_text == expected_text + + actual_size = [v.get_markersize() for v in legend.get_lines()] + expected_size = [ + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 4.242640687119285, + 6.708203932499369, + 8.48528137423857, + ] + np.testing.assert_allclose(expected_size, actual_size) + def test_add_legend_by_default(self) -> None: sc = self.ds.plot.scatter(x="A", y="B", hue="hue") fig = sc.figure @@ -3288,8 +3339,9 @@ def test_maybe_gca() -> None: @requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["scatter", "lines"]) @pytest.mark.parametrize( - "x, y, z, hue, markersize, row, col, add_legend, add_colorbar", + "x, y, z, hue, _size, row, col, add_legend, add_colorbar", [ ("A", "B", None, None, None, None, None, None, None), ("B", "A", None, "w", None, None, None, True, None), @@ -3298,30 +3350,33 @@ def test_maybe_gca() -> None: ("B", "A", "z", "w", None, None, None, True, None), ("A", "B", "z", "y", "x", None, None, True, True), ("A", "B", "z", "y", "x", "w", None, True, True), + ("A", "B", "z", "y", "x", "w", "x", True, True), ], ) -def test_datarray_scatter( - x, y, z, hue, markersize, row, col, add_legend, add_colorbar +def test_plot1d_functions( + x: Hashable, + y: Hashable, + z: Hashable, + hue: Hashable, + _size: Hashable, + row: Hashable, + col: Hashable, + add_legend: bool | None, + add_colorbar: bool | None, + plotfunc: str, ) -> None: - """Test datarray scatter. Merge with TestPlot1D eventually.""" - ds = xr.tutorial.scatter_example_dataset() - - extra_coords = [v for v in [x, hue, markersize] if v is not None] - - # Base coords: - coords = dict(ds.coords) - - # Add extra coords to the DataArray: - coords.update({v: ds[v] for v in extra_coords}) - - darray = xr.DataArray(ds[y], coords=coords) + """Test plot1d function. Merge with TestPlot1D eventually.""" + ds = xr.tutorial.scatter_example_dataset(seed=42) with figure_context(): - darray.plot.scatter( + getattr(ds.plot, plotfunc)( x=x, + y=y, z=z, hue=hue, - markersize=markersize, + _size=_size, + row=row, + col=col, add_legend=add_legend, add_colorbar=add_colorbar, ) @@ -3449,6 +3504,107 @@ def test_plot1d_filtered_nulls() -> None: assert expected == actual +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["lines"]) +def test_plot1d_lines_color(plotfunc: str, x="z", color="b") -> None: + from matplotlib.colors import to_rgba_array + + ds = xr.tutorial.scatter_example_dataset(seed=42) + + darray = ds.A.sel(x=0, y=0) + + with figure_context(): + fig, ax = plt.subplots() + getattr(darray.plot, plotfunc)(x=x, color=color) + coll = ax.collections[0] + + # Make sure color is respected: + expected_color = np.asarray(to_rgba_array(color)) + actual_color = np.asarray(coll.get_edgecolor()) + np.testing.assert_allclose(expected_color, actual_color) + + +@requires_matplotlib +@pytest.mark.parametrize("plotfunc", ["lines"]) +def test_plot1d_lines_linestyle(plotfunc: str, x="z", linestyle="dashed") -> None: + # TODO: Is there a public function that converts linestyle to dash pattern? + from matplotlib.lines import ( # type: ignore[attr-defined] + _get_dash_pattern, + _scale_dashes, + ) + + ds = xr.tutorial.scatter_example_dataset(seed=42) + + darray = ds.A.sel(x=0, y=0) + + with figure_context(): + fig, ax = plt.subplots() + getattr(darray.plot, plotfunc)(x=x, linestyle=linestyle) + coll = ax.collections[0] + + # Make sure linestyle is respected: + w = np.atleast_1d(coll.get_linewidth())[0] + expected_linestyle = [_scale_dashes(*_get_dash_pattern(linestyle), w)] + actual_linestyle = coll.get_linestyle() + assert expected_linestyle == actual_linestyle + + +@requires_matplotlib +def test_plot1d_lines_facetgrid_legend() -> None: + # asserts that order is correct, only unique values, no nans/masked values. + + ds = xr.tutorial.scatter_example_dataset(seed=42) + + with figure_context(): + g = ds.plot.lines( + x="A", y="B", hue="y", linewidth="x", row="x", col="w", add_colorbar=False + ) + + legend = g.figlegend + assert legend is not None + actual_text = [t.get_text() for t in legend.texts] + expected_text = [ + "y [yunits]", + "$\\mathdefault{0.0}$", + "$\\mathdefault{0.1}$", + "$\\mathdefault{0.2}$", + "$\\mathdefault{0.3}$", + "$\\mathdefault{0.4}$", + "$\\mathdefault{0.5}$", + "$\\mathdefault{0.6}$", + "$\\mathdefault{0.7}$", + "$\\mathdefault{0.8}$", + "$\\mathdefault{0.9}$", + "$\\mathdefault{1.0}$", + "x [xunits]", + "$\\mathdefault{0}$", + "$\\mathdefault{1}$", + "$\\mathdefault{2}$", + ] + assert expected_text == actual_text + + actual_size = [v.get_linewidth() for v in legend.get_lines()] + expected_size = [ + 1.5, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 6.0, + 1.5, + 1.224744871391589, + 1.9364916731037085, + 2.449489742783178, + ] + np.testing.assert_allclose(expected_size, actual_size) + + @requires_matplotlib def test_9155() -> None: # A test for types from issue #9155