Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨ Add plot_doas function that only plots DOAS related information #135

Merged
merged 5 commits into from
Feb 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pyglotaran_extras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pyglotaran_extras.io.setup_case_study import setup_case_study
from pyglotaran_extras.plotting.plot_coherent_artifact import plot_coherent_artifact
from pyglotaran_extras.plotting.plot_data import plot_data_overview
from pyglotaran_extras.plotting.plot_doas import plot_doas
from pyglotaran_extras.plotting.plot_guidance import plot_guidance
from pyglotaran_extras.plotting.plot_irf_dispersion_center import plot_irf_dispersion_center
from pyglotaran_extras.plotting.plot_overview import plot_overview
Expand All @@ -15,12 +16,13 @@
"setup_case_study",
"plot_coherent_artifact",
"plot_data_overview",
"plot_doas",
"plot_guidance",
"plot_irf_dispersion_center",
"plot_overview",
"plot_simple_overview",
"plot_fitted_traces",
"select_plot_wavelengths",
"plot_guidance",
"plot_irf_dispersion_center",
]

__version__ = "0.6.0"
188 changes: 120 additions & 68 deletions pyglotaran_extras/plotting/plot_doas.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,22 @@
"""Module containing plot functionality for damped oscillations (DOAS)."""
"""Module containing DOAS (Damped Oscillation) plotting functionality."""
from __future__ import annotations

from typing import TYPE_CHECKING
from warnings import warn
from typing import Literal

import matplotlib.pyplot as plt
import numpy as np
from cycler import Cycler

from pyglotaran_extras.deprecation.deprecation_utils import FIG_ONLY_WARNING
from pyglotaran_extras.deprecation.deprecation_utils import PyglotaranExtrasApiDeprecationWarning
from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi
from pyglotaran_extras.plotting.utils import extract_irf_location
from pyglotaran_extras.plotting.utils import shift_time_axis_by_irf_location

if TYPE_CHECKING:
from cycler import Cycler
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes
Expand All @@ -22,74 +25,123 @@


def plot_doas(
result: DatasetConvertible | Result,
figsize: tuple[int, int] = (25, 25),
dataset: DatasetConvertible | Result,
*,
damped_oscillation: list[str] | None = None,
time_range: tuple[float, float] | None = None,
spectral: float = 0,
main_irf_nr: int | None = 0,
normalize: bool = False,
figsize: tuple[int, int] = (20, 5),
show_zero_line: bool = True,
cycler: Cycler | None = PlotStyle().cycler,
figure_only: bool = True,
) -> Figure | tuple[Figure, Axes]:
"""Plot Damped oscillation associated spectra (DOAS).
oscillation_type: Literal["cos", "sin"] = "cos",
title: str | None = "Damped oscillations",
) -> tuple[Figure, Axes]:
"""Plot DOAS (Damped Oscillation) related data of the optimization result.

Parameters
----------
result: DatasetConvertible | Result
Result from a pyglotaran optimization as dataset, Path or Result object.
figsize : tuple[int, int]
Size of the figure (N, M) in inches. Defaults to (18, 16).
cycler : Cycler | None
Plot style cycler to use. Defaults to PlotStyle().cycler.
figure_only: bool
Whether or not to only return the figure.
This is a deprecation helper argument to transition to a consistent return value
consisting of the :class:`Figure` and the :class:`Axes`. Defaults to True.
dataset: DatasetConvertible | Result
Result dataset from a pyglotaran optimization.
damped_oscillation: list[str] | None
List of oscillation names which should be plotted.
Defaults to None which means that all oscillations will be plotted.
time_range: tuple[float, float] | None
Start and end time for the Oscillation plot, if ``main_irf_nr`` is not None the value are
relative to the IRF location. Defaults to None which means that the full time range is
used.
spectral: float
Value of the spectral axis that should be used to select the data for the Oscillation
plot this value does not need to be an exact existing value and only has effect if the
IRF has dispersion. Defaults to 0 which means that the Oscillation plot at lowest
spectral value will be shown.
main_irf_nr: int | None
Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks
and is used to shift the time axis. If it is none ``None`` the shifting will be
deactivated. Defaults to 0.
normalize: bool
Whether or not to normalize the DOAS spectra plot. If the DOAS spectra is normalized,
the Oscillation is scaled with the reciprocal of the normalization to compensate for this.
Defaults to False.
figsize: tuple[int, int]
Size of the figure (N, M) in inches. Defaults to (20, 5)
show_zero_line: bool
Whether or not to add a horizontal line at zero. Defaults to True
cycler: Cycler | None
Plot style cycler to use. Defaults to PlotStyle().cycler
oscillation_type: Literal["cos", "sin"]
Type of the oscillation to show in the oscillation plot. Defaults to "cos"
title: str | None
Title of the figure. Defaults to "Damped oscillations"

