Skip to content

Commit

Permalink
Added methods to convert InferenceData to and from Zarr (#1518)
Browse files Browse the repository at this point in the history
* Added a to_zarr method which converts the InferenceData object
to a hierarchical zarr group.

* Added a from_zarr method which create an inferenceData object from a zarr store.

* Fixed small typo in to_zarr and from_zarr

* Added the ability to create InferenceData from zarr.hierarchy.Group

* Oversight in zarr.hierachy.group.groups() generator call.

* Forgot dictionary definition.

* Cleanup zarr:
- added reference to zarr docs
- replaced type with isInstance
- replaced MemoryStore with TempStore

* Added to methods to CHANGELOG

* PR-Comments:
- Moved MutableMapping to top
- Check if zarr is installed
- Removed ifs observed_data,constant_data,predictions_constant_data
- Renamed g to zarr_handle

Removed depreciated docstring

* Removed typo

* Replaced last occurence of g  with zarr_handle

* Added from packaging import version

* Fixed wrong import order, I did not know this is a thing in python, wow

* Even later import of version

* Yet another try to fix the import order

* Local pylint is working with this formatting...

* Docstring formatting changes and removed deprecated parts.

* Fixed local black version mismatch.

* Update arviz/data/inference_data.py

Co-authored-by: Ari Hartikainen <ahartikainen@users.noreply.github.com>

* Added tests and docs intersphinx_mapping.

* Improved test coverage for to_zarr and from_zarr functions.

* Added pytest.mark.skipif to zarr test class

* Moved test class to new file called 'test_data_zarr'

* Fixed small pylint wrong-import-position error

* Yet another import-order fix

* Pylint is still black magic for me...
It is working locally without errors

* Update arviz/tests/base_tests/test_data_zarr.py

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>

* Removed running_on_ci

* Reverted last change and removed running_on_ci in right place now

* Moved zarr import down and added `# pylint: disable=wrong-import-position`

* Update arviz/tests/base_tests/test_data_zarr.py

Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>

Co-authored-by: Ari Hartikainen <ahartikainen@users.noreply.github.com>
Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
3 people authored Feb 6, 2021

Verified

This commit was signed with the committer’s verified signature.
Cruikshanks Alan Cruikshanks
1 parent d09d644 commit 2be3611
Showing 7 changed files with 220 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -2,6 +2,8 @@

## v0.x.x Unreleased
### New features
* Added `to_zarr` method to InferenceData
* Added `from_zarr` method to InferenceData

### Maintenance and fixes

107 changes: 104 additions & 3 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# pylint: disable=too-many-lines
# pylint: disable=too-many-lines,too-many-public-methods
"""Data structure for using netcdf groups with xarray."""
import sys
import uuid
import warnings
from collections import OrderedDict, defaultdict
from collections.abc import Sequence
from collections.abc import MutableMapping, Sequence
from copy import copy as ccopy
from copy import deepcopy
from datetime import datetime
from html import escape
import sys
from typing import (
TYPE_CHECKING,
Any,
@@ -26,6 +26,7 @@
import netCDF4 as nc
import numpy as np
import xarray as xr
from packaging import version

from ..rcparams import rcParams
from ..utils import HtmlTemplate, _subset_list, either_dict_or_kwargs
@@ -589,6 +590,106 @@ def to_dataframe(
(dfs,) = dfs.values()
return dfs

def to_zarr(self, store=None):
"""Convert InferenceData to a :class:`zarr.hierarchy.Group`.
The zarr storage is using the same group names as the InferenceData.
Raises
------
TypeError
If no valid store is found.
Parameters
----------
store: zarr.storage i.e MutableMapping or str, optional
Zarr storage class or path to desired DirectoryStore.
Returns
-------
zarr.hierarchy.group
A zarr hierarchy group containing the InferenceData.
References
----------
https://zarr.readthedocs.io/
"""
try: # Check zarr
import zarr

assert version.parse(zarr.__version__) >= version.parse("2.5.0")
except (ImportError, AssertionError) as err:
raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err

# Check store type and create store if necessary
if store is None:
store = zarr.storage.TempStore(suffix="arviz")
elif isinstance(store, str):
store = zarr.storage.DirectoryStore(path=store)
elif not isinstance(store, MutableMapping):
raise TypeError(f"No valid store found: {store}")

groups = self.groups()

if not groups:
raise TypeError("No valid groups found!")

for group in groups:
# Create zarr group in store with same group name
getattr(self, group).to_zarr(store=store, group=group, mode="w")

return zarr.open(store) # Open store to get overarching group

@staticmethod
def from_zarr(store) -> "InferenceData":
"""Initialize object from a zarr store or path.
Expects that the zarr store will have groups, each of which can be loaded by xarray.
By default, the datasets of the InferenceData object will be lazily loaded instead
of being loaded into memory. This
behaviour is regulated by the value of ``az.rcParams["data.load"]``.
Parameters
----------
store: MutableMapping or zarr.hierarchy.Group or str.
Zarr storage class or path to desired Store.
Returns
-------
InferenceData object
References
----------
https://zarr.readthedocs.io/
"""
try:
import zarr

assert version.parse(zarr.__version__) >= version.parse("2.5.0")
except (ImportError, AssertionError) as err:
raise ImportError("'to_zarr' method needs Zarr (2.5.0+) installed.") from err

# Check store type and create store if necessary
if isinstance(store, str):
store = zarr.storage.DirectoryStore(path=store)
elif isinstance(store, zarr.hierarchy.Group):
store = store.store
elif not isinstance(store, MutableMapping):
raise TypeError(f"No valid store found: {store}")

groups = {}
zarr_handle = zarr.open(store, mode="r")

# Open each group via xarray method
for key_group, _ in zarr_handle.groups():
with xr.open_zarr(store=store, group=key_group) as data:
if rcParams["data.load"] == "eager":
groups[key_group] = data.load()
else:
groups[key_group] = data

return InferenceData(**groups)

def __add__(self, other: "InferenceData") -> "InferenceData":
"""Concatenate two InferenceData objects."""
return concat(self, other, copy=True, inplace=False)
1 change: 1 addition & 0 deletions arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=no-member, invalid-name, redefined-outer-name
# pylint: disable=too-many-lines

import os
from collections import namedtuple
from copy import deepcopy
106 changes: 106 additions & 0 deletions arviz/tests/base_tests/test_data_zarr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# pylint: disable=redefined-outer-name
import importlib
import os
import shutil
from collections.abc import MutableMapping

import numpy as np
import pytest


from arviz import InferenceData, from_dict

from ..helpers import ( # pylint: disable=unused-import
chains,
check_multiple_attrs,
draws,
eight_schools_params,
running_on_ci,
)

pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name
importlib.util.find_spec("zarr") is None and not running_on_ci(),
reason="test requires zarr which is not installed",
)

import zarr # pylint: disable=wrong-import-position, wrong-import-order


class TestDataZarr:
@pytest.fixture(scope="class")
def data(self, draws, chains):
class Data:
# fake 8-school output
obj = {}
for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items():
obj[key] = np.random.randn(chains, draws, *shape)

return Data

def get_inference_data(self, data, eight_schools_params):
return from_dict(
posterior=data.obj,
posterior_predictive=data.obj,
sample_stats=data.obj,
prior=data.obj,
prior_predictive=data.obj,
sample_stats_prior=data.obj,
observed_data=eight_schools_params,
coords={"school": np.arange(8)},
dims={"theta": ["school"], "eta": ["school"]},
)

@pytest.mark.parametrize("store", [0, 1, 2])
def test_io_method(self, data, eight_schools_params, store):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
"posterior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats": ["eta", "theta", "mu", "tau"],
"prior": ["eta", "theta", "mu", "tau"],
"prior_predictive": ["eta", "theta", "mu", "tau"],
"sample_stats_prior": ["eta", "theta", "mu", "tau"],
"observed_data": ["J", "y", "sigma"],
}
fails = check_multiple_attrs(test_dict, inference_data)
assert not fails

# check filename does not exist and use to_zarr method
here = os.path.dirname(os.path.abspath(__file__))
data_directory = os.path.join(here, "..", "saved_models")
filepath = os.path.join(data_directory, "zarr")
assert not os.path.exists(filepath)

# InferenceData method
if store == 0:
# Tempdir
store = inference_data.to_zarr(store=None)
assert isinstance(store, MutableMapping)
elif store == 1:
inference_data.to_zarr(store=filepath)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0
elif store == 2:
store = zarr.storage.DirectoryStore(filepath)
inference_data.to_zarr(store=store)
# assert file has been saved correctly
assert os.path.exists(filepath)
assert os.path.getsize(filepath) > 0

if isinstance(store, MutableMapping):
inference_data2 = InferenceData.from_zarr(store)
else:
inference_data2 = InferenceData.from_zarr(filepath)

# Everything in dict still available in inference_data2 ?
fails = check_multiple_attrs(test_dict, inference_data2)
assert not fails

# Remove created folder structure
if os.path.exists(filepath):
shutil.rmtree(filepath)
assert not os.path.exists(filepath)
2 changes: 2 additions & 0 deletions doc/source/api/inference_data.rst
Original file line number Diff line number Diff line change
@@ -33,6 +33,8 @@ IO / Conversion
InferenceData.to_json
InferenceData.from_netcdf
InferenceData.to_netcdf
InferenceData.from_zarr
InferenceData.to_zarr
InferenceData.chunk
InferenceData.compute
InferenceData.load
7 changes: 4 additions & 3 deletions doc/source/conf.py
Original file line number Diff line number Diff line change
@@ -137,8 +137,8 @@
# documentation.

html_theme_options = {
"github_url": "https://github.com/arviz-devs/arviz",
"twitter_url": "https://twitter.com/arviz_devs",
"github_url": "https://github.com/arviz-devs/arviz",
"twitter_url": "https://twitter.com/arviz_devs",
}

# Add any paths that contain custom static files (such as style sheets) here,
@@ -153,7 +153,7 @@

# use additional pages to add a 404 page
html_additional_pages = {
'404': '404.html',
"404": "404.html",
}

# -- Options for HTMLHelp output ------------------------------------------
@@ -252,4 +252,5 @@
"mpl": ("https://matplotlib.org/", None),
"bokeh": ("https://docs.bokeh.org/en/latest/", None),
"scipy": ("https://docs.scipy.org/doc/scipy/reference/", None),
"zarr": ("https://zarr.readthedocs.io/en/stable/", None),
}
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -2,3 +2,4 @@ numba
bokeh>=1.4.0
ujson
dask
zarr>=2.5.0

0 comments on commit 2be3611

Please sign in to comment.