Skip to content

Commit

Permalink
👌 Make result input for plot_coherent_artifact more generic (#134)
Browse files Browse the repository at this point in the history
* 👌 Made result input for plot_coherent_artifact more generic

* 👌 Warn when picking single dataset from result with multiple datasets

* 🩹 Add hack to prevent numpy netCDF4 import order dependent RuntimeWarning

* ⬆️🩹 Update isort in pre-commit to fix CI installation

See https://github.com/s-weigand/pyglotaran-extras/actions/runs/4039040323/jobs/6943465082
  • Loading branch information
s-weigand authored Jan 29, 2023
1 parent 6c6a7a4 commit c1385c9
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 19 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repos:
language_version: python3

- repo: https://github.com/PyCQA/isort
rev: 5.11.4
rev: 5.12.0
hooks:
- id: isort
minimum_pre_commit_version: 2.9.0
Expand Down
18 changes: 17 additions & 1 deletion pyglotaran_extras/io/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

from pathlib import Path
from warnings import warn

import xarray as xr
from glotaran.io import load_dataset
Expand All @@ -10,7 +11,9 @@
from pyglotaran_extras.types import DatasetConvertible


def load_data(result: DatasetConvertible | Result, dataset_name: str | None = None) -> xr.Dataset:
def load_data(
result: DatasetConvertible | Result, dataset_name: str | None = None, *, _stacklevel: int = 2
) -> xr.Dataset:
"""Extract a single dataset from a :class:`DatasetConvertible` object.
Parameters
Expand All @@ -20,6 +23,10 @@ def load_data(result: DatasetConvertible | Result, dataset_name: str | None = No
dataset_name : str, optional
Name of a specific dataset contained in ``result``, if not provided
the first dataset will be extracted. Defaults to None.
_stacklevel: int
Stacklevel of the warning which is raised when ``result`` is of class ``Result``,
contains multiple datasets and no ``dataset_name`` is provided. Changing this value is
only required if you use this function inside of another function. Defaults to 2
Returns
-------
Expand All @@ -39,6 +46,15 @@ def load_data(result: DatasetConvertible | Result, dataset_name: str | None = No
if dataset_name is not None:
return result.data[dataset_name]
keys = list(result.data)
if len(keys) > 1:
warn(
UserWarning(
f"Result contains multiple datasets, auto selecting {keys[0]!r}.\n"
f"Pass the dataset set you want to plot (e.g. result.data[{keys[0]!r}]) , "
f"to deactivate this Warning.\nPossible dataset names are: {keys}"
),
stacklevel=_stacklevel,
)
return result.data[keys[0]]
if isinstance(result, (str, Path)):
return load_data(load_dataset(result))
Expand Down
32 changes: 19 additions & 13 deletions pyglotaran_extras/plotting/plot_coherent_artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,22 @@

import matplotlib.pyplot as plt
import numpy as np
import xarray as xr
from cycler import Cycler

from pyglotaran_extras.io.load_data import load_data
from pyglotaran_extras.plotting.utils import abs_max
from pyglotaran_extras.plotting.utils import add_cycler_if_not_none

if TYPE_CHECKING:
from cycler import Cycler
from glotaran.project.result import Result
from matplotlib.figure import Figure
from matplotlib.pyplot import Axes

from pyglotaran_extras.types import DatasetConvertible


def plot_coherent_artifact(
res: xr.Dataset,
dataset: DatasetConvertible | Result,
*,
time_range: tuple[float, float] | None = None,
spectral: float = 0,
Expand All @@ -34,7 +37,7 @@ def plot_coherent_artifact(
Parameters
----------
res: xr.Dataset
dataset: DatasetConvertible | Result
Result dataset from a pyglotaran optimization.
time_range: tuple[float, float] | None
Start and end time for the IRF derivative plot. Defaults to None which means that
Expand Down Expand Up @@ -65,20 +68,21 @@ def plot_coherent_artifact(
"""
fig, axes = plt.subplots(1, 2, figsize=figsize)
add_cycler_if_not_none(axes, cycler)
dataset = load_data(dataset, _stacklevel=3)

if (
"coherent_artifact_response" not in res
or "coherent_artifact_associated_spectra" not in res
"coherent_artifact_response" not in dataset
or "coherent_artifact_associated_spectra" not in dataset
):
warn(
UserWarning(f"Dataset does not contain coherent artifact data:\n {res.data_vars}"),
UserWarning(f"Dataset does not contain coherent artifact data:\n {dataset.data_vars}"),
stacklevel=2,
)
return fig, axes

irf_max = abs_max(res.coherent_artifact_response, result_dims=("coherent_artifact_order"))
irf_max = abs_max(dataset.coherent_artifact_response, result_dims=("coherent_artifact_order"))
irfas_max = abs_max(
res.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order")
dataset.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order")
)
scales = np.sqrt(irfas_max * irf_max)
norm_factor = 1
Expand All @@ -90,7 +94,7 @@ def plot_coherent_artifact(
irf_y_label = f"normalized {irf_y_label}"

plot_slice_irf = (
res.coherent_artifact_response.sel(spectral=spectral, method="nearest")
dataset.coherent_artifact_response.sel(spectral=spectral, method="nearest")
/ irf_max
* scales
/ norm_factor
Expand All @@ -102,7 +106,9 @@ def plot_coherent_artifact(
axes[0].set_title("IRF Derivatives")
axes[0].set_ylabel(f"{irf_y_label} (a.u.)")

plot_slice_irfas = res.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor
plot_slice_irfas = (
dataset.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor
)
plot_slice_irfas.plot.line(x="spectral", ax=axes[1])
axes[1].get_legend().remove()
axes[1].set_title("IRFAS")
Expand All @@ -113,9 +119,9 @@ def plot_coherent_artifact(
axes[1].axhline(0, color="k", linewidth=1)

#
if res.coords["coherent_artifact_order"][0] == 1:
if dataset.coords["coherent_artifact_order"][0] == 1:
axes[0].legend(
[f"{int(ax_label)-1}" for ax_label in res.coords["coherent_artifact_order"]],
[f"{int(ax_label)-1}" for ax_label in dataset.coords["coherent_artifact_order"]],
title="coherent_artifact_order",
)
if title:
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def plot_data_overview(
tuple[Figure, Axes]
Figure and axes which can then be refined by the user.
"""
dataset = load_data(dataset)
dataset = load_data(dataset, _stacklevel=3)

fig = plt.figure(figsize=figsize)
data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig))
Expand Down
2 changes: 1 addition & 1 deletion pyglotaran_extras/plotting/plot_guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def plot_guidance(
tuple[Figure, Axes]
Figure and axes which can then be refined by the user.
"""
res = load_data(result)
res = load_data(result, _stacklevel=3)
fig, axes = plt.subplots(1, 2, figsize=figsize)

for axis in axes:
Expand Down
4 changes: 2 additions & 2 deletions pyglotaran_extras/plotting/plot_overview.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def plot_overview(
If ``figure_only`` is True, Figure object which contains the plots (deprecated).
If ``figure_only`` is False, Figure object which contains the plots and the Axes.
"""
res = load_data(result)
res = load_data(result, _stacklevel=3)

if res.coords["time"].values.size == 1:
fig, axes = plot_guidance(res)
Expand Down Expand Up @@ -197,7 +197,7 @@ def plot_simple_overview(
If ``figure_only`` is True, Figure object which contains the plots (deprecated).
If ``figure_only`` is False, Figure object which contains the plots and the Axes.
"""
res = load_data(result)
res = load_data(result, _stacklevel=3)

fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True)
for ax in axes.flatten():
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# isort: off
# Hack around https://github.com/pydata/xarray/issues/7259 which also affects pyglotaran <= 0.7.0
import numpy # noqa
import netCDF4 # noqa

# isort: on

from dataclasses import replace

import pytest
Expand Down
95 changes: 95 additions & 0 deletions tests/io/test_load_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
"""Tests for ``pyglotaran_extras.io.load_data``."""
from __future__ import annotations

from pathlib import Path
from typing import TYPE_CHECKING

import pytest
import xarray as xr
from glotaran.io import load_result

from pyglotaran_extras.io.load_data import load_data

if TYPE_CHECKING:

from _pytest.recwarn import WarningsRecorder
from glotaran.project.result import Result

MULTI_DATASET_WARING = (
"Result contains multiple datasets, auto selecting 'dataset_1'.\n"
"Pass the dataset set you want to plot (e.g. result.data['dataset_1']) , "
"to deactivate this Warning.\nPossible dataset names are: ['dataset_1', 'dataset_2']"
)


def run_load_data_test(result: xr.Dataset, compare: xr.Dataset | None = None):
"""Factored out test runner function for ``test_load_data``."""
assert isinstance(result, xr.Dataset)
assert hasattr(result, "data")
if compare is not None:
assert result.equals(compare)


def test_load_data(
result_sequential_spectral_decay: Result, tmp_path: Path, recwarn: WarningsRecorder
):
"""All input_type permutations result in a ``xr.Dataset``."""
compare_dataset = result_sequential_spectral_decay.data["dataset_1"]

from_result = load_data(result_sequential_spectral_decay)

run_load_data_test(from_result, compare_dataset)

from_dataset = load_data(compare_dataset)

run_load_data_test(from_dataset, compare_dataset)

result_sequential_spectral_decay.save(tmp_path / "result.yml")

from_file = load_data(tmp_path / "dataset_1.nc")

run_load_data_test(from_file, compare_dataset)

data_array = xr.DataArray([[1, 2], [3, 4]])
from_data_array = load_data(data_array)

run_load_data_test(from_data_array)
assert data_array.equals(from_data_array.data)

# No warning til now
assert len(recwarn) == 0

# Ensure not to mutate original fixture
result_multi_dataset = load_result(tmp_path / "result.yml")
result_multi_dataset.data["dataset_2"] = xr.Dataset({"foo": [1]})

from_result_multi_dataset = load_data(result_multi_dataset)

run_load_data_test(from_result_multi_dataset, compare_dataset)

assert len(recwarn) == 1

assert recwarn[0].category == UserWarning
assert recwarn[0].message.args[0] == MULTI_DATASET_WARING
assert Path(recwarn[0].filename) == Path(__file__)

def wrapped_call(result: Result):
return load_data(result, _stacklevel=3)

result_wrapped_call = wrapped_call(result_multi_dataset)

run_load_data_test(result_wrapped_call, compare_dataset)

assert len(recwarn) == 2

assert recwarn[1].category == UserWarning
assert recwarn[1].message.args[0] == MULTI_DATASET_WARING
assert Path(recwarn[1].filename) == Path(__file__)

with pytest.raises(TypeError) as excinfo:
load_data([1, 2])

assert str(excinfo.value) == (
"Result needs to be of type typing.Union[xarray.core.dataset.Dataset, "
"xarray.core.dataarray.DataArray, str, pathlib.Path], but was [1, 2]."
)

0 comments on commit c1385c9

Please sign in to comment.