diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 58b37df..8b64bbb 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -30,5 +30,11 @@ jobs: uv python install ${{ matrix.python-version }} uv sync --python ${{ matrix.python-version }} --frozen + - name: Cache Zenodo datasets + uses: actions/cache@v4 + with: + path: ~/.cache/sdf_datasets + key: sdf-datasets-17991042 + - name: Test with pytest run: uv run pytest diff --git a/docs/animation.rst b/docs/animation.rst index 2275349..62dec15 100644 --- a/docs/animation.rst +++ b/docs/animation.rst @@ -3,6 +3,9 @@ .. |animate_accessor| replace:: `xarray.DataArray.epoch.animate ` +.. |animate_multiple_accessor| replace:: `xarray.Dataset.epoch.animate_multiple + ` + ========== Animations ========== @@ -131,10 +134,8 @@ Moving window ------------- EPOCH allows for simulations that have a moving simulation window -(changing x-axis over time). |animate_accessor| will -automatically detect when a simulation has a moving window by searching -for NaNs in the `xarray.DataArray` and change the x-axis limits -accordingly. +(changing x-axis over time). |animate_accessor| can accept the boolean parameter +``move_window`` and change the x-axis limits accordingly. .. warning:: `sdf_xarray.open_mfdataset` does not currently function with moving window data. @@ -152,7 +153,7 @@ accordingly. ) da = ds["Derived_Number_Density_Beam_Electrons"] - anim = da.epoch.animate(fps = 5) + anim = da.epoch.animate(move_window=True, fps = 5) anim.show() .. warning:: @@ -191,73 +192,70 @@ before plotting as in :ref:`sec-unit-conversion`. Some functionality such as ) anim.show() -Advanced usage --------------- +Combining multiple animations +----------------------------- -Multiple plots on the same axes -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +|animate_multiple_accessor| creates a `matplotlib.animation.FuncAnimation` +that contains multiple plots layered on top of each other. + +1D simulation +~~~~~~~~~~~~~ What follows is an example of how to combine multiple animations on the -same axis. This may be implemented in a more user-friendly function in -a future update. +same axis. .. jupyter-execute:: - # Open the SDF files ds = sdfxr.open_mfdataset("tutorial_dataset_1d/*.sdf") - # Create figure and axes - fig, ax = plt.subplots() - plt.close(fig) - - # Generate the animations independently - anim_1 = ds["Derived_Number_Density_Electron"].epoch.animate() - anim_2 = ds["Derived_Number_Density_Ion"].epoch.animate() - - # Extract the update functions from the animations - update_1 = anim_1._func - update_2 = anim_2._func - - # Create axes details for new animation - x_min, x_max = update_1(0)[0].axes.get_xlim() - y_min_1, y_max_1 = update_1(0)[0].axes.get_ylim() - y_min_2, y_max_2 = update_2(0)[0].axes.get_ylim() - y_min = min(y_min_1, y_min_2) - y_max = max(y_max_1, y_max_2) - x_label = update_1(0)[0].axes.get_xlabel() - y_label = "Number Density [m$^{-3}$]" - label_1 = "Electron" - label_2 = "Ion" - - # Create new update function - def update_combined(frame): - anim_1_fig = update_1(frame)[0] - anim_2_fig = update_2(frame)[0] - - title = anim_1_fig.axes.title._text - - ax.clear() - plot = ax.plot(anim_1_fig._x, anim_1_fig._y, label = label_1) - ax.plot(anim_2_fig._x, anim_2_fig._y, label = label_2) - ax.set_title(title) - ax.set_xlim(x_min, x_max) - ax.set_ylim(y_min, y_max) - ax.set_xlabel(x_label) - ax.set_ylabel(y_label) - ax.legend(loc = "upper left") - return plot - - N_frames = anim_1._save_count - interval = anim_1._interval - - # Create combined animation - anim_combined = FuncAnimation( - fig, - update_combined, - frames=range(N_frames), - interval = interval, - repeat=True, - ) + anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_Electron"], + ds["Derived_Number_Density_Ion"], + datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}], + ylim=(0e27,4e27), + ylabel="Derived Number Density [1/m$^3$]" + ) - # Display animation as jshtml - HTML(anim_combined.to_jshtml()) \ No newline at end of file + anim.show() + +2D simulation +~~~~~~~~~~~~~ + +.. tip:: + To correctly display 2D data on top of one another you need to specify + the ``alpha`` value which sets the opacity of the plot. + +This also works with 2 dimensional data. + +.. jupyter-execute:: + + import numpy as np + from matplotlib.colors import LogNorm + + ds = sdfxr.open_mfdataset("tutorial_dataset_2d/*.sdf") + + flux_magnitude = np.sqrt( + ds["Derived_Poynting_Flux_x"]**2 + + ds["Derived_Poynting_Flux_y"]**2 + + ds["Derived_Poynting_Flux_z"]**2 + ) + flux_magnitude.attrs["long_name"] = "Poynting Flux Magnitude" + flux_magnitude.attrs["units"] = "W/m$^2$" + + # Cut-off low energy values so that they will be rendered as transparent + # in the plot as they've been set to NaN + flux_masked = flux_magnitude.where(flux_magnitude > 0.2e23) + flux_norm = LogNorm( + vmin=float(flux_masked.min()), + vmax=float(flux_masked.max()) + ) + + anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_Electron"], + flux_masked, + datasets_kwargs=[ + {"alpha": 1.0}, + {"cmap": "hot", "norm": flux_norm, "alpha": 0.9}, + ], + ) + anim.show() \ No newline at end of file diff --git a/src/sdf_xarray/dataset_accessor.py b/src/sdf_xarray/dataset_accessor.py index 71eeb1b..c9f5433 100644 --- a/src/sdf_xarray/dataset_accessor.py +++ b/src/sdf_xarray/dataset_accessor.py @@ -1,5 +1,15 @@ +from __future__ import annotations + +from types import MethodType +from typing import TYPE_CHECKING + import xarray as xr +from .plotting import animate_multiple, show + +if TYPE_CHECKING: + from matplotlib.animation import FuncAnimation + @xr.register_dataset_accessor("epoch") class EpochAccessor: @@ -69,3 +79,46 @@ def rescale_coords( new_coords[coord_name] = coord_rescaled return ds.assign_coords(new_coords) + + def animate_multiple( + self, + *variables: str | xr.DataArray, + datasets_kwargs: list[dict] | None = None, + **kwargs, + ) -> FuncAnimation: + """ + Animate multiple Dataset variables on the same axes. + + Parameters + ---------- + variables + The variables to animate. + datasets_kwargs + Per-dataset keyword arguments passed to plotting. + kwargs + Common keyword arguments forwarded to animation. + + Examples + -------- + >>> anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_Electron"], + ds["Derived_Number_Density_Ion"], + datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}], + ylabel="Derived Number Density [1/m$^3$]" + ) + >>> anim.save("animation.gif") + >>> # Or in a jupyter notebook: + >>> anim.show() + """ + + dataarrays = [ + self._obj[var] if isinstance(var, str) else var for var in variables + ] + anim = animate_multiple( + *dataarrays, + datasets_kwargs=datasets_kwargs, + **kwargs, + ) + anim.show = MethodType(show, anim) + + return anim diff --git a/src/sdf_xarray/download.py b/src/sdf_xarray/download.py index 6fc992c..0a67b7d 100644 --- a/src/sdf_xarray/download.py +++ b/src/sdf_xarray/download.py @@ -54,7 +54,7 @@ def fetch_dataset( logger = pooch.get_logger() datasets = pooch.create( path=pooch.os_cache("sdf_datasets"), - base_url="doi:10.5281/zenodo.17618510", + base_url="https://zenodo.org/records/17991042/files", registry={ "test_array_no_grids.zip": "md5:583c85ed8c31d0e34e7766b6d9f2d6da", "test_dist_fn.zip": "md5:a582ff5e8c59bad62fe4897f65fc7a11", @@ -64,10 +64,11 @@ def fetch_dataset( "test_mismatched_files.zip": "md5:710fdc94666edf7777523e8fc9dd1bd4", "test_two_probes_2D.zip": "md5:0f2a4fefe84a15292d066b3320d4d533", "tutorial_dataset_1d.zip": "md5:7fad744d8b8b2b84bba5c0e705fdef7b", - "tutorial_dataset_2d.zip": "md5:1945ecdbc1ac1798164f83ea2b3d1b31", + "tutorial_dataset_2d.zip": "md5:b7f35c05703a48eb5128049cdd106ffa", "tutorial_dataset_2d_moving_window.zip": "md5:a795f40d18df69263842055de4559501", "tutorial_dataset_3d.zip": "md5:d9254648867016292440fdb028f717f7", }, + retry_if_failed=10, ) datasets.fetch( diff --git a/src/sdf_xarray/plotting.py b/src/sdf_xarray/plotting.py index feb4422..a3c2db9 100644 --- a/src/sdf_xarray/plotting.py +++ b/src/sdf_xarray/plotting.py @@ -1,6 +1,10 @@ from __future__ import annotations -from typing import TYPE_CHECKING +import warnings +from collections.abc import Callable +from dataclasses import dataclass +from types import MethodType +from typing import TYPE_CHECKING, Any import numpy as np import xarray as xr @@ -9,7 +13,11 @@ import matplotlib.pyplot as plt from matplotlib.animation import FuncAnimation -from types import MethodType + +@dataclass +class AnimationUnit: + update: Callable[[int], object] + n_frames: int def get_frame_title( @@ -55,8 +63,8 @@ def calculate_window_boundaries( x_axis_name: str = "X_Grid_mid", t: str = "time", ) -> np.ndarray: - """Calculate the bounderies a moving window frame. If the user specifies xlim, this will - be used as the initial bounderies and the window will move along acordingly. + """Calculate the boundaries a moving window frame. If the user specifies xlim, this will + be used as the initial boundaries and the window will move along acordingly. Parameters ---------- @@ -71,12 +79,12 @@ def calculate_window_boundaries( """ x_grid = data[x_axis_name].values x_half_cell = (x_grid[1] - x_grid[0]) / 2 - N_frames = data[t].size + n_frames = data[t].size - # Find the window bounderies by finding the first and last non-NaN values in the 0th lineout + # Find the window boundaries by finding the first and last non-NaN values in the 0th lineout # along the x-axis. - window_boundaries = np.zeros((N_frames, 2)) - for i in range(N_frames): + window_boundaries = np.zeros((n_frames, 2)) + for i in range(n_frames): # Check if data is 1D if data.ndim == 2: target_lineout = data[i].values @@ -87,7 +95,7 @@ def calculate_window_boundaries( window_boundaries[i, 0] = x_grid_non_nan[0] - x_half_cell window_boundaries[i, 1] = x_grid_non_nan[-1] + x_half_cell - # User's choice for initial window edge supercides the one calculated + # User's choice for initial window edge supercedes the one calculated if xlim is not None: window_boundaries = window_boundaries + xlim - window_boundaries[0] return window_boundaries @@ -120,25 +128,98 @@ def compute_global_limits( return global_min, global_max -def animate( +def _set_axes_labels(ax: plt.Axes, axis_kwargs: dict) -> None: + """Set the labels for the x and y axes""" + if "xlabel" in axis_kwargs: + ax.set_xlabel(axis_kwargs["xlabel"]) + if "ylabel" in axis_kwargs: + ax.set_ylabel(axis_kwargs["ylabel"]) + + +def _setup_2d_plot( data: xr.DataArray, - fps: float = 10, + ax: plt.Axes, + coord_names: list[str], + kwargs: dict, + axis_kwargs: dict, + min_percentile: float, + max_percentile: float, + t: str, +) -> tuple[float, float]: + """Setup 2D plot initialization.""" + + kwargs.setdefault("x", coord_names[0]) + + data.isel({t: 0}).plot(ax=ax, **kwargs) + + global_min, global_max = compute_global_limits(data, min_percentile, max_percentile) + + _set_axes_labels(ax, axis_kwargs) + + if "ylim" not in kwargs: + ax.set_ylim(global_min, global_max) + + return global_min, global_max + + +def _setup_3d_plot( + data: xr.DataArray, + ax: plt.Axes, + coord_names: list[str], + kwargs: dict, + kwargs_original: dict, + axis_kwargs: dict, + min_percentile: float, + max_percentile: float, + t: str, +) -> None: + """Setup 3D plot initialization.""" + import matplotlib.pyplot as plt # noqa: PLC0415 + + if "norm" not in kwargs: + global_min, global_max = compute_global_limits( + data, min_percentile, max_percentile + ) + kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max) + + kwargs["add_colorbar"] = False + kwargs.setdefault("x", coord_names[0]) + kwargs.setdefault("y", coord_names[1]) + + argmin_time = np.unravel_index(np.argmin(data.values), data.shape)[0] + plot = data.isel({t: argmin_time}).plot(ax=ax, **kwargs) + kwargs["cmap"] = plot.cmap + + _set_axes_labels(ax, axis_kwargs) + + if kwargs_original.get("add_colorbar", True): + long_name = data.attrs.get("long_name") + units = data.attrs.get("units") + fig = plot.get_figure() + fig.colorbar(plot, ax=ax, label=f"{long_name} [{units}]") + + +def _generate_animation( + data: xr.DataArray, + clear_axes: bool = False, min_percentile: float = 0, max_percentile: float = 100, title: str | None = None, display_sdf_name: bool = False, + move_window: bool = False, t: str | None = None, ax: plt.Axes | None = None, - **kwargs, -) -> FuncAnimation: - """Generate an animation using an xarray.DataArray + kwargs: dict | None = None, +) -> AnimationUnit: + """ + Internal function for generating the plotting logic required for animations. Parameters --------- data DataArray containing the target data - fps - Frames per second for the animation + clear_axes + Decide whether to run ``ax.clear()`` in every update min_percentile Minimum percentile of the data max_percentile @@ -147,8 +228,11 @@ def animate( Custom title to add to the plot display_sdf_name Display the sdf file name in the animation title + move_window + Update the ``xlim`` to be only values that are not NaNs at each time interval t - Coordinate for t axis (the coordinate which will be animated over). If `None`, use data.dims[0] + Coordinate for t axis (the coordinate which will be animated over). + If ``None``, use ``data.dims[0]`` ax Matplotlib axes on which to plot kwargs @@ -156,18 +240,18 @@ def animate( Examples -------- - >>> ds["Derived_Number_Density_Electron"].epoch.animate() + >>> anim = animate(ds["Derived_Number_Density_Electron"]) + >>> anim.save("animation.gif") """ - import matplotlib.pyplot as plt # noqa: PLC0415 - from matplotlib.animation import FuncAnimation # noqa: PLC0415 + if kwargs is None: + kwargs = {} kwargs_original = kwargs.copy() - # Create plot if no ax is provided - if ax is None: - fig, ax = plt.subplots() - # Prevents figure from prematurely displaying in Jupyter notebook - plt.close(fig) + axis_kwargs = {} + for key in ("xlabel", "ylabel"): + if key in kwargs: + axis_kwargs[key] = kwargs.pop(key) # Sets the animation coordinate (t) for iteration. If time is in the coords # then it will set time to be t. If it is not it will fallback to the last @@ -182,68 +266,258 @@ def animate( N_frames = data[t].size + global_min = global_max = None if data.ndim == 2: - kwargs.setdefault("x", coord_names[0]) - plot = data.isel({t: 0}).plot(ax=ax, **kwargs) - ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t)) - global_min, global_max = compute_global_limits( - data, min_percentile, max_percentile + global_min, global_max = _setup_2d_plot( + data=data, + ax=ax, + coord_names=coord_names, + kwargs=kwargs, + axis_kwargs=axis_kwargs, + min_percentile=min_percentile, + max_percentile=max_percentile, + t=t, + ) + elif data.ndim == 3: + _setup_3d_plot( + data=data, + ax=ax, + coord_names=coord_names, + kwargs=kwargs, + kwargs_original=kwargs_original, + axis_kwargs=axis_kwargs, + min_percentile=min_percentile, + max_percentile=max_percentile, + t=t, ) - ax.set_ylim(global_min, global_max) - if data.ndim == 3: - if "norm" not in kwargs: - global_min, global_max = compute_global_limits( - data, min_percentile, max_percentile - ) - kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max) - kwargs["add_colorbar"] = False - # Set default x and y coordinates for 3D data if not provided - kwargs.setdefault("x", coord_names[0]) - kwargs.setdefault("y", coord_names[1]) - - # Finds the time step with the minimum data value - # This is needed so that the animation can use the correct colour bar - argmin_time = np.unravel_index(data.argmin(), data.shape)[0] - - # Initialize the plot, the final output will still start at the first time step - plot = data.isel({t: argmin_time}).plot(ax=ax, **kwargs) - ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t)) - kwargs["cmap"] = plot.cmap - - # Add colorbar - if kwargs_original.get("add_colorbar", True): - long_name = data.attrs.get("long_name") - units = data.attrs.get("units") - fig = plot.get_figure() - fig.colorbar(plot, ax=ax, label=f"{long_name} [{units}]") - - # check if there is a moving window by finding NaNs in the data - move_window = np.isnan(np.sum(data.values)) + ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t)) + + window_boundaries = None if move_window: window_boundaries = calculate_window_boundaries( data, kwargs.get("xlim"), kwargs["x"] ) def update(frame): + if clear_axes: + ax.clear() # Set the xlim for each frame in the case of a moving window if move_window: kwargs["xlim"] = window_boundaries[frame] - # Update plot for the new frame - ax.clear() - plot = data.isel({t: frame}).plot(ax=ax, **kwargs) ax.set_title(get_frame_title(data, frame, display_sdf_name, title, t)) + _set_axes_labels(ax, axis_kwargs) - if data.ndim == 2: + if data.ndim == 2 and "ylim" not in kwargs and global_min is not None: ax.set_ylim(global_min, global_max) + return plot + return AnimationUnit( + update=update, + n_frames=N_frames, + ) + + +def animate( + data: xr.DataArray, + fps: float = 10, + min_percentile: float = 0, + max_percentile: float = 100, + title: str | None = None, + display_sdf_name: bool = False, + move_window: bool = False, + t: str | None = None, + ax: plt.Axes | None = None, + **kwargs, +) -> FuncAnimation: + """ + Generate an animation using an `xarray.DataArray`. The intended use + of this function is via `sdf_xarray.plotting.EpochAccessor.animate`. + + Parameters + --------- + data + DataArray containing the target data + fps + Frames per second for the animation + min_percentile + Minimum percentile of the data + max_percentile + Maximum percentile of the data + title + Custom title to add to the plot + display_sdf_name + Display the sdf file name in the animation title + move_window + Update the ``xlim`` to be only values that are not NaNs at each time interval + t + Coordinate for t axis (the coordinate which will be animated over). + If ``None``, use ``data.dims[0]`` + ax + Matplotlib axes on which to plot + kwargs + Keyword arguments to be passed to matplotlib + + Examples + -------- + >>> anim = animate(ds["Derived_Number_Density_Electron"]) + >>> anim.save("animation.gif") + """ + import matplotlib.pyplot as plt # noqa: PLC0415 + from matplotlib.animation import FuncAnimation # noqa: PLC0415 + + # Create plot if no ax is provided + if ax is None: + fig, ax = plt.subplots() + # Prevents figure from prematurely displaying in Jupyter notebook + plt.close(fig) + + animation = _generate_animation( + data, + clear_axes=True, + min_percentile=min_percentile, + max_percentile=max_percentile, + title=title, + display_sdf_name=display_sdf_name, + move_window=move_window, + t=t, + ax=ax, + kwargs=kwargs, + ) + + return FuncAnimation( + ax.get_figure(), + animation.update, + frames=range(animation.n_frames), + interval=1000 / fps, + repeat=True, + ) + + +def animate_multiple( + *datasets: xr.DataArray, + datasets_kwargs: list[dict[str, Any]] | None = None, + fps: float = 10, + min_percentile: float = 0, + max_percentile: float = 100, + title: str | None = None, + display_sdf_name: bool = False, + move_window: bool = False, + t: str | None = None, + ax: plt.Axes | None = None, + **common_kwargs, +) -> FuncAnimation: + """ + Generate an animation using multiple `xarray.DataArray`. The intended use + of this function is via `sdf_xarray.dataset_accessor.EpochAccessor.animate_multiple`. + + Parameters + --------- + datasets + `xarray.DataArray` objects containing the data to be animated + datasets_kwargs + A list of dictionaries, following the same order as ``datasets``, containing + per-dataset matplotlib keyword arguments. The list does not need to be the same + length as ``datasets``; missing entries are initialised as empty dictionaries + fps + Frames per second for the animation + min_percentile + Minimum percentile of the data + max_percentile + Maximum percentile of the data + title + Custom title to add to the plot + display_sdf_name + Display the sdf file name in the animation title + move_window + Update the ``xlim`` to be only values that are not NaNs at each time interval + t + Coordinate for t axis (the coordinate which will be animated over). If ``None``, + use ``data.dims[0]`` + ax + Matplotlib axes on which to plot + common_kwargs + Matplotlib keyword arguments applied to all datasets. These are overridden by + per-dataset entries in ``datasets_kwargs`` + + Examples + -------- + >>> anim = animate_multiple( + ds["Derived_Number_Density_Electron"], + ds["Derived_Number_Density_Ion"], + datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}], + ylim=(0e27,4e27), + display_sdf_name=True, + ylabel="Derived Number Density [1/m$^3$]" + ) + >>> anim.save("animation.gif") + """ + import matplotlib.pyplot as plt # noqa: PLC0415 + from matplotlib.animation import FuncAnimation # noqa: PLC0415 + + if not datasets: + raise ValueError("At least one dataset must be provided") + + # Create plot if no ax is provided + if ax is None: + fig, ax = plt.subplots() + # Prevents figure from prematurely displaying in Jupyter notebook + plt.close(fig) + + n_datasets = len(datasets) + if datasets_kwargs is None: + # Initialise an empty series of dicts the same size as the number of datasets + datasets_kwargs = [{} for _ in range(n_datasets)] + else: + # The user might only want to use kwargs on some of the datasets so we make sure + # to initialise additional empty dicts and append them to the list + datasets_kwargs.extend({} for _ in range(n_datasets - len(datasets_kwargs))) + + animations: list[AnimationUnit] = [] + for da, kw in zip(datasets, datasets_kwargs): + animations.append( + _generate_animation( + da, + ax=ax, + min_percentile=min_percentile, + max_percentile=max_percentile, + title=title, + display_sdf_name=display_sdf_name, + move_window=move_window, + t=t, + # Per-dataset kwargs override common matplotlib kwargs + kwargs={**common_kwargs, **kw}, + ) + ) + + lengths = [anim.n_frames for anim in animations] + n_frames = min(lengths) + + if len(set(lengths)) > 1: + warnings.warn( + "Datasets have different frame counts; truncating to the shortest", + stacklevel=2, + ) + + # Render the legend if a label exists for any 2D dataset + show_legend = any( + "label" in kw and da.ndim == 2 for da, kw in zip(datasets, datasets_kwargs) + ) + + def update(frame): + ax.clear() + for anim in animations: + anim.update(frame) + if show_legend: + ax.legend(loc="upper right") + return FuncAnimation( ax.get_figure(), update, - frames=range(N_frames), + frames=range(n_frames), interval=1000 / fps, repeat=True, ) @@ -280,7 +554,7 @@ def animate(self, *args, **kwargs) -> FuncAnimation: Examples -------- >>> anim = ds["Electric_Field_Ey"].epoch.animate() - >>> anim.save("myfile.mp4") + >>> anim.save("animation.gif") >>> # Or in a jupyter notebook: >>> anim.show() """ diff --git a/tests/test_epoch_dataset_accessor.py b/tests/test_epoch_dataset_accessor.py index 7c6a59d..69b9e09 100644 --- a/tests/test_epoch_dataset_accessor.py +++ b/tests/test_epoch_dataset_accessor.py @@ -1,17 +1,32 @@ +import tempfile +from importlib.metadata import version + +import matplotlib as mpl import numpy as np import pytest import xarray as xr +from matplotlib.animation import PillowWriter +from packaging.version import Version from sdf_xarray import download, open_mfdataset -TEST_FILES_DIR = download.fetch_dataset("test_files_3D") +mpl.use("Agg") + +# TODO Remove this once the new kwarg options are fully implemented +if Version(version("xarray")) >= Version("2025.8.0"): + xr.set_options(use_new_combine_kwarg_defaults=True) + +TEST_FILES_DIR_1D = download.fetch_dataset("test_files_1D") +TEST_FILES_DIR_2D = download.fetch_dataset("test_two_probes_2D") +TEST_FILES_DIR_2D_MW = download.fetch_dataset("test_files_2D_moving_window") +TEST_FILES_DIR_3D = download.fetch_dataset("test_files_3D") def test_rescale_coords_X(): multiplier = 1e3 unit_label = "mm" - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as ds: + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: ds_rescaled = ds.epoch.rescale_coords( multiplier=multiplier, unit_label=unit_label, @@ -39,7 +54,7 @@ def test_rescale_coords_X_Y(): multiplier = 1e2 unit_label = "cm" - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as ds: + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: ds_rescaled = ds.epoch.rescale_coords( multiplier=multiplier, unit_label=unit_label, @@ -68,7 +83,7 @@ def test_rescale_coords_X_Y_tuple(): multiplier = 1e2 unit_label = "cm" - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as ds: + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: ds_rescaled = ds.epoch.rescale_coords( multiplier=multiplier, unit_label=unit_label, @@ -97,7 +112,7 @@ def test_rescale_coords_attributes_copied(): multiplier = 1e6 unit_label = "µm" - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as ds: + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: ds_rescaled = ds.epoch.rescale_coords( multiplier=multiplier, unit_label=unit_label, @@ -110,7 +125,7 @@ def test_rescale_coords_attributes_copied(): def test_rescale_coords_non_existent_coord(): - with xr.open_dataset(TEST_FILES_DIR / "0000.sdf") as ds: + with xr.open_dataset(TEST_FILES_DIR_3D / "0000.sdf") as ds: with pytest.raises(ValueError, match="Coordinate 'Time' not found"): ds.epoch.rescale_coords( multiplier=1.0, @@ -130,7 +145,7 @@ def test_rescale_coords_time(): multiplier = 1e-15 unit_label = "fs" - with open_mfdataset(TEST_FILES_DIR.glob("*.sdf")) as ds: + with open_mfdataset(TEST_FILES_DIR_3D.glob("*.sdf")) as ds: ds_rescaled = ds.epoch.rescale_coords( multiplier=multiplier, unit_label=unit_label, @@ -142,3 +157,88 @@ def test_rescale_coords_time(): assert ds_rescaled["time"].attrs["units"] == unit_label assert ds_rescaled["time"].attrs["long_name"] == "Time" assert ds_rescaled["time"].attrs["full_name"] == "time" + + +def test_animate_multiple_accessor(): + with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds: + assert hasattr(ds, "epoch") + assert hasattr(ds.epoch, "animate_multiple") + + +def test_animate_multiple_headless_single(): + with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds: + anim = ds.epoch.animate_multiple(ds["Derived_Number_Density_electron"]) + + # Specify a custom writable temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/output.gif" + try: + anim.save(temp_file_path, writer=PillowWriter(fps=2)) + except Exception as e: + pytest.fail(f"animate().save() failed in headless mode: {e}") + + +def test_animate_multiple_headless_multiple(): + with open_mfdataset(TEST_FILES_DIR_1D.glob("*.sdf")) as ds: + anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_electron"], ds["Derived_Number_Density_proton"] + ) + + # Specify a custom writable temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/output.gif" + try: + anim.save(temp_file_path, writer=PillowWriter(fps=2)) + except Exception as e: + pytest.fail(f"animate().save() failed in headless mode: {e}") + + +def test_animate_multiple_headless_single_kwargs(): + with open_mfdataset(TEST_FILES_DIR_2D.glob("*.sdf")) as ds: + anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_Electron"], datasets_kwargs=[{"cmap": "viridis"}] + ) + # Force the first frame to be drawn + anim._func(0) + ax = anim._fig.axes[0] + + mesh = next((m for m in ax.get_children() if hasattr(m, "get_cmap")), None) + assert mesh is not None, "No artist with a colormap found" + assert mesh.get_cmap().name == "viridis" + + # Specify a custom writable temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/output.gif" + try: + anim.save(temp_file_path, writer=PillowWriter(fps=2)) + except Exception as e: + pytest.fail(f"animate().save() failed in headless mode: {e}") + + +def test_animate_multiple_headless_multiple_kwargs(): + with open_mfdataset(TEST_FILES_DIR_2D.glob("*.sdf")) as ds: + anim = ds.epoch.animate_multiple( + ds["Derived_Number_Density_Electron"], + ds["Derived_Number_Density_Ion_H"], + datasets_kwargs=[{"cmap": "viridis"}, {"cmap": "plasma"}], + ) + # Force the first frame to be drawn + anim._func(0) + ax = anim._fig.axes[0] + + # Collect all artists that have a colormap + meshes = [m for m in ax.get_children() if hasattr(m, "get_cmap")] + assert len(meshes) == 2, "Expected two artists with colormaps" + + # Check colormaps in order + expected_cm = ["viridis", "plasma"] + for mesh, expected in zip(meshes, expected_cm): + assert mesh.get_cmap().name == expected + + # Specify a custom writable temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + temp_file_path = f"{temp_dir}/output.gif" + try: + anim.save(temp_file_path, writer=PillowWriter(fps=2)) + except Exception as e: + pytest.fail(f"animate().save() failed in headless mode: {e}")