Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

👌 Add das_cycler and svd_cycler to plot collection functions #218

Merged
merged 4 commits into from
Nov 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@

## 0.8.0 (Unreleased)

- 🧰👌 Switch tooling to ruff (#197)
- 🩹 Fix crash when plotting spectral model result (#200)
- 👷♻️ Use hatch as build backend (#204)
- 🧰 Use black-pre-commit-mirror for 2x speedup (#205)
- 🧰🚀 Use ruff for formatting (#214)
- 👌 Use weighted residual instead of residual plots if present (#216)
- 👌 Add color map arguments to plot_data_overview (#217)
- 👌 Add das_cycler and svd_cycler to plot collection functions (#218)

(changes-0_7_1)=

## 0.7.1 (2023-07-27)
Expand Down
14 changes: 13 additions & 1 deletion pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from pyglotaran_extras.plotting.plot_svd import plot_lsv_data
from pyglotaran_extras.plotting.plot_svd import plot_rsv_data
from pyglotaran_extras.plotting.plot_svd import plot_sv_data
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import not_single_element_dims
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location
Expand All @@ -22,6 +23,7 @@
from collections.abc import Hashable

import xarray as xr
from cycler import Cycler
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
Expand All @@ -41,6 +43,7 @@ def plot_data_overview(
cmap: str = "PuRd",
vmin: float | None = None,
vmax: float | None = None,
svd_cycler: Cycler | None = PlotStyle().cycler,
) -> tuple[Figure, Axes] | tuple[Figure, Axis]:
"""Plot data as filled contour plot and SVD components.

Expand Down Expand Up @@ -71,6 +74,8 @@ def plot_data_overview(
Lower value to anchor the colormap. Defaults to None meaning it inferred from the data.
vmax : float | None
Lower value to anchor the colormap. Defaults to None meaning it inferred from the data.
svd_cycler : Cycler | None
Plot style cycler to use for SVD plots. Defaults to ``PlotStyle().cycler``.

Returns
-------
Expand Down Expand Up @@ -111,9 +116,16 @@ def plot_data_overview(
linlog=linlog,
linthresh=linthresh,
irf_location=irf_location,
cycler=svd_cycler,
)
plot_sv_data(dataset, sv_ax)
plot_rsv_data(dataset, rsv_ax, indices=range(nr_of_data_svd_vectors), show_legend=False)
plot_rsv_data(
dataset,
rsv_ax,
indices=range(nr_of_data_svd_vectors),
show_legend=False,
cycler=svd_cycler,
)
if show_data_svd_legend is True:
rsv_ax.legend(title="singular value index", loc="lower right", bbox_to_anchor=(1.13, 1))
fig.suptitle(title, fontsize=16)
Expand Down
31 changes: 27 additions & 4 deletions pyglotaran_extras/plotting/plot_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import extract_irf_location
from pyglotaran_extras.types import Unset

if TYPE_CHECKING:
from cycler import Cycler
Expand All @@ -29,6 +30,7 @@
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import DatasetConvertible
from pyglotaran_extras.types import UnsetType


def plot_overview(
Expand All @@ -48,6 +50,8 @@ def plot_overview(
show_residual_svd_legend: bool = True,
show_irf_dispersion_center: bool = True,
show_zero_line: bool = True,
das_cycler: Cycler | None | UnsetType = Unset,
svd_cycler: Cycler | None | UnsetType = Unset,
) -> tuple[Figure, Axes]:
"""Plot overview of the optimization result.

Expand Down Expand Up @@ -95,13 +99,24 @@ def plot_overview(
show_zero_line : bool
Whether or not to add a horizontal line at zero to the plots of the spectra.
Defaults to True.
das_cycler : Cycler | None | UnsetType
Plot style cycler to use for DAS plots. Defaults to ``Unset`` which means that the value
of ``cycler`` is used.
svd_cycler : Cycler | None | UnsetType
Plot style cycler to use for SVD plots. Defaults to ``Unset`` which means that the value
of ``cycler`` is used.

Returns
-------
tuple[Figure, Axes]
"""
res = load_data(result, _stacklevel=3)

if das_cycler is Unset:
das_cycler = cycler
if svd_cycler is Unset:
svd_cycler = cycler

if res.coords["time"].to_numpy().size == 1:
fig, axes = plot_guidance(res)
if figure_only is not None:
Expand All @@ -125,13 +140,15 @@ def plot_overview(
main_irf_nr=main_irf_nr,
cycler=cycler,
)
plot_spectra(res, axes[0:2, 1:3], cycler=cycler, show_zero_line=show_zero_line)
plot_spectra(
res, axes[0:2, 1:3], cycler=cycler, show_zero_line=show_zero_line, das_cycler=das_cycler
)
plot_svd(
res,
axes[2:4, 0:3],
linlog=linlog,
linthresh=linthresh,
cycler=cycler,
cycler=svd_cycler,
nr_of_data_svd_vectors=nr_of_data_svd_vectors,
nr_of_residual_svd_vectors=nr_of_residual_svd_vectors,
show_data_svd_legend=show_data_svd_legend,
Expand Down Expand Up @@ -161,6 +178,7 @@ def plot_simple_overview(
figure_only: bool | None = None,
show_irf_dispersion_center: bool = True,
show_data: bool | None = False,
svd_cycler: Cycler | None | UnsetType = Unset,
) -> tuple[Figure, Axes]:
"""Plot simple overview.

Expand All @@ -182,12 +200,17 @@ def plot_simple_overview(
show_data : bool | None
Whether to show the input data or residual. If set to ``None`` the plot is skipped
which improves plotting performance for big datasets. Defaults to False.
svd_cycler : Cycler | None | UnsetType
Plot style cycler to use for SVD plots. Defaults to ``Unset`` which means that the value
of ``cycler`` is used.

Returns
-------
tuple[Figure, Axes]
"""
res = load_data(result, _stacklevel=3)
if svd_cycler is Unset:
svd_cycler = cycler

fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True)
for ax in axes.flatten():
Expand All @@ -200,8 +223,8 @@ def plot_simple_overview(

irf_location = extract_irf_location(res, center_λ=res.coords["spectral"].to_numpy()[0])

plot_lsv_residual(res, ax=axes[1, 0], irf_location=irf_location)
plot_rsv_residual(res, ax=axes[1, 1])
plot_lsv_residual(res, ax=axes[1, 0], irf_location=irf_location, cycler=svd_cycler)
plot_rsv_residual(res, ax=axes[1, 1], cycler=svd_cycler)

plot_residual(
res,
Expand Down
14 changes: 12 additions & 2 deletions pyglotaran_extras/plotting/plot_spectra.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,23 @@

from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.types import Unset

if TYPE_CHECKING:
import xarray as xr
from cycler import Cycler
from matplotlib.axis import Axis
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import UnsetType


def plot_spectra(
res: xr.Dataset,
axes: Axes,
cycler: Cycler | None = PlotStyle().cycler,
show_zero_line: bool = True,
das_cycler: Cycler | None | UnsetType = Unset,
) -> None:
"""Plot spectra such as SAS and DAS as well as their normalize version on ``axes``.

Expand All @@ -33,11 +37,17 @@ def plot_spectra(
Plot style cycler to use. Defaults to PlotStyle().cycler.
show_zero_line : bool
Whether or not to add a horizontal line at zero. Defaults to True.
das_cycler : Cycler | None | UnsetType
Plot style cycler to use for DAS plots. Defaults to ``Unset`` which means that the value
of ``cycler`` is used.
"""
if das_cycler is Unset:
das_cycler = cycler

plot_sas(res, axes[0, 0], cycler=cycler, show_zero_line=show_zero_line)
plot_das(res, axes[0, 1], cycler=cycler, show_zero_line=show_zero_line)
plot_das(res, axes[0, 1], cycler=das_cycler, show_zero_line=show_zero_line)
plot_norm_sas(res, axes[1, 0], cycler=cycler, show_zero_line=show_zero_line)
plot_norm_das(res, axes[1, 1], cycler=cycler, show_zero_line=show_zero_line)
plot_norm_das(res, axes[1, 1], cycler=das_cycler, show_zero_line=show_zero_line)


def plot_sas(
Expand Down
33 changes: 23 additions & 10 deletions pyglotaran_extras/plotting/plot_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@

from glotaran.io.prepare_dataset import add_svd_to_dataset

from pyglotaran_extras.deprecation import warn_deprecated
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import MinorSymLogLocator
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location
from pyglotaran_extras.types import Unset
from pyglotaran_extras.types import UnsetType

if TYPE_CHECKING:
from collections.abc import Sequence
Expand Down Expand Up @@ -80,7 +83,7 @@ def plot_svd(
show_legend=show_residual_svd_legend,
irf_location=irf_location,
)
plot_sv_residual(res, axes[0, 2], cycler=cycler)
plot_sv_residual(res, axes[0, 2])
add_svd_to_dataset(dataset=res, name="data")
plot_lsv_data(
res,
Expand All @@ -100,7 +103,7 @@ def plot_svd(
show_legend=show_data_svd_legend,
irf_location=irf_location,
)
plot_sv_data(res, axes[1, 2], cycler=cycler)
plot_sv_data(res, axes[1, 2])


def plot_lsv_data(
Expand Down Expand Up @@ -181,7 +184,7 @@ def plot_sv_data(
res: xr.Dataset,
ax: Axis,
indices: Sequence[int] = range(10),
cycler: Cycler | None = PlotStyle().cycler,
cycler: Cycler | None | UnsetType = Unset,
) -> None:
"""Plot singular values of the data matrix.

Expand All @@ -193,10 +196,15 @@ def plot_sv_data(
Axis to plot on.
indices : Sequence[int]
Indices of the singular vector to plot. Defaults to range(10).
cycler : Cycler | None
Plot style cycler to use. Defaults to PlotStyle().cycler.
cycler : Cycler | None | UnsetType
Deprecated since it has no effect. Defaults to Unset.
"""
add_cycler_if_not_none(ax, cycler)
if cycler is not Unset:
warn_deprecated(
deprecated_qual_name_usage="'cycler' argument in 'plot_sv_data'",
new_qual_name_usage="matplotlib on the axis directly",
to_be_removed_in_version="0.9.0",
)
dSV = res.data_singular_values # noqa: N806
dSV.sel(singular_value_index=indices[: len(dSV.singular_value_index)]).plot.line(
"ro-", yscale="log", ax=ax
Expand Down Expand Up @@ -288,7 +296,7 @@ def plot_sv_residual(
res: xr.Dataset,
ax: Axis,
indices: Sequence[int] = range(10),
cycler: Cycler | None = PlotStyle().cycler,
cycler: Cycler | None | UnsetType = Unset,
) -> None:
"""Plot singular values of the residual matrix.

Expand All @@ -300,10 +308,15 @@ def plot_sv_residual(
Axis to plot on.
indices : Sequence[int]
Indices of the singular vector to plot. Defaults to range(10).
cycler : Cycler | None
Plot style cycler to use. Defaults to PlotStyle().cycler.
cycler : Cycler | None | UnsetType
Deprecated since it has no effect. Defaults to Unset.
"""
add_cycler_if_not_none(ax, cycler)
if cycler is not Unset:
warn_deprecated(
deprecated_qual_name_usage="'cycler' argument in 'plot_sv_residual'",
new_qual_name_usage="matplotlib on the axis directly",
to_be_removed_in_version="0.9.0",
)
if "weighted_residual_singular_values" in res:
rSV = res.weighted_residual_singular_values # noqa: N806
else:
Expand Down
15 changes: 15 additions & 0 deletions pyglotaran_extras/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,21 @@
import xarray as xr
from glotaran.project.result import Result


class UnsetType:
"""Type for the ``Unset`` singleton."""

def __repr__(self) -> str: # noqa: DOC
"""Representation of instances in editors."""
return "Unset"


Unset = UnsetType()
"""Value to use as default for an arguments where None is a meaningful value.

This way we can prevent regressions.
"""

DatasetConvertible: TypeAlias = xr.Dataset | xr.DataArray | str | Path
"""Types of data which can be converted to a dataset."""
ResultLike: TypeAlias = (
Expand Down