diff --git a/pyglotaran_extras/__init__.py b/pyglotaran_extras/__init__.py index 009a7f51..edaed120 100644 --- a/pyglotaran_extras/__init__.py +++ b/pyglotaran_extras/__init__.py @@ -3,6 +3,7 @@ from pyglotaran_extras.io.setup_case_study import setup_case_study from pyglotaran_extras.plotting.plot_coherent_artifact import plot_coherent_artifact from pyglotaran_extras.plotting.plot_data import plot_data_overview +from pyglotaran_extras.plotting.plot_doas import plot_doas from pyglotaran_extras.plotting.plot_guidance import plot_guidance from pyglotaran_extras.plotting.plot_irf_dispersion_center import plot_irf_dispersion_center from pyglotaran_extras.plotting.plot_overview import plot_overview @@ -15,12 +16,13 @@ "setup_case_study", "plot_coherent_artifact", "plot_data_overview", + "plot_doas", + "plot_guidance", + "plot_irf_dispersion_center", "plot_overview", "plot_simple_overview", "plot_fitted_traces", "select_plot_wavelengths", - "plot_guidance", - "plot_irf_dispersion_center", ] __version__ = "0.6.0" diff --git a/pyglotaran_extras/plotting/plot_doas.py b/pyglotaran_extras/plotting/plot_doas.py index dc777fcb..cf75dbdc 100644 --- a/pyglotaran_extras/plotting/plot_doas.py +++ b/pyglotaran_extras/plotting/plot_doas.py @@ -1,19 +1,22 @@ -"""Module containing plot functionality for damped oscillations (DOAS).""" +"""Module containing DOAS (Damped Oscillation) plotting functionality.""" from __future__ import annotations from typing import TYPE_CHECKING -from warnings import warn +from typing import Literal import matplotlib.pyplot as plt +import numpy as np +from cycler import Cycler -from pyglotaran_extras.deprecation.deprecation_utils import FIG_ONLY_WARNING -from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.style import PlotStyle +from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none +from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi +from pyglotaran_extras.plotting.utils import extract_irf_location +from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location if TYPE_CHECKING: - from cycler import Cycler from glotaran.project.result import Result from matplotlib.figure import Figure from matplotlib.pyplot import Axes @@ -22,74 +25,123 @@ def plot_doas( - result: DatasetConvertible | Result, - figsize: tuple[int, int] = (25, 25), + dataset: DatasetConvertible | Result, + *, + damped_oscillation: list[str] | None = None, + time_range: tuple[float, float] | None = None, + spectral: float = 0, + main_irf_nr: int | None = 0, + normalize: bool = False, + figsize: tuple[int, int] = (20, 5), + show_zero_line: bool = True, cycler: Cycler | None = PlotStyle().cycler, - figure_only: bool = True, -) -> Figure | tuple[Figure, Axes]: - """Plot Damped oscillation associated spectra (DOAS). + oscillation_type: Literal["cos", "sin"] = "cos", + title: str | None = "Damped oscillations", +) -> tuple[Figure, Axes]: + """Plot DOAS (Damped Oscillation) related data of the optimization result. Parameters ---------- - result: DatasetConvertible | Result - Result from a pyglotaran optimization as dataset, Path or Result object. - figsize : tuple[int, int] - Size of the figure (N, M) in inches. Defaults to (18, 16). - cycler : Cycler | None - Plot style cycler to use. Defaults to PlotStyle().cycler. - figure_only: bool - Whether or not to only return the figure. - This is a deprecation helper argument to transition to a consistent return value - consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True. + dataset: DatasetConvertible | Result + Result dataset from a pyglotaran optimization. + damped_oscillation: list[str] | None + List of oscillation names which should be plotted. + Defaults to None which means that all oscillations will be plotted. + time_range: tuple[float, float] | None + Start and end time for the Oscillation plot, if ``main_irf_nr`` is not None the value are + relative to the IRF location. Defaults to None which means that the full time range is + used. + spectral: float + Value of the spectral axis that should be used to select the data for the Oscillation + plot this value does not need to be an exact existing value and only has effect if the + IRF has dispersion. Defaults to 0 which means that the Oscillation plot at lowest + spectral value will be shown. + main_irf_nr: int | None + Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks + and is used to shift the time axis. If it is none ``None`` the shifting will be + deactivated. Defaults to 0. + normalize: bool + Whether or not to normalize the DOAS spectra plot. If the DOAS spectra is normalized, + the Oscillation is scaled with the reciprocal of the normalization to compensate for this. + Defaults to False. + figsize: tuple[int, int] + Size of the figure (N, M) in inches. Defaults to (20, 5) + show_zero_line: bool + Whether or not to add a horizontal line at zero. Defaults to True + cycler: Cycler | None + Plot style cycler to use. Defaults to PlotStyle().cycler + oscillation_type: Literal["cos", "sin"] + Type of the oscillation to show in the oscillation plot. Defaults to "cos" + title: str | None + Title of the figure. Defaults to "Damped oscillations" Returns ------- - Figure|tuple[Figure, Axes] - If ``figure_only`` is True, Figure object which contains the plots (deprecated). - If ``figure_only`` is False, Figure object which contains the plots and the Axes. + tuple[Figure, Axes] + Figure object which contains the plots and the Axes. + + See Also + -------- + calculate_ticks_in_units_of_pi """ - dataset = load_data(result) - - # Create M x N plotting grid - M = 6 - N = 3 - - fig, axes = plt.subplots(M, N, figsize=figsize) - - for ax in axes.flatten(): - add_cycler_if_not_none(ax, cycler) - - # Plot data - dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0]) - dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1]) - - if "spectral" in dataset.species_concentration.coords: - dataset.species_concentration.isel(spectral=0).plot.line(x="time", ax=axes[1, 0]) - else: - dataset.species_concentration.plot.line(x="time", ax=axes[1, 0]) - axes[1, 0].set_xscale("symlog", linthreshx=1) - - if "dampened_oscillation_associated_spectra" in dataset: - dataset.dampened_oscillation_cos.isel(spectral=0).sel(time=slice(-1, 10)).plot.line( - x="time", ax=axes[1, 1] - ) - dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral", ax=axes[2, 0]) - dataset.dampened_oscillation_phase.plot.line(x="spectral", ax=axes[2, 1]) - - dataset.residual_left_singular_vectors.isel(left_singular_value_index=0).plot(ax=axes[0, 2]) - dataset.residual_singular_values.plot.line("ro-", yscale="log", ax=axes[1, 2]) - dataset.residual_right_singular_vectors.isel(right_singular_value_index=0).plot(ax=axes[2, 2]) - - interval = int(dataset.spectral.size / 11) - for i in range(0): - axi = axes[i % 3, int(i / 3) + 3] - index = (i + 1) * interval - dataset.data.isel(spectral=index).plot(ax=axi) - dataset.residual.isel(spectral=index).plot(ax=axi) - dataset.fitted_data.isel(spectral=index).plot(ax=axi) - - plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0) - if figure_only is False: - return fig, axes - warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2) - return fig + dataset = load_data(dataset, _stacklevel=3) + + fig, axes = plt.subplots(1, 3, figsize=figsize) + + add_cycler_if_not_none(axes, cycler) + + time_sel_kwargs = ( + {"time": slice(time_range[0], time_range[1])} if time_range is not None else {} + ) + osc_sel_kwargs = ( + {"damped_oscillation": damped_oscillation} if damped_oscillation is not None else {} + ) + + irf_location = extract_irf_location(dataset, spectral, main_irf_nr) + + oscillations = shift_time_axis_by_irf_location( + dataset[f"damped_oscillation_{oscillation_type}"] + .sel(spectral=spectral, method="nearest") + .sel(**osc_sel_kwargs), + irf_location, + ) + oscillations_spectra = dataset["damped_oscillation_associated_spectra"].sel(**osc_sel_kwargs) + + damped_oscillation_phase = dataset["damped_oscillation_phase"].sel(**osc_sel_kwargs) + + osc_max = abs_max((oscillations - 1), result_dims="damped_oscillation") + spectra_max = abs_max(oscillations_spectra, result_dims="damped_oscillation") + scales = np.sqrt(osc_max * spectra_max) + + norm_factor = scales.max() if normalize is True else 1 + + ((oscillations - 1) / osc_max * scales * norm_factor).sel(**time_sel_kwargs).plot.line( + x="time", ax=axes[0] + ) + + (oscillations_spectra / spectra_max * scales / norm_factor).plot.line(x="spectral", ax=axes[1]) + + damped_oscillation_phase.plot.line(x="spectral", ax=axes[2]) + + axes[0].set_title(f"{oscillation_type.capitalize()} Oscillations") + axes[1].set_title("Spectra") + axes[2].set_title("Phases") + + axes[1].set_ylabel("Normalized DOAS" if normalize is True else "DOAS") + + axes[2].set_yticks( + *calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal" + ) + axes[2].set_ylabel("Phase (π)") + + axes[1].get_legend().remove() + axes[2].get_legend().remove() + + if show_zero_line is True: + [ax.axhline(0, color="k", linewidth=1) for ax in axes.flatten()] + + if title: + fig.suptitle(title, fontsize=16) + + fig.tight_layout() + return fig, axes diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 8139db46..455dc7cf 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -8,6 +8,7 @@ import numpy as np import xarray as xr +from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable from pyglotaran_extras.io.utils import result_dataset_mapping if TYPE_CHECKING: @@ -109,7 +110,7 @@ def extract_irf_dispersion_center( def extract_irf_location( - res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int = 0 + res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int | None = 0 ) -> float: """Determine location of the ``irf``, which can be used to shift plots. @@ -120,14 +121,16 @@ def extract_irf_location( center_λ: float | None Center wavelength (λ in nm) main_irf_nr : int - Index of the main ``irf`` component when using an ``irf`` - parametrized with multiple peaks. Defaults to 0. + Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks. + If it is none ``None`` the location will be 0. Defaults to 0. Returns ------- float Location of the ``irf`` """ + if main_irf_nr is None: + return 0 irf_dispersion_center = extract_irf_dispersion_center( res=res, main_irf_nr=main_irf_nr, as_dataarray=False ) @@ -406,3 +409,57 @@ def abs_max( result_dims = (result_dims,) reduce_dims = (dim for dim in data.dims if dim not in result_dims) return np.abs(data).max(dim=reduce_dims) + + +def calculate_ticks_in_units_of_pi( + values: np.ndarray | xr.DataArray, *, step_size: float = 0.5 +) -> tuple[Iterable[float], Iterable[str]]: + """Calculate tick values and labels in units of Pi. + + Parameters + ---------- + values: np.ndarray + Values which the ticks should be calculated for. + step_size: float + Step size of the ticks in units of pi. Defaults to 0.5 + + Returns + ------- + tuple[Iterable[float], Iterable[str]] + Tick values and tick labels + + See Also + -------- + pyglotaran_extras.plotting.plot_doas.plot_doas + + Examples + -------- + If you have a case study that uses a ``damped-oscillation`` megacomplex you can plot the + ``damped_oscillation_phase`` with y-tick in units of Pi by the following code given that the + dataset is saved under ``dataset.nc``. + + .. code-block:: python + import matplotlib.pyplot as plt + + from glotaran.io import load_dataset + from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi + + dataset = load_dataset("dataset.nc") + + fig, ax = plt.subplots(1, 1) + + damped_oscillation_phase = dataset["damped_oscillation_phase"].sel( + damped_oscillation=["osc1"] + ) + damped_oscillation_phase.plot.line(x="spectral", ax=ax) + + ax.set_yticks( + *calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal" + ) + """ + values = np.array(values) + int_values_over_pi = np.round(values / np.pi / step_size) + tick_labels = np.arange(int_values_over_pi.min(), int_values_over_pi.max() + 1) * step_size + return tick_labels * np.pi, ( + str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1) + ) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index c5681207..aede5a88 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -6,6 +6,7 @@ import matplotlib import matplotlib.pyplot as plt +import numpy as np import pytest import xarray as xr from cycler import Cycler @@ -14,6 +15,7 @@ from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none +from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi matplotlib.use("Agg") DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"] @@ -65,3 +67,21 @@ def test_abs_max(result_dims: Hashable | Iterable[Hashable], expected: xr.DataAr """Result values are positive and dimensions are preserved if result_dims is not empty.""" data = xr.DataArray([[-10, 20], [-30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]}) assert abs_max(data, result_dims=result_dims).equals(expected) + + +@pytest.mark.parametrize( + "step_size, expected_tick_values,expected_tick_labels", + ( + (0.5, np.linspace(-np.pi, 2 * np.pi, num=7), ["-1", "-0.5", "0", "0.5", "1", "1.5", "2"]), + (1, np.linspace(-np.pi, 2 * np.pi, num=4), ["-1", "0", "1", "2"]), + ), +) +def test_calculate_ticks_in_units_of_pi( + step_size: float, expected_tick_values: list[float], expected_tick_labels: list[str] +): + """Different values depending on ``step_size``.""" + values = np.linspace(-np.pi, 2 * np.pi) + tick_values, tick_labels = calculate_ticks_in_units_of_pi(values, step_size=step_size) + + assert np.allclose(list(tick_values), expected_tick_values) + assert list(tick_labels) == expected_tick_labels