diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 567e5ecd..0855621c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,7 +33,7 @@ repos: language_version: python3 - repo: https://github.com/PyCQA/isort - rev: 5.11.4 + rev: 5.12.0 hooks: - id: isort minimum_pre_commit_version: 2.9.0 diff --git a/pyglotaran_extras/io/load_data.py b/pyglotaran_extras/io/load_data.py index d8f71cd3..ec8e70e6 100644 --- a/pyglotaran_extras/io/load_data.py +++ b/pyglotaran_extras/io/load_data.py @@ -2,6 +2,7 @@ from __future__ import annotations from pathlib import Path +from warnings import warn import xarray as xr from glotaran.io import load_dataset @@ -10,7 +11,9 @@ from pyglotaran_extras.types import DatasetConvertible -def load_data(result: DatasetConvertible | Result, dataset_name: str | None = None) -> xr.Dataset: +def load_data( + result: DatasetConvertible | Result, dataset_name: str | None = None, *, _stacklevel: int = 2 +) -> xr.Dataset: """Extract a single dataset from a :class:`DatasetConvertible` object. Parameters @@ -20,6 +23,10 @@ def load_data(result: DatasetConvertible | Result, dataset_name: str | None = No dataset_name : str, optional Name of a specific dataset contained in ``result``, if not provided the first dataset will be extracted. Defaults to None. + _stacklevel: int + Stacklevel of the warning which is raised when ``result`` is of class ``Result``, + contains multiple datasets and no ``dataset_name`` is provided. Changing this value is + only required if you use this function inside of another function. Defaults to 2 Returns ------- @@ -39,6 +46,15 @@ def load_data(result: DatasetConvertible | Result, dataset_name: str | None = No if dataset_name is not None: return result.data[dataset_name] keys = list(result.data) + if len(keys) > 1: + warn( + UserWarning( + f"Result contains multiple datasets, auto selecting {keys[0]!r}.\n" + f"Pass the dataset set you want to plot (e.g. result.data[{keys[0]!r}]) , " + f"to deactivate this Warning.\nPossible dataset names are: {keys}" + ), + stacklevel=_stacklevel, + ) return result.data[keys[0]] if isinstance(result, (str, Path)): return load_data(load_dataset(result)) diff --git a/pyglotaran_extras/plotting/plot_coherent_artifact.py b/pyglotaran_extras/plotting/plot_coherent_artifact.py index 74ca6194..6f11e329 100644 --- a/pyglotaran_extras/plotting/plot_coherent_artifact.py +++ b/pyglotaran_extras/plotting/plot_coherent_artifact.py @@ -6,19 +6,22 @@ import matplotlib.pyplot as plt import numpy as np -import xarray as xr -from cycler import Cycler +from pyglotaran_extras.io.load_data import load_data from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none if TYPE_CHECKING: + from cycler import Cycler + from glotaran.project.result import Result from matplotlib.figure import Figure from matplotlib.pyplot import Axes + from pyglotaran_extras.types import DatasetConvertible + def plot_coherent_artifact( - res: xr.Dataset, + dataset: DatasetConvertible | Result, *, time_range: tuple[float, float] | None = None, spectral: float = 0, @@ -34,7 +37,7 @@ def plot_coherent_artifact( Parameters ---------- - res: xr.Dataset + dataset: DatasetConvertible | Result Result dataset from a pyglotaran optimization. time_range: tuple[float, float] | None Start and end time for the IRF derivative plot. Defaults to None which means that @@ -65,20 +68,21 @@ def plot_coherent_artifact( """ fig, axes = plt.subplots(1, 2, figsize=figsize) add_cycler_if_not_none(axes, cycler) + dataset = load_data(dataset, _stacklevel=3) if ( - "coherent_artifact_response" not in res - or "coherent_artifact_associated_spectra" not in res + "coherent_artifact_response" not in dataset + or "coherent_artifact_associated_spectra" not in dataset ): warn( - UserWarning(f"Dataset does not contain coherent artifact data:\n {res.data_vars}"), + UserWarning(f"Dataset does not contain coherent artifact data:\n {dataset.data_vars}"), stacklevel=2, ) return fig, axes - irf_max = abs_max(res.coherent_artifact_response, result_dims=("coherent_artifact_order")) + irf_max = abs_max(dataset.coherent_artifact_response, result_dims=("coherent_artifact_order")) irfas_max = abs_max( - res.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order") + dataset.coherent_artifact_associated_spectra, result_dims=("coherent_artifact_order") ) scales = np.sqrt(irfas_max * irf_max) norm_factor = 1 @@ -90,7 +94,7 @@ def plot_coherent_artifact( irf_y_label = f"normalized {irf_y_label}" plot_slice_irf = ( - res.coherent_artifact_response.sel(spectral=spectral, method="nearest") + dataset.coherent_artifact_response.sel(spectral=spectral, method="nearest") / irf_max * scales / norm_factor @@ -102,7 +106,9 @@ def plot_coherent_artifact( axes[0].set_title("IRF Derivatives") axes[0].set_ylabel(f"{irf_y_label} (a.u.)") - plot_slice_irfas = res.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor + plot_slice_irfas = ( + dataset.coherent_artifact_associated_spectra / irfas_max * scales * norm_factor + ) plot_slice_irfas.plot.line(x="spectral", ax=axes[1]) axes[1].get_legend().remove() axes[1].set_title("IRFAS") @@ -113,9 +119,9 @@ def plot_coherent_artifact( axes[1].axhline(0, color="k", linewidth=1) # - if res.coords["coherent_artifact_order"][0] == 1: + if dataset.coords["coherent_artifact_order"][0] == 1: axes[0].legend( - [f"{int(ax_label)-1}" for ax_label in res.coords["coherent_artifact_order"]], + [f"{int(ax_label)-1}" for ax_label in dataset.coords["coherent_artifact_order"]], title="coherent_artifact_order", ) if title: diff --git a/pyglotaran_extras/plotting/plot_data.py b/pyglotaran_extras/plotting/plot_data.py index eaf4ed87..93b7b8fb 100644 --- a/pyglotaran_extras/plotting/plot_data.py +++ b/pyglotaran_extras/plotting/plot_data.py @@ -62,7 +62,7 @@ def plot_data_overview( tuple[Figure, Axes] Figure and axes which can then be refined by the user. """ - dataset = load_data(dataset) + dataset = load_data(dataset, _stacklevel=3) fig = plt.figure(figsize=figsize) data_ax = cast(Axis, plt.subplot2grid((4, 3), (0, 0), colspan=3, rowspan=3, fig=fig)) diff --git a/pyglotaran_extras/plotting/plot_guidance.py b/pyglotaran_extras/plotting/plot_guidance.py index 89c59cbd..3f1a29af 100644 --- a/pyglotaran_extras/plotting/plot_guidance.py +++ b/pyglotaran_extras/plotting/plot_guidance.py @@ -45,7 +45,7 @@ def plot_guidance( tuple[Figure, Axes] Figure and axes which can then be refined by the user. """ - res = load_data(result) + res = load_data(result, _stacklevel=3) fig, axes = plt.subplots(1, 2, figsize=figsize) for axis in axes: diff --git a/pyglotaran_extras/plotting/plot_overview.py b/pyglotaran_extras/plotting/plot_overview.py index b56a7f3c..6b17a978 100644 --- a/pyglotaran_extras/plotting/plot_overview.py +++ b/pyglotaran_extras/plotting/plot_overview.py @@ -104,7 +104,7 @@ def plot_overview( If ``figure_only`` is True, Figure object which contains the plots (deprecated). If ``figure_only`` is False, Figure object which contains the plots and the Axes. """ - res = load_data(result) + res = load_data(result, _stacklevel=3) if res.coords["time"].values.size == 1: fig, axes = plot_guidance(res) @@ -197,7 +197,7 @@ def plot_simple_overview( If ``figure_only`` is True, Figure object which contains the plots (deprecated). If ``figure_only`` is False, Figure object which contains the plots and the Axes. """ - res = load_data(result) + res = load_data(result, _stacklevel=3) fig, axes = plt.subplots(2, 3, figsize=figsize, constrained_layout=True) for ax in axes.flatten(): diff --git a/tests/conftest.py b/tests/conftest.py index 086a9f27..3f5ee1a4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,10 @@ +# isort: off +# Hack around https://github.com/pydata/xarray/issues/7259 which also affects pyglotaran <= 0.7.0 +import numpy # noqa +import netCDF4 # noqa + +# isort: on + from dataclasses import replace import pytest diff --git a/tests/io/test_load_data.py b/tests/io/test_load_data.py new file mode 100644 index 00000000..83d5ba56 --- /dev/null +++ b/tests/io/test_load_data.py @@ -0,0 +1,95 @@ +"""Tests for ``pyglotaran_extras.io.load_data``.""" +from __future__ import annotations + +from pathlib import Path +from typing import TYPE_CHECKING + +import pytest +import xarray as xr +from glotaran.io import load_result + +from pyglotaran_extras.io.load_data import load_data + +if TYPE_CHECKING: + + from _pytest.recwarn import WarningsRecorder + from glotaran.project.result import Result + +MULTI_DATASET_WARING = ( + "Result contains multiple datasets, auto selecting 'dataset_1'.\n" + "Pass the dataset set you want to plot (e.g. result.data['dataset_1']) , " + "to deactivate this Warning.\nPossible dataset names are: ['dataset_1', 'dataset_2']" +) + + +def run_load_data_test(result: xr.Dataset, compare: xr.Dataset | None = None): + """Factored out test runner function for ``test_load_data``.""" + assert isinstance(result, xr.Dataset) + assert hasattr(result, "data") + if compare is not None: + assert result.equals(compare) + + +def test_load_data( + result_sequential_spectral_decay: Result, tmp_path: Path, recwarn: WarningsRecorder +): + """All input_type permutations result in a ``xr.Dataset``.""" + compare_dataset = result_sequential_spectral_decay.data["dataset_1"] + + from_result = load_data(result_sequential_spectral_decay) + + run_load_data_test(from_result, compare_dataset) + + from_dataset = load_data(compare_dataset) + + run_load_data_test(from_dataset, compare_dataset) + + result_sequential_spectral_decay.save(tmp_path / "result.yml") + + from_file = load_data(tmp_path / "dataset_1.nc") + + run_load_data_test(from_file, compare_dataset) + + data_array = xr.DataArray([[1, 2], [3, 4]]) + from_data_array = load_data(data_array) + + run_load_data_test(from_data_array) + assert data_array.equals(from_data_array.data) + + # No warning til now + assert len(recwarn) == 0 + + # Ensure not to mutate original fixture + result_multi_dataset = load_result(tmp_path / "result.yml") + result_multi_dataset.data["dataset_2"] = xr.Dataset({"foo": [1]}) + + from_result_multi_dataset = load_data(result_multi_dataset) + + run_load_data_test(from_result_multi_dataset, compare_dataset) + + assert len(recwarn) == 1 + + assert recwarn[0].category == UserWarning + assert recwarn[0].message.args[0] == MULTI_DATASET_WARING + assert Path(recwarn[0].filename) == Path(__file__) + + def wrapped_call(result: Result): + return load_data(result, _stacklevel=3) + + result_wrapped_call = wrapped_call(result_multi_dataset) + + run_load_data_test(result_wrapped_call, compare_dataset) + + assert len(recwarn) == 2 + + assert recwarn[1].category == UserWarning + assert recwarn[1].message.args[0] == MULTI_DATASET_WARING + assert Path(recwarn[1].filename) == Path(__file__) + + with pytest.raises(TypeError) as excinfo: + load_data([1, 2]) + + assert str(excinfo.value) == ( + "Result needs to be of type typing.Union[xarray.core.dataset.Dataset, " + "xarray.core.dataarray.DataArray, str, pathlib.Path], but was [1, 2]." + )