From 2be3611058a9d68b6505cb485d6e517fd41d2c11 Mon Sep 17 00:00:00 2001 From: semohr <39738318+semohr@users.noreply.github.com> Date: Sat, 6 Feb 2021 17:36:06 +0100 Subject: [PATCH] Added methods to convert InferenceData to and from Zarr (#1518) * Added a to_zarr method which converts the InferenceData object to a hierarchical zarr group. * Added a from_zarr method which create an inferenceData object from a zarr store. * Fixed small typo in to_zarr and from_zarr * Added the ability to create InferenceData from zarr.hierarchy.Group * Oversight in zarr.hierachy.group.groups() generator call. * Forgot dictionary definition. * Cleanup zarr: - added reference to zarr docs - replaced type with isInstance - replaced MemoryStore with TempStore * Added to methods to CHANGELOG * PR-Comments: - Moved MutableMapping to top - Check if zarr is installed - Removed ifs observed_data,constant_data,predictions_constant_data - Renamed g to zarr_handle Removed depreciated docstring * Removed typo * Replaced last occurence of g with zarr_handle * Added from packaging import version * Fixed wrong import order, I did not know this is a thing in python, wow * Even later import of version * Yet another try to fix the import order * Local pylint is working with this formatting... * Docstring formatting changes and removed deprecated parts. * Fixed local black version mismatch. * Update arviz/data/inference_data.py Co-authored-by: Ari Hartikainen * Added tests and docs intersphinx_mapping. * Improved test coverage for to_zarr and from_zarr functions. * Added pytest.mark.skipif to zarr test class * Moved test class to new file called 'test_data_zarr' * Fixed small pylint wrong-import-position error * Yet another import-order fix * Pylint is still black magic for me... It is working locally without errors * Update arviz/tests/base_tests/test_data_zarr.py Co-authored-by: Oriol Abril-Pla * Removed running_on_ci * Reverted last change and removed running_on_ci in right place now * Moved zarr import down and added `# pylint: disable=wrong-import-position` * Update arviz/tests/base_tests/test_data_zarr.py Co-authored-by: Oriol Abril-Pla Co-authored-by: Ari Hartikainen Co-authored-by: Oriol Abril-Pla --- CHANGELOG.md | 2 + arviz/data/inference_data.py | 107 ++++++++++++++++++++++- arviz/tests/base_tests/test_data.py | 1 + arviz/tests/base_tests/test_data_zarr.py | 106 ++++++++++++++++++++++ doc/source/api/inference_data.rst | 2 + doc/source/conf.py | 7 +- requirements-optional.txt | 1 + 7 files changed, 220 insertions(+), 6 deletions(-) create mode 100644 arviz/tests/base_tests/test_data_zarr.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a12cbaad5..f4de451cc4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,8 @@ ## v0.x.x Unreleased ### New features +* Added `to_zarr` method to InferenceData +* Added `from_zarr` method to InferenceData ### Maintenance and fixes diff --git a/arviz/data/inference_data.py b/arviz/data/inference_data.py index 68439889ff..d20464ab15 100644 --- a/arviz/data/inference_data.py +++ b/arviz/data/inference_data.py @@ -1,14 +1,14 @@ -# pylint: disable=too-many-lines +# pylint: disable=too-many-lines,too-many-public-methods """Data structure for using netcdf groups with xarray.""" +import sys import uuid import warnings from collections import OrderedDict, defaultdict -from collections.abc import Sequence +from collections.abc import MutableMapping, Sequence from copy import copy as ccopy from copy import deepcopy from datetime import datetime from html import escape -import sys from typing import ( TYPE_CHECKING, Any, @@ -26,6 +26,7 @@ import netCDF4 as nc import numpy as np import xarray as xr +from packaging import version from ..rcparams import rcParams from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs @@ -589,6 +590,106 @@ def to_dataframe( (dfs,) = dfs.values() return dfs + def to_zarr(self, store=None): + """Convert InferenceData to a :class:`zarr.hierarchy.Group`. + + The zarr storage is using the same group names as the InferenceData. + + Raises + ------ + TypeError + If no valid store is found. + + Parameters + ---------- + store: zarr.storage i.e MutableMapping or str, optional + Zarr storage class or path to desired DirectoryStore. + + Returns + ------- + zarr.hierarchy.group + A zarr hierarchy group containing the InferenceData. + + References + ---------- + https://zarr.readthedocs.io/ + """ + try: # Check zarr + import zarr + + assert version.parse(zarr.__version__) >= version.parse("2.5.0") + except (ImportError, AssertionError) as err: + raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err + + # Check store type and create store if necessary + if store is None: + store = zarr.storage.TempStore(suffix="arviz") + elif isinstance(store, str): + store = zarr.storage.DirectoryStore(path=store) + elif not isinstance(store, MutableMapping): + raise TypeError(f"No valid store found: {store}") + + groups = self.groups() + + if not groups: + raise TypeError("No valid groups found!") + + for group in groups: + # Create zarr group in store with same group name + getattr(self, group).to_zarr(store=store, group=group, mode="w") + + return zarr.open(store) # Open store to get overarching group + + @staticmethod + def from_zarr(store) -> "InferenceData": + """Initialize object from a zarr store or path. + + Expects that the zarr store will have groups, each of which can be loaded by xarray. + By default, the datasets of the InferenceData object will be lazily loaded instead + of being loaded into memory. This + behaviour is regulated by the value of ``az.rcParams["data.load"]``. + + Parameters + ---------- + store: MutableMapping or zarr.hierarchy.Group or str. + Zarr storage class or path to desired Store. + + Returns + ------- + InferenceData object + + References + ---------- + https://zarr.readthedocs.io/ + """ + try: + import zarr + + assert version.parse(zarr.__version__) >= version.parse("2.5.0") + except (ImportError, AssertionError) as err: + raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err + + # Check store type and create store if necessary + if isinstance(store, str): + store = zarr.storage.DirectoryStore(path=store) + elif isinstance(store, zarr.hierarchy.Group): + store = store.store + elif not isinstance(store, MutableMapping): + raise TypeError(f"No valid store found: {store}") + + groups = {} + zarr_handle = zarr.open(store, mode="r") + + # Open each group via xarray method + for key_group, _ in zarr_handle.groups(): + with xr.open_zarr(store=store, group=key_group) as data: + if rcParams["data.load"] == "eager": + groups[key_group] = data.load() + else: + groups[key_group] = data + + return InferenceData(**groups) + def __add__(self, other: "InferenceData") -> "InferenceData": """Concatenate two InferenceData objects.""" return concat(self, other, copy=True, inplace=False) diff --git a/arviz/tests/base_tests/test_data.py b/arviz/tests/base_tests/test_data.py index 4894075608..2ed3a9ff5a 100644 --- a/arviz/tests/base_tests/test_data.py +++ b/arviz/tests/base_tests/test_data.py @@ -1,5 +1,6 @@ # pylint: disable=no-member, invalid-name, redefined-outer-name # pylint: disable=too-many-lines + import os from collections import namedtuple from copy import deepcopy diff --git a/arviz/tests/base_tests/test_data_zarr.py b/arviz/tests/base_tests/test_data_zarr.py new file mode 100644 index 0000000000..7a77a2587f --- /dev/null +++ b/arviz/tests/base_tests/test_data_zarr.py @@ -0,0 +1,106 @@ +# pylint: disable=redefined-outer-name +import importlib +import os +import shutil +from collections.abc import MutableMapping + +import numpy as np +import pytest + + +from arviz import InferenceData, from_dict + +from ..helpers import ( # pylint: disable=unused-import + chains, + check_multiple_attrs, + draws, + eight_schools_params, + running_on_ci, +) + +pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name + importlib.util.find_spec("zarr") is None and not running_on_ci(), + reason="test requires zarr which is not installed", +) + +import zarr # pylint: disable=wrong-import-position, wrong-import-order + + +class TestDataZarr: + @pytest.fixture(scope="class") + def data(self, draws, chains): + class Data: + # fake 8-school output + obj = {} + for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items(): + obj[key] = np.random.randn(chains, draws, *shape) + + return Data + + def get_inference_data(self, data, eight_schools_params): + return from_dict( + posterior=data.obj, + posterior_predictive=data.obj, + sample_stats=data.obj, + prior=data.obj, + prior_predictive=data.obj, + sample_stats_prior=data.obj, + observed_data=eight_schools_params, + coords={"school": np.arange(8)}, + dims={"theta": ["school"], "eta": ["school"]}, + ) + + @pytest.mark.parametrize("store", [0, 1, 2]) + def test_io_method(self, data, eight_schools_params, store): + # create InferenceData and check it has been properly created + inference_data = self.get_inference_data( # pylint: disable=W0612 + data, eight_schools_params + ) + test_dict = { + "posterior": ["eta", "theta", "mu", "tau"], + "posterior_predictive": ["eta", "theta", "mu", "tau"], + "sample_stats": ["eta", "theta", "mu", "tau"], + "prior": ["eta", "theta", "mu", "tau"], + "prior_predictive": ["eta", "theta", "mu", "tau"], + "sample_stats_prior": ["eta", "theta", "mu", "tau"], + "observed_data": ["J", "y", "sigma"], + } + fails = check_multiple_attrs(test_dict, inference_data) + assert not fails + + # check filename does not exist and use to_zarr method + here = os.path.dirname(os.path.abspath(__file__)) + data_directory = os.path.join(here, "..", "saved_models") + filepath = os.path.join(data_directory, "zarr") + assert not os.path.exists(filepath) + + # InferenceData method + if store == 0: + # Tempdir + store = inference_data.to_zarr(store=None) + assert isinstance(store, MutableMapping) + elif store == 1: + inference_data.to_zarr(store=filepath) + # assert file has been saved correctly + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + elif store == 2: + store = zarr.storage.DirectoryStore(filepath) + inference_data.to_zarr(store=store) + # assert file has been saved correctly + assert os.path.exists(filepath) + assert os.path.getsize(filepath) > 0 + + if isinstance(store, MutableMapping): + inference_data2 = InferenceData.from_zarr(store) + else: + inference_data2 = InferenceData.from_zarr(filepath) + + # Everything in dict still available in inference_data2 ? + fails = check_multiple_attrs(test_dict, inference_data2) + assert not fails + + # Remove created folder structure + if os.path.exists(filepath): + shutil.rmtree(filepath) + assert not os.path.exists(filepath) diff --git a/doc/source/api/inference_data.rst b/doc/source/api/inference_data.rst index a086590100..ddb447c2be 100644 --- a/doc/source/api/inference_data.rst +++ b/doc/source/api/inference_data.rst @@ -33,6 +33,8 @@ IO / Conversion InferenceData.to_json InferenceData.from_netcdf InferenceData.to_netcdf + InferenceData.from_zarr + InferenceData.to_zarr InferenceData.chunk InferenceData.compute InferenceData.load diff --git a/doc/source/conf.py b/doc/source/conf.py index 7a6922b94f..91801d2e06 100644 --- a/doc/source/conf.py +++ b/doc/source/conf.py @@ -137,8 +137,8 @@ # documentation. html_theme_options = { - "github_url": "https://github.com/arviz-devs/arviz", - "twitter_url": "https://twitter.com/arviz_devs", + "github_url": "https://github.com/arviz-devs/arviz", + "twitter_url": "https://twitter.com/arviz_devs", } # Add any paths that contain custom static files (such as style sheets) here, @@ -153,7 +153,7 @@ # use additional pages to add a 404 page html_additional_pages = { - '404': '404.html', + "404": "404.html", } # -- Options for HTMLHelp output ------------------------------------------ @@ -252,4 +252,5 @@ "mpl": ("https://matplotlib.org/", None), "bokeh": ("https://docs.bokeh.org/en/latest/", None), "scipy": ("https://docs.scipy.org/doc/scipy/reference/", None), + "zarr": ("https://zarr.readthedocs.io/en/stable/", None), } diff --git a/requirements-optional.txt b/requirements-optional.txt index dfa9708469..ec7ef8ddd6 100644 --- a/requirements-optional.txt +++ b/requirements-optional.txt @@ -2,3 +2,4 @@ numba bokeh>=1.4.0 ujson dask +zarr>=2.5.0 \ No newline at end of file