diff --git a/pyglotaran_extras/plotting/utils.py b/pyglotaran_extras/plotting/utils.py index 8139db46..86ac21c5 100644 --- a/pyglotaran_extras/plotting/utils.py +++ b/pyglotaran_extras/plotting/utils.py @@ -8,6 +8,7 @@ import numpy as np import xarray as xr +from pyglotaran_extras.inspect.utils import pretty_format_numerical_iterable from pyglotaran_extras.io.utils import result_dataset_mapping if TYPE_CHECKING: @@ -406,3 +407,57 @@ def abs_max( result_dims = (result_dims,) reduce_dims = (dim for dim in data.dims if dim not in result_dims) return np.abs(data).max(dim=reduce_dims) + + +def calculate_ticks_in_units_of_pi( + values: np.ndarray | xr.DataArray, *, step_size: float = 0.5 +) -> tuple[Iterable[float], Iterable[str]]: + """Calculate tick values and labels in units of Pi. + + Parameters + ---------- + values: np.ndarray + Values which the ticks should be calculated for. + step_size: float + Step size of the ticks in units of pi. Defaults to 0.5 + + Returns + ------- + tuple[Iterable[float], Iterable[str]] + Tick values and tick labels + + See Also + -------- + pyglotaran_extras.plotting.plot_doas.plot_doas + + Examples + -------- + If you have a case study that uses a ``damped-oscillation`` megacomplex you can plot the + ``damped_oscillation_phase`` with y-tick in units of Pi by the following code given that the + dataset is saved under ``dataset.nc``. + + .. code-block:: python + import matplotlib.pyplot as plt + + from glotaran.io import load_dataset + from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi + + dataset = load_dataset("dataset.nc") + + fig, ax = plt.subplots(1, 1) + + damped_oscillation_phase = dataset["damped_oscillation_phase"].sel( + damped_oscillation=["osc1"] + ) + damped_oscillation_phase.plot.line(x="spectral", ax=ax) + + ax.set_yticks( + *calculate_ticks_in_units_of_pi(damped_oscillation_phase), rotation="horizontal" + ) + """ + values = np.array(values) + int_values_over_pi = np.round(values / np.pi / step_size) + tick_labels = np.arange(int_values_over_pi.min(), int_values_over_pi.max() + 1) * step_size + return tick_labels * np.pi, ( + str(val) for val in pretty_format_numerical_iterable(tick_labels, decimal_places=1) + ) diff --git a/tests/plotting/test_utils.py b/tests/plotting/test_utils.py index c5681207..aede5a88 100644 --- a/tests/plotting/test_utils.py +++ b/tests/plotting/test_utils.py @@ -6,6 +6,7 @@ import matplotlib import matplotlib.pyplot as plt +import numpy as np import pytest import xarray as xr from cycler import Cycler @@ -14,6 +15,7 @@ from pyglotaran_extras.plotting.style import PlotStyle from pyglotaran_extras.plotting.utils import abs_max from pyglotaran_extras.plotting.utils import add_cycler_if_not_none +from pyglotaran_extras.plotting.utils import calculate_ticks_in_units_of_pi matplotlib.use("Agg") DEFAULT_CYCLER = plt.rcParams["axes.prop_cycle"] @@ -65,3 +67,21 @@ def test_abs_max(result_dims: Hashable | Iterable[Hashable], expected: xr.DataAr """Result values are positive and dimensions are preserved if result_dims is not empty.""" data = xr.DataArray([[-10, 20], [-30, 40]], coords={"dim1": [1, 2], "dim2": [3, 4]}) assert abs_max(data, result_dims=result_dims).equals(expected) + + +@pytest.mark.parametrize( + "step_size, expected_tick_values,expected_tick_labels", + ( + (0.5, np.linspace(-np.pi, 2 * np.pi, num=7), ["-1", "-0.5", "0", "0.5", "1", "1.5", "2"]), + (1, np.linspace(-np.pi, 2 * np.pi, num=4), ["-1", "0", "1", "2"]), + ), +) +def test_calculate_ticks_in_units_of_pi( + step_size: float, expected_tick_values: list[float], expected_tick_labels: list[str] +): + """Different values depending on ``step_size``.""" + values = np.linspace(-np.pi, 2 * np.pi) + tick_values, tick_labels = calculate_ticks_in_units_of_pi(values, step_size=step_size) + + assert np.allclose(list(tick_values), expected_tick_values) + assert list(tick_labels) == expected_tick_labels