Returns
-------
Figure|tuple[Figure, Axes]
If ``figure_only`` is True, Figure object which contains the plots (deprecated).
If ``figure_only`` is False, Figure object which contains the plots and the Axes.
tuple[Figure, Axes]
Figure object which contains the plots and the Axes.

See Also
--------
calculate_ticks_in_units_of_pi
"""
dataset = load_data(result)

# Create M x N plotting grid
M = 6
N = 3

fig, axes = plt.subplots(M, N, figsize=figsize)

for ax in axes.flatten():
add_cycler_if_not_none(ax, cycler)

# Plot data
dataset.species_associated_spectra.plot.line(x="spectral", ax=axes[0, 0])
dataset.decay_associated_spectra.plot.line(x="spectral", ax=axes[0, 1])

if "spectral" in dataset.species_concentration.coords:
dataset.species_concentration.isel(spectral=0).plot.line(x="time", ax=axes[1, 0])
else:
dataset.species_concentration.plot.line(x="time", ax=axes[1, 0])
axes[1, 0].set_xscale("symlog", linthreshx=1)

if "dampened_oscillation_associated_spectra" in dataset:
dataset.dampened_oscillation_cos.isel(spectral=0).sel(time=slice(-1, 10)).plot.line(
x="time", ax=axes[1, 1]
)
dataset.dampened_oscillation_associated_spectra.plot.line(x="spectral", ax=axes[2, 0])
dataset.dampened_oscillation_phase.plot.line(x="spectral", ax=axes[2, 1])

dataset.residual_left_singular_vectors.isel(left_singular_value_index=0).plot(ax=axes[0, 2])
dataset.residual_singular_values.plot.line("ro-", yscale="log", ax=axes[1, 2])
dataset.residual_right_singular_vectors.isel(right_singular_value_index=0).plot(ax=axes[2, 2])

interval = int(dataset.spectral.size / 11)
for i in range(0):
axi = axes[i % 3, int(i / 3) + 3]
index = (i + 1) * interval
dataset.data.isel(spectral=index).plot(ax=axi)
dataset.residual.isel(spectral=index).plot(ax=axi)
dataset.fitted_data.isel(spectral=index).plot(ax=axi)

plt.tight_layout(pad=5, w_pad=2.0, h_pad=2.0)
if figure_only is False:
return fig, axes
warn(PyglotaranExtrasApiDeprecationWarning(FIG_ONLY_WARNING), stacklevel=2)
return fig
dataset = load_data(dataset, _stacklevel=3)

fig, axes = plt.subplots(1, 3, figsize=figsize)

add_cycler_if_not_none(axes, cycler)

time_sel_kwargs = (
{"time": slice(time_range[0], time_range[1])} if time_range is not None else {}
)
osc_sel_kwargs = (
{"damped_oscillation": damped_oscillation} if damped_oscillation is not None else {}
)

irf_location = extract_irf_location(dataset, spectral, main_irf_nr)

oscillations = shift_time_axis_by_irf_location(
dataset[f"damped_oscillation_{oscillation_type}"]
.sel(spectral=spectral, method="nearest")
.sel(**osc_sel_kwargs),
irf_location,
)
oscillations_spectra = dataset["damped_oscillation_associated_spectra"].sel(**osc_sel_kwargs)

damped_oscillation_phase = dataset["damped_oscillation_phase"].sel(**osc_sel_kwargs)

osc_max = abs_max((oscillations - 1), result_dims="damped_oscillation")
spectra_max = abs_max(oscillations_spectra, result_dims="damped_oscillation")
scales = np.sqrt(osc_max * spectra_max)

norm_factor = scales.max() if normalize is True else 1

((oscillations - 1) / osc_max * scales * norm_factor).sel(**time_sel_kwargs).plot.line(
x="time", ax=axes[0]
)

(oscillations_spectra / spectra_max * scales / norm_factor).plot.line(x="spectral", ax=axes[1])

damped_oscillation_phase.plot.line(x="spectral", ax=axes[2])

axes[0].set_title(f"{oscillation_type.capitalize()} Oscillations")
axes[1].set_title("Spectra")
axes[2].set_title("Phases")

axes[1].set_ylabel("Normalized DOAS" if normalize is True else "DOAS")

axes[2].set_yticks(
*calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal"
)
axes[2].set_ylabel("Phase (π)")

axes[1].get_legend().remove()
axes[2].get_legend().remove()

if show_zero_line is True:
[ax.axhline(0, color="k", linewidth=1) for ax in axes.flatten()]

if title:
fig.suptitle(title, fontsize=16)

fig.tight_layout()
return fig, axes
63 changes: 60 additions & 3 deletions pyglotaran_extras/plotting/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import xarray as xr

from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable
from pyglotaran_extras.io.utils import result_dataset_mapping

if TYPE_CHECKING:
Expand Down Expand Up @@ -109,7 +110,7 @@ def extract_irf_dispersion_center(


def extract_irf_location(
res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int = 0
res: xr.Dataset, center_λ: float | None = None, main_irf_nr: int | None = 0
) -> float:
"""Determine location of the ``irf``, which can be used to shift plots.

