diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a12cbaad5..f4de451cc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## v0.x.x Unreleased ### New features +* Added `to_zarr` method to InferenceData +* Added `from_zarr` method to InferenceData ### Maintenance and fixes diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 68439889ff..d20464ab15 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1,14 +1,14 @@ -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,too-many-public-methods """Data structure for using netcdf groups with xarray.""" +import sys import uuid import warnings from collections import OrderedDict, defaultdict -from collections.abc import Sequence +from collections.abc import MutableMapping, Sequence from copy import copy as ccopy from copy import deepcopy from datetime import datetime from html import escape -import sys from typing import ( TYPE_CHECKING, Any, @@ -26,6 +26,7 @@ import netCDF4 as nc import numpy as np import xarray as xr +from packaging import version from ..rcparams import rcParams from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs @@ -589,6 +590,106 @@ def to_dataframe( (dfs,) = dfs.values() return dfs + def to_zarr(self, store=None): + """Convert InferenceData to a :class:`zarr.hierarchy.Group`. + + The zarr storage is using the same group names as the InferenceData. + + Raises + ------ + TypeError + If no valid store is found. + + Parameters + ---------- + store: zarr.storage i.e MutableMapping or str, optional + Zarr storage class or path to desired DirectoryStore. + + Returns + ------- + zarr.hierarchy.group + A zarr hierarchy group containing the InferenceData. + + References + ---------- + https://zarr.readthedocs.io/ + """ + try: # Check zarr + import zarr + + assert version.parse(zarr.__version__) >= version.parse("2.5.0") + except (ImportError, AssertionError) as err: + raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err + + # Check store type and create store if necessary + if store is None: + store = zarr.storage.TempStore(suffix="arviz") + elif isinstance(store, str): + store = zarr.storage.DirectoryStore(path=store) + elif not isinstance(store, MutableMapping): + raise TypeError(f"No valid store found: {store}") + + groups = self.groups() + + if not groups: + raise TypeError("No valid groups found!") + + for group in groups: + # Create zarr group in store with same group name + getattr(self, group).to_zarr(store=store, group=group, mode="w") + + return zarr.open(store) # Open store to get overarching group + + @staticmethod + def from_zarr(store) -> "InferenceData": + """Initialize object from a zarr store or path. + + Expects that the zarr store will have groups, each of which can be loaded by xarray. + By default, the datasets of the InferenceData object will be lazily loaded instead + of being loaded into memory. This + behaviour is regulated by the value of ``az.rcParams["data.load"]``. + + Parameters + ---------- + store: MutableMapping or zarr.hierarchy.Group or str. + Zarr storage class or path to desired Store. + + Returns + ------- + InferenceData object + + References + ---------- + https://zarr.readthedocs.io/ + """ + try: + import zarr + + assert version.parse(zarr.__version__) >= version.parse("2.5.0") + except (ImportError, AssertionError) as err: + raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err + + # Check store type and create store if necessary + if isinstance(store, str): + store = zarr.storage.DirectoryStore(path=store) + elif isinstance(store, zarr.hierarchy.Group): + store = store.store + elif not isinstance(store, MutableMapping): + raise TypeError(f"No valid store found: {store}") + + groups = {} + zarr_handle = zarr.open(store, mode="r") + + # Open each group via xarray method + for key_group, _ in zarr_handle.groups(): + with xr.open_zarr(store=store, group=key_group) as data: + if rcParams["data.load"] == "eager": + groups[key_group] = data.load() + else: + groups[key_group] = data + + return InferenceData(**groups) + def __add__(self, other: "InferenceData") -> "InferenceData": """Concatenate two InferenceData objects.""" return concat(self, other, copy=True, inplace=False) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 4894075608..2ed3a9ff5a 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1,5 +1,6 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name # pylint: disable=too-many-lines + import os from collections import namedtuple from copy import deepcopy diff --git a/arviz/tests/base_tests/test_data_zarr.py b/arviz/tests/base_tests/test_data_zarr.py new file mode 100644 index 0000000000..7a77a2587f --- /dev/null +++ b/arviz/tests/base_tests/test_data_zarr.py @@ -0,0 +1,106 @@ +# pylint: disable=redefined-outer-name +import importlib +import os +import shutil +from collections.abc import MutableMapping + +import numpy as np +import pytest + + +from arviz import InferenceData, from_dict + +from ..helpers import ( # pylint: disable=unused-import + chains, + check_multiple_attrs, + draws, + eight_schools_params, + running_on_ci, +) + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + importlib.util.find_spec("zarr") is None and not running_on_ci(), + reason="test requires zarr which is not installed", +) + +import zarr # pylint: disable=wrong-import-position, wrong-import-order + + +class TestDataZarr: + @pytest.fixture(scope="class") + def data(self, draws, chains): + class Data: + # fake 8-school output + obj = {} + for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items(): + obj[key] = np.random.randn(chains, draws, *shape) + + return Data + + def get_inference_data(self, data, eight_schools_params): + return from_dict( + posterior=data.obj, + posterior_predictive=data.obj, + sample_stats=data.obj, + prior=data.obj, + prior_predictive=data.obj, + sample_stats_prior=data.obj, + observed_data=eight_schools_params, + coords={"school": np.arange(8)}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + + @pytest.mark.parametrize("store", [0, 1, 2]) + def test_io_method(self, data, eight_schools_params, store): + # create InferenceData and check it has been properly created + inference_data = self.get_inference_data( # pylint: disable=W0612 + data, eight_schools_params + ) + test_dict = { + "posterior": ["eta", "theta", "mu", "tau"], + "posterior_predictive": ["eta", "theta", "mu", "tau"], + "sample_stats": ["eta", "theta", "mu", "tau"], + "prior": ["eta", "theta", "mu", "tau"], + "prior_predictive": ["eta", "theta", "mu", "tau"], + "sample_stats_prior": ["eta", "theta", "mu", "tau"], + "observed_data": ["J", "y", "sigma"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + # check filename does not exist and use to_zarr method + here = os.path.dirname(os.path.abspath(__file__)) + data_directory = os.path.join(here, "..", "saved_models") + filepath = os.path.join(data_directory, "zarr") + assert not os.path.exists(filepath) + + # InferenceData method + if store == 0: + # Tempdir + store = inference_data.to_zarr(store=None) + assert isinstance(store, MutableMapping) + elif store == 1: + inference_data.to_zarr(store=filepath) + # assert file has been saved correctly + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + elif store == 2: + store = zarr.storage.DirectoryStore(filepath) + inference_data.to_zarr(store=store) + # assert file has been saved correctly + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + + if isinstance(store, MutableMapping): + inference_data2 = InferenceData.from_zarr(store) + else: + inference_data2 = InferenceData.from_zarr(filepath) + + # Everything in dict still available in inference_data2 ? + fails = check_multiple_attrs(test_dict, inference_data2) + assert not fails + + # Remove created folder structure + if os.path.exists(filepath): + shutil.rmtree(filepath) + assert not os.path.exists(filepath) diff --git a/doc/source/api/inference_data.rst b/doc/source/api/inference_data.rst index a086590100..ddb447c2be 100644 --- a/doc/source/api/inference_data.rst +++ b/doc/source/api/inference_data.rst @@ -33,6 +33,8 @@ IO / Conversion InferenceData.to_json InferenceData.from_netcdf InferenceData.to_netcdf + InferenceData.from_zarr + InferenceData.to_zarr InferenceData.chunk InferenceData.compute InferenceData.load diff --git a/doc/source/conf.py b/doc/source/conf.py index 7a6922b94f..91801d2e06 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -137,8 +137,8 @@ # documentation. html_theme_options = { - "github_url": "https://github.com/arviz-devs/arviz", - "twitter_url": "https://twitter.com/arviz_devs", + "github_url": "https://github.com/arviz-devs/arviz", + "twitter_url": "https://twitter.com/arviz_devs", } # Add any paths that contain custom static files (such as style sheets) here, @@ -153,7 +153,7 @@ # use additional pages to add a 404 page html_additional_pages = { - '404': '404.html', + "404": "404.html", } # -- Options for HTMLHelp output ------------------------------------------ @@ -252,4 +252,5 @@ "mpl": ("https://matplotlib.org/", None), "bokeh": ("https://docs.bokeh.org/en/latest/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "zarr": ("https://zarr.readthedocs.io/en/stable/", None), } diff --git a/requirements-optional.txt b/requirements-optional.txt index dfa9708469..ec7ef8ddd6 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -2,3 +2,4 @@ numba bokeh>=1.4.0 ujson dask +zarr>=2.5.0 \ No newline at end of file