diff --git a/pde/fields/collection.py b/pde/fields/collection.py index c28225fc..41152333 100644 --- a/pde/fields/collection.py +++ b/pde/fields/collection.py @@ -886,50 +886,6 @@ def _plot_merged_image( } return PlotReference(ax, axes_image, parameters) - @plot_on_axes(update_method="_update_rgb_image_plot") - def _plot_rgb_image( - self, - ax, - transpose: bool = False, - vmin: float | list[float | None] | None = None, - vmax: float | list[float | None] | None = None, - **kwargs, - ) -> PlotReference: - r"""Visualize fields by mapping to different color chanels in a 2d density plot. - - Args: - ax (:class:`matplotlib.axes.Axes`): - Figure axes to be used for plotting. - transpose (bool): - Determines whether the transpose of the data is plotted - vmin, vmax (float, list of float): - Define the data range that the color chanels cover. By default, they - cover the complete value range of the supplied data. - \**kwargs: - Additional keyword arguments that affect the image. Non-Cartesian grids - might support `performance_goal` to influence how an image is created - from raw data. Finally, remaining arguments are passed to - :func:`matplotlib.pyplot.imshow` to affect the appearance. - - Returns: - :class:`PlotReference`: Instance that contains information to update the - plot with new data later. - """ - # since 2024-01-25 - warnings.warn( - "`rgb_image` is deprecated in favor of `merged`", DeprecationWarning - ) - return self._plot_merged_image( # type: ignore - ax=ax, - colors="rgb", - background_color="k", - projection="max", - transpose=transpose, - vmin=vmin, - vmax=vmax, - **kwargs, - ) - def _update_plot(self, reference: list[PlotReference]) -> None: """Update a plot collection with the current field values. @@ -984,7 +940,7 @@ def plot( List of :class:`PlotReference`: Instances that contain information to update all the plots with new data later. """ - if kind in {"merged", "rgb", "rgb_image", "rgb-image"}: + if kind in {"merged"}: num_panels = 1 else: num_panels = len(self) @@ -1024,14 +980,6 @@ def plot( ) ] - elif kind in {"rgb", "rgb_image", "rgb-image"}: - # plot a single RGB representation - reference = [ - self._plot_rgb_image( - ax=axs[0], action="none", **kwargs, **subplot_args[0] - ) - ] - else: # plot all the elements onto the respective axes if isinstance(kind, str): diff --git a/pde/grids/boundaries/local.py b/pde/grids/boundaries/local.py index 81bf7547..9563f478 100644 --- a/pde/grids/boundaries/local.py +++ b/pde/grids/boundaries/local.py @@ -73,12 +73,7 @@ from ...tools.cache import cached_method from ...tools.docstrings import fill_in_docstring from ...tools.numba import address_as_void_pointer, jit, numba_dict -from ...tools.typing import ( - AdjacentEvaluator, - FloatNumerical, - GhostCellSetter, - VirtualPointEvaluator, -) +from ...tools.typing import FloatNumerical, GhostCellSetter, VirtualPointEvaluator from ..base import GridBase, PeriodicityError if TYPE_CHECKING: @@ -592,23 +587,6 @@ def make_virtual_point_evaluator(self) -> VirtualPointEvaluator: the boundary condition. """ - def make_adjacent_evaluator(self) -> AdjacentEvaluator: - """Returns a function evaluating the value adjacent to a given point. - - .. deprecated:: Since 2023-12-19 - - Returns: - function: A function with signature (arr_1d, i_point, bc_idx), where - `arr_1d` is the one-dimensional data array (the data points along the axis - perpendicular to the boundary), `i_point` is the index into this array for - the current point and bc_idx are the remaining indices of the current point, - which indicate the location on the boundary plane. The result of the - function is the data value at the adjacent point along the axis associated - with this boundary condition in the upper (lower) direction when `upper` is - True (False). - """ - raise NotImplementedError - @abstractmethod def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None: """Set the ghost cell values for this boundary. @@ -1529,9 +1507,6 @@ def get_sparse_matrix_data( def get_virtual_point(self, arr, idx: tuple[int, ...] | None = None) -> float: raise NotImplementedError - def make_adjacent_evaluator(self) -> AdjacentEvaluator: - raise NotImplementedError - def _get_value_cell_index(self, with_ghost_cells: bool) -> int: if self.value_cell is None: # pick adjacent cell by default @@ -2159,69 +2134,6 @@ def virtual_point( return virtual_point # type: ignore - def make_adjacent_evaluator(self) -> AdjacentEvaluator: - # method deprecated since 2023-12-19 - warnings.warn("`make_adjacent_evaluator` is deprecated", DeprecationWarning) - # get values distinguishing upper from lower boundary - if self.upper: - i_bndry = self.grid.shape[self.axis] - 1 - i_dx = 1 - else: - i_bndry = 0 - i_dx = -1 - - if self.homogeneous: - # the boundary condition does not depend on space - - # calculate necessary constants - const, factor, index = self.get_virtual_point_data(compiled=True) - zeros = np.zeros(self._shape_tensor) - ones = np.ones(self._shape_tensor) - - @register_jitable(inline="always") - def adjacent_point( - arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...] - ) -> FloatNumerical: - """Evaluate the value adjacent to the current point.""" - # determine the parameters for evaluating adjacent point. Note - # that defining the variables c and f for the interior points - # seems needless, but it turns out that this results in a 10x - # faster function (because of branch prediction?). - if i_point == i_bndry: - c, f, i = const(), factor(), index - else: - c, f, i = zeros, ones, i_point + i_dx # INTENTIONAL - - # calculate the values - return c + f * arr_1d[..., i] # type: ignore - - else: - # the boundary condition is a function of space - - # calculate necessary constants - const, factor, index = self.get_virtual_point_data(compiled=True) - zeros = np.zeros(self._shape_tensor + self._shape_boundary) - ones = np.ones(self._shape_tensor + self._shape_boundary) - - @register_jitable(inline="always") - def adjacent_point(arr_1d, i_point, bc_idx) -> float: - """Evaluate the value adjacent to the current point.""" - # determine the parameters for evaluating adjacent point. Note - # that defining the variables c and f for the interior points - # seems needless, but it turns out that this results in a 10x - # faster function (because of branch prediction?). This is - # surprising, because it uses arrays zeros and ones that are - # quite pointless - if i_point == i_bndry: - c, f, i = const(), factor(), index - else: - c, f, i = zeros, ones, i_point + i_dx # INTENTIONAL - - # calculate the values - return c[bc_idx] + f[bc_idx] * arr_1d[..., i] # type: ignore - - return adjacent_point # type: ignore - def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None: # calculate necessary constants const, factor, index = self.get_virtual_point_data() @@ -2754,73 +2666,6 @@ def virtual_point(arr: np.ndarray, idx: tuple[int, ...], args=None): return virtual_point # type: ignore - def make_adjacent_evaluator(self) -> AdjacentEvaluator: - # method deprecated since 2023-12-19 - warnings.warn("`make_adjacent_evaluator` is deprecated", DeprecationWarning) - size = self.grid.shape[self.axis] - if size < 2: - raise ValueError( - f"Need at least two support points along axis {self.axis} to apply " - "boundary conditions" - ) - - # get values distinguishing upper from lower boundary - if self.upper: - i_bndry = size - 1 - i_dx = 1 - else: - i_bndry = 0 - i_dx = -1 - - # calculate necessary constants - data_vp = self.get_virtual_point_data() - - zeros = np.zeros_like(self.value) - ones = np.ones_like(self.value) - - if self.homogeneous: - # the boundary condition does not depend on space - - @register_jitable - def adjacent_point( - arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...] - ) -> float: - """Evaluate the value adjacent to the current point.""" - # determine the parameters for evaluating adjacent point - if i_point == i_bndry: - data = data_vp - else: - data = (zeros, ones, i_point + i_dx, zeros, 0) - - # calculate the values - return ( # type: ignore - data[0] - + data[1] * arr_1d[..., data[2]] - + data[3] * arr_1d[..., data[4]] - ) - - else: - # the boundary condition is a function of space - - @register_jitable - def adjacent_point( - arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...] - ) -> float: - """Evaluate the value adjacent to the current point.""" - # determine the parameters for evaluating adjacent point - if i_point == i_bndry: - data = data_vp - else: - data = (zeros, ones, i_point + i_dx, zeros, 0) - - return ( # type: ignore - data[0][bc_idx] - + data[1][bc_idx] * arr_1d[..., data[2]] - + data[3][bc_idx] * arr_1d[..., data[4]] - ) - - return adjacent_point # type: ignore - def set_ghost_cells(self, data_full: np.ndarray, *, args=None) -> None: # calculate necessary constants data = self.get_virtual_point_data() diff --git a/pde/storage/base.py b/pde/storage/base.py index 46baf370..4dbf64ce 100644 --- a/pde/storage/base.py +++ b/pde/storage/base.py @@ -284,7 +284,6 @@ def tracker( interrupts: InterruptData = 1, *, transformation: Callable[[FieldBase, float], FieldBase] | None = None, - interval=None, ) -> StorageTracker: """Create object that can be used as a tracker to fill this storage. @@ -321,10 +320,7 @@ def add_to_state(state): possible by defining appropriate :func:`add_to_state` """ return StorageTracker( - storage=self, - interrupts=interrupts, - transformation=transformation, - interval=interval, + storage=self, interrupts=interrupts, transformation=transformation ) def start_writing(self, field: FieldBase, info: InfoDict | None = None) -> None: @@ -566,7 +562,6 @@ def __init__( interrupts: InterruptData = 1, *, transformation: Callable[[FieldBase, float], FieldBase] | None = None, - interval=None, ): """ Args: @@ -582,7 +577,7 @@ def __init__( the current field, while the optional second argument is the associated time. """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.storage = storage if transformation is not None and not callable(transformation): raise TypeError("`transformation` must be callable") diff --git a/pde/tools/numba.py b/pde/tools/numba.py index 300d34e8..c6f7e6f5 100644 --- a/pde/tools/numba.py +++ b/pde/tools/numba.py @@ -13,30 +13,13 @@ import numba as nb import numpy as np from numba.core.types import npytypes, scalars -from numba.extending import overload, register_jitable +from numba.extending import is_jitted, overload, register_jitable from numba.typed import Dict as NumbaDict from .. import config from ..tools.misc import decorator_arguments from .typing import Number -try: - # is_jitted has been added in numba 0.53 on 2021-03-11 - from numba.extending import is_jitted - -except ImportError: - # for earlier version of numba, we need to define the function - - def is_jitted(function: Callable) -> bool: - """Determine whether a function has already been jitted.""" - try: - from numba.core.dispatcher import Dispatcher - except ImportError: - # assume older numba module structure - from numba.dispatcher import Dispatcher - return isinstance(function, Dispatcher) - - # numba version as a list of integers NUMBA_VERSION = [int(v) for v in nb.__version__.split(".")[:2]] @@ -177,7 +160,6 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF return function # prepare the compilation arguments - kwargs.setdefault("nopython", True) if config["numba.fastmath"] is True: # enable some (but not all) fastmath flags. We skip the flags that affect # handling of infinities and NaN for safety by default. Use "fast" to enable all @@ -192,10 +174,10 @@ def jit(function: TFunc, signature=None, parallel: bool = False, **kwargs) -> TF # log some details logger = logging.getLogger(__name__) name = getattr(function, "__name__", "") - if kwargs["nopython"]: # standard case - logger.info("Compile `%s` with parallel=%s", name, kwargs["parallel"]) - else: # this might imply numba falls back to object mode - logger.warning("Compile `%s` with nopython=False", name) + if kwargs["parallel"]: + logger.info("Compile `%s`", name) + else: + logger.info("Compile `%s` with parallel=True", name) # increase the compilation counter by one JIT_COUNT.increment() @@ -308,7 +290,7 @@ def get_common_numba_dtype(*args): return nb.double -@jit(nopython=True, nogil=True) +@jit(nogil=True) def _random_seed_compiled(seed: int) -> None: """Sets the seed of the random number generator of numba.""" np.random.seed(seed) @@ -325,7 +307,7 @@ def random_seed(seed: int = 0) -> None: _random_seed_compiled(seed) -if NUMBA_VERSION < [0, 45]: +if NUMBA_VERSION < [0, 59]: warnings.warn( - "Your numba version is outdated. Please install at least version 0.45" + "Your numba version is outdated. Please install at least version 0.59" ) diff --git a/pde/tools/typing.py b/pde/tools/typing.py index 318728d4..d202be16 100644 --- a/pde/tools/typing.py +++ b/pde/tools/typing.py @@ -44,13 +44,6 @@ def __call__(self, arr: np.ndarray, idx: tuple[int, ...], args=None) -> float: """Evaluate the virtual point at the given position.""" -class AdjacentEvaluator(Protocol): - def __call__( - self, arr_1d: np.ndarray, i_point: int, bc_idx: tuple[int, ...] - ) -> float: - """Evaluate the values at adjecent points.""" - - class GhostCellSetter(Protocol): def __call__(self, data_full: np.ndarray, args=None) -> None: """Set the ghost cells.""" diff --git a/pde/trackers/base.py b/pde/trackers/base.py index d8ba3a56..b5bdec19 100644 --- a/pde/trackers/base.py +++ b/pde/trackers/base.py @@ -37,19 +37,12 @@ class TrackerBase(metaclass=ABCMeta): _subclasses: dict[str, type[TrackerBase]] = {} # all inheriting classes @fill_in_docstring - def __init__(self, interrupts: InterruptData = 1, *, interval=None): + def __init__(self, interrupts: InterruptData = 1): """ Args: interrupts: {ARG_TRACKER_INTERRUPT} """ - if interval is not None: - # deprecated on 2023-12-23 - warnings.warn( - "Argument `interval` has been renamed to `interrupts`", - DeprecationWarning, - ) - interrupts = interval self.interrupt = parse_interrupt(interrupts) def __init_subclass__(cls, **kwargs): diff --git a/pde/trackers/interactive.py b/pde/trackers/interactive.py index 01128380..435a6296 100644 --- a/pde/trackers/interactive.py +++ b/pde/trackers/interactive.py @@ -250,7 +250,6 @@ def __init__( *, close: bool = True, show_time: bool = False, - interval=None, ): """ Args: @@ -263,7 +262,7 @@ def __init__( show_time (bool): Whether to indicate the time """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.close = close self.show_time = show_time diff --git a/pde/trackers/trackers.py b/pde/trackers/trackers.py index 05533ed0..445a6d5e 100644 --- a/pde/trackers/trackers.py +++ b/pde/trackers/trackers.py @@ -60,7 +60,7 @@ def check_simulation(state, time): raise StopIteration - tracker = CallbackTracker(check_simulation, interval="0:10") + tracker = CallbackTracker(check_simulation, interrupts="0:10") Adding :code:`tracker` to the simulation will perform a check every 10 real time seconds. If the integral of the entire state falls below zero, the simulation @@ -68,13 +68,7 @@ def check_simulation(state, time): """ @fill_in_docstring - def __init__( - self, - func: Callable, - interrupts: InterruptData = 1, - *, - interval=None, - ): + def __init__(self, func: Callable, interrupts: InterruptData = 1): """ Args: func: @@ -89,7 +83,7 @@ def __init__( interrupts: {ARG_TRACKER_INTERRUPT} """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self._callback = func self._num_args = len(inspect.signature(func).parameters) if not 0 < self._num_args < 3: @@ -126,7 +120,6 @@ def __init__( fancy: bool = True, ndigits: int = 5, leave: bool = True, - interval=None, ): """ Args: @@ -146,7 +139,7 @@ def __init__( if interrupts is None: interrupts = RealtimeInterrupts(duration=1) # print every second by default - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.fancy = fancy self.ndigits = ndigits self.leave = leave @@ -243,13 +236,7 @@ class PrintTracker(TrackerBase): name = "print" @fill_in_docstring - def __init__( - self, - interrupts: InterruptData = 1, - stream: IO[str] = sys.stdout, - *, - interval=None, - ): + def __init__(self, interrupts: InterruptData = 1, stream: IO[str] = sys.stdout): """ Args: @@ -258,7 +245,7 @@ def __init__( stream: The stream used for printing """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.stream = stream def handle(self, field: FieldBase, t: float) -> None: @@ -288,7 +275,7 @@ class PlotTracker(TrackerBase): .. code-block:: python - movie_tracker = PlotTracker(interval=10, movie="my_movie.mp4") + movie_tracker = PlotTracker(interrupts=10, movie="my_movie.mp4") eq.solve(..., tracker=movie_tracker) This will create the file `my_movie.mp4` during the simulation. Note that you @@ -307,7 +294,6 @@ def __init__( tight_layout: bool = False, max_fps: float = math.inf, plot_args: dict[str, Any] | None = None, - interval=None, ): """ Args: @@ -366,7 +352,7 @@ def __init__( from ..visualization.movies import Movie # initialize the tracker - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.title = title self.output_file = output_file self.tight_layout = tight_layout @@ -533,7 +519,7 @@ class LivePlotTracker(PlotTracker): """PlotTracker with defaults for live plotting. The only difference to :class:`PlotTracker` are the changed default values, where - output is by default shown on screen and the `interval` is set something more + output is by default shown on screen and the `interrupts` is set something more suitable for interactive plotting. In particular, this tracker can be enabled by simply listing 'plot' as a tracker. """ @@ -547,7 +533,6 @@ def __init__( *, show: bool = True, max_fps: float = 2, - interval=None, **kwargs, ): """ @@ -581,13 +566,7 @@ def __init__( instance, the value `{'ax_style': {'ylim': (0, 1)}}` enforces the y-axis to lie between 0 and 1. """ - super().__init__( - interrupts=interrupts, - interval=interval, - show=show, - max_fps=max_fps, - **kwargs, - ) + super().__init__(interrupts=interrupts, show=show, max_fps=max_fps, **kwargs) class DataTracker(CallbackTracker): @@ -602,10 +581,10 @@ def get_statistics(state, time): return {"mean": state.data.mean(), "variance": state.data.var()} - data_tracker = DataTracker(get_statistics, interval=10) + data_tracker = DataTracker(get_statistics, interrupts=10) Adding :code:`data_tracker` to the simulation will gather the statistics every - 10 time units. After the simulation, the final result will be accessable via the + 10 time units. After the simulation, the final result will be accessible via the :attr:`data` attribute or conveniently as a pandas from the :attr:`dataframe` attribute. @@ -624,7 +603,6 @@ def __init__( interrupts: InterruptData = 1, *, filename: str | None = None, - interval=None, ): """ Args: @@ -648,7 +626,7 @@ def __init__( storing a tuple `(self.times, self.data)`, whereas any other data format requires :mod:`pandas`. """ - super().__init__(func=func, interrupts=interrupts, interval=interval) + super().__init__(func=func, interrupts=interrupts) self.filename = filename self.times: list[float] = [] self.data: list[Any] = [] @@ -756,7 +734,6 @@ def __init__( *, progress: bool = False, evolution_rate: Callable[[np.ndarray, float], np.ndarray] | None = None, - interval=None, ): """ Args: @@ -779,7 +756,7 @@ def __init__( """ if interrupts is None: interrupts = RealtimeInterrupts(duration=1) - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.atol = atol self.rtol = rtol self.evolution_rate = evolution_rate @@ -861,9 +838,7 @@ class RuntimeTracker(TrackerBase): """Tracker interrupting the simulation once a duration has passed.""" @fill_in_docstring - def __init__( - self, max_runtime: Real | str, interrupts: InterruptData = 1, *, interval=None - ): + def __init__(self, max_runtime: Real | str, interrupts: InterruptData = 1): """ Args: max_runtime (float or str): @@ -874,7 +849,7 @@ def __init__( interrupts: {ARG_TRACKER_INTERRUPT} """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) try: self.max_runtime = float(max_runtime) @@ -916,7 +891,7 @@ class ConsistencyTracker(TrackerBase): name = "consistency" @fill_in_docstring - def __init__(self, interrupts: InterruptData | None = None, *, interval=None): + def __init__(self, interrupts: InterruptData | None = None): """ Args: interrupts: @@ -926,7 +901,7 @@ def __init__(self, interrupts: InterruptData | None = None, *, interval=None): """ if interrupts is None: interrupts = RealtimeInterrupts(duration=1) - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) def handle(self, field: FieldBase, t: float) -> None: """Handle data supplied to this tracker. @@ -948,12 +923,7 @@ class MaterialConservationTracker(TrackerBase): @fill_in_docstring def __init__( - self, - interrupts: InterruptData = 1, - atol: float = 1e-4, - rtol: float = 1e-4, - *, - interval=None, + self, interrupts: InterruptData = 1, atol: float = 1e-4, rtol: float = 1e-4 ): """ Args: @@ -964,7 +934,7 @@ def __init__( rtol (float): Relative tolerance for amount deviations """ - super().__init__(interrupts=interrupts, interval=interval) + super().__init__(interrupts=interrupts) self.atol = atol self.rtol = rtol diff --git a/tests/conftest.py b/tests/conftest.py index 4f93a2b0..4309aff6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ .. codeauthor:: David Zwicker """ +import gc + import matplotlib.pyplot as plt import numpy as np import pytest diff --git a/tests/fields/test_field_collections.py b/tests/fields/test_field_collections.py index 4ceecfa7..c27e5c2d 100644 --- a/tests/fields/test_field_collections.py +++ b/tests/fields/test_field_collections.py @@ -348,16 +348,6 @@ def test_collection_apply(rng): np.testing.assert_allclose(f1.apply("s1 * v2").data, v.data * 2) -@pytest.mark.parametrize("num", [1, 2, 3]) -def test_rgb_image_plotting(num): - """Test plotting of collections as rgb fields.""" - grid = UnitGrid([16, 8]) - fc = FieldCollection([ScalarField.random_uniform(grid) for _ in range(num)]) - - refs = fc.plot("rgb_image") - fc._update_plot(refs) - - @pytest.mark.parametrize("num", [1, 2, 3, 4]) def test_merged_image_plotting(num): """Test plotting of collections as merged images.""" diff --git a/tests/grids/boundaries/test_local_boundaries.py b/tests/grids/boundaries/test_local_boundaries.py index 27b8ecd9..789cc229 100644 --- a/tests/grids/boundaries/test_local_boundaries.py +++ b/tests/grids/boundaries/test_local_boundaries.py @@ -258,20 +258,9 @@ def test_inhomogeneous_bcs_2d(): assert ev(data, (1, 0)) == pytest.approx(1.5) assert ev(data, (1, 1)) == pytest.approx(2.5) - ev = bc_x.make_adjacent_evaluator() - assert ev(*_get_arr_1d(data, (0, 0), axis=0)) == pytest.approx(1) - assert ev(*_get_arr_1d(data, (0, 1), axis=0)) == pytest.approx(1) - assert ev(*_get_arr_1d(data, (1, 0), axis=0)) == pytest.approx(1.5) - assert ev(*_get_arr_1d(data, (1, 1), axis=0)) == pytest.approx(2.5) - # test lower bc bc_x = BCBase.from_data(g, 0, False, {"curvature": "y"}) assert bc_x.axis_coord == 0 - ev = bc_x.make_adjacent_evaluator() - assert ev(*_get_arr_1d(data, (1, 0), axis=0)) == pytest.approx(1) - assert ev(*_get_arr_1d(data, (1, 1), axis=0)) == pytest.approx(1) - assert ev(*_get_arr_1d(data, (0, 0), axis=0)) == pytest.approx(1.5) - assert ev(*_get_arr_1d(data, (0, 1), axis=0)) == pytest.approx(2.5) @pytest.mark.parametrize("expr", ["1", "x + y**2"]) diff --git a/tests/storage/test_generic_storages.py b/tests/storage/test_generic_storages.py index 368da7f9..db7b4070 100644 --- a/tests/storage/test_generic_storages.py +++ b/tests/storage/test_generic_storages.py @@ -2,6 +2,7 @@ .. codeauthor:: David Zwicker """ +import contextlib import functools import platform @@ -33,11 +34,11 @@ if mr is not None: STORAGE_CLASSES_ALL.append((0, False, mr.storage.MemoryStorage)) STORAGE_CLASSES_ALL.append((0, False, mr.storage.JSONStorage)) - if module_available("yaml"): - STORAGE_CLASSES_ALL.append((0, False, mr.storage.YAMLStorage)) - if module_available("zarr"): + with contextlib.suppress(AttributeError): STORAGE_CLASSES_ALL.append((0, False, mr.storage.ZarrStorage)) - if module_available("h5py"): + with contextlib.suppress(AttributeError): + STORAGE_CLASSES_ALL.append((0, False, mr.storage.YAMLStorage)) + with contextlib.suppress(AttributeError): STORAGE_CLASSES_ALL.append((0, False, mr.storage.HDFStorage))