Expand All @@ -120,14 +121,16 @@ def extract_irf_location(
center_λ: float | None
Center wavelength (λ in nm)
main_irf_nr : int
Index of the main ``irf`` component when using an ``irf``
parametrized with multiple peaks. Defaults to 0.
Index of the main ``irf`` component when using an ``irf`` parametrized with multiple peaks.
If it is none ``None`` the location will be 0. Defaults to 0.

Returns
-------
float
Location of the ``irf``
"""
if main_irf_nr is None:
return 0
irf_dispersion_center = extract_irf_dispersion_center(
res=res, main_irf_nr=main_irf_nr, as_dataarray=False
)
Expand Down Expand Up @@ -406,3 +409,57 @@ def abs_max(
result_dims = (result_dims,)
reduce_dims = (dim for dim in data.dims if dim not in result_dims)
return np.abs(data).max(dim=reduce_dims)


def calculate_ticks_in_units_of_pi(
values: np.ndarray | xr.DataArray, *, step_size: float = 0.5
) -> tuple[Iterable[float], Iterable[str]]:
"""Calculate tick values and labels in units of Pi.

Parameters
----------
values: np.ndarray
Values which the ticks should be calculated for.
step_size: float
Step size of the ticks in units of pi. Defaults to 0.5

Returns
-------
tuple[Iterable[float], Iterable[str]]
Tick values and tick labels

See Also
--------
pyglotaran_extras.plotting.plot_doas.plot_doas

Examples
--------
If you have a case study that uses a ``damped-oscillation`` megacomplex you can plot the
``damped_oscillation_phase`` with y-tick in units of Pi by the following code given that the
dataset is saved under ``dataset.nc``.

.. code-block:: python
import matplotlib.pyplot as plt

from glotaran.io import load_dataset
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi

dataset = load_dataset("dataset.nc")

fig, ax = plt.subplots(1, 1)

damped_oscillation_phase = dataset["damped_oscillation_phase"].sel(
damped_oscillation=["osc1"]
)
damped_oscillation_phase.plot.line(x="spectral", ax=ax)

ax.set_yticks(
*calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal"
)
"""
values = np.array(values)
int_values_over_pi = np.round(values / np.pi / step_size)
tick_labels = np.arange(int_values_over_pi.min(), int_values_over_pi.max() + 1) * step_size
return tick_labels * np.pi, (
str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1)
)
20 changes: 20 additions & 0 deletions tests/plotting/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pytest
import xarray as xr
from cycler import Cycler
Expand All @@ -14,6 +15,7 @@
from pyglotaran_extras.plotting.style import PlotStyle
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none
from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi

matplotlib.use("Agg")
DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"]
Expand Down Expand Up @@ -65,3 +67,21 @@ def test_abs_max(result_dims: Hashable | Iterable[Hashable], expected: xr.DataAr
"""Result values are positive and dimensions are preserved if result_dims is not empty."""
data = xr.DataArray([[-10, 20], [-30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]})
assert abs_max(data, result_dims=result_dims).equals(expected)


@pytest.mark.parametrize(
"step_size, expected_tick_values,expected_tick_labels",
(
(0.5, np.linspace(-np.pi, 2 * np.pi, num=7), ["-1", "-0.5", "0", "0.5", "1", "1.5", "2"]),
(1, np.linspace(-np.pi, 2 * np.pi, num=4), ["-1", "0", "1", "2"]),
),
)
def test_calculate_ticks_in_units_of_pi(
step_size: float, expected_tick_values: list[float], expected_tick_labels: list[str]
):
"""Different values depending on ``step_size``."""
values = np.linspace(-np.pi, 2 * np.pi)
tick_values, tick_labels = calculate_ticks_in_units_of_pi(values, step_size=step_size)

assert np.allclose(list(tick_values), expected_tick_values)
assert list(tick_labels) == expected_tick_labels