-
-
Notifications
You must be signed in to change notification settings - Fork 419
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added methods to convert InferenceData to and from Zarr (#1518)
* Added a to_zarr method which converts the InferenceData object to a hierarchical zarr group. * Added a from_zarr method which create an inferenceData object from a zarr store. * Fixed small typo in to_zarr and from_zarr * Added the ability to create InferenceData from zarr.hierarchy.Group * Oversight in zarr.hierachy.group.groups() generator call. * Forgot dictionary definition. * Cleanup zarr: - added reference to zarr docs - replaced type with isInstance - replaced MemoryStore with TempStore * Added to methods to CHANGELOG * PR-Comments: - Moved MutableMapping to top - Check if zarr is installed - Removed ifs observed_data,constant_data,predictions_constant_data - Renamed g to zarr_handle Removed depreciated docstring * Removed typo * Replaced last occurence of g with zarr_handle * Added from packaging import version * Fixed wrong import order, I did not know this is a thing in python, wow * Even later import of version * Yet another try to fix the import order * Local pylint is working with this formatting... * Docstring formatting changes and removed deprecated parts. * Fixed local black version mismatch. * Update arviz/data/inference_data.py Co-authored-by: Ari Hartikainen <ahartikainen@users.noreply.github.com> * Added tests and docs intersphinx_mapping. * Improved test coverage for to_zarr and from_zarr functions. * Added pytest.mark.skipif to zarr test class * Moved test class to new file called 'test_data_zarr' * Fixed small pylint wrong-import-position error * Yet another import-order fix * Pylint is still black magic for me... It is working locally without errors * Update arviz/tests/base_tests/test_data_zarr.py Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com> * Removed running_on_ci * Reverted last change and removed running_on_ci in right place now * Moved zarr import down and added `# pylint: disable=wrong-import-position` * Update arviz/tests/base_tests/test_data_zarr.py Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com> Co-authored-by: Ari Hartikainen <ahartikainen@users.noreply.github.com> Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
1 parent
d09d644
commit 2be3611
Showing
7 changed files
with
220 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# pylint: disable=redefined-outer-name | ||
import importlib | ||
import os | ||
import shutil | ||
from collections.abc import MutableMapping | ||
|
||
import numpy as np | ||
import pytest | ||
|
||
|
||
from arviz import InferenceData, from_dict | ||
|
||
from ..helpers import ( # pylint: disable=unused-import | ||
chains, | ||
check_multiple_attrs, | ||
draws, | ||
eight_schools_params, | ||
running_on_ci, | ||
) | ||
|
||
pytestmark = pytest.mark.skipif( # pylint: disable=invalid-name | ||
importlib.util.find_spec("zarr") is None and not running_on_ci(), | ||
reason="test requires zarr which is not installed", | ||
) | ||
|
||
import zarr # pylint: disable=wrong-import-position, wrong-import-order | ||
|
||
|
||
class TestDataZarr: | ||
@pytest.fixture(scope="class") | ||
def data(self, draws, chains): | ||
class Data: | ||
# fake 8-school output | ||
obj = {} | ||
for key, shape in {"mu": [], "tau": [], "eta": [8], "theta": [8]}.items(): | ||
obj[key] = np.random.randn(chains, draws, *shape) | ||
|
||
return Data | ||
|
||
def get_inference_data(self, data, eight_schools_params): | ||
return from_dict( | ||
posterior=data.obj, | ||
posterior_predictive=data.obj, | ||
sample_stats=data.obj, | ||
prior=data.obj, | ||
prior_predictive=data.obj, | ||
sample_stats_prior=data.obj, | ||
observed_data=eight_schools_params, | ||
coords={"school": np.arange(8)}, | ||
dims={"theta": ["school"], "eta": ["school"]}, | ||
) | ||
|
||
@pytest.mark.parametrize("store", [0, 1, 2]) | ||
def test_io_method(self, data, eight_schools_params, store): | ||
# create InferenceData and check it has been properly created | ||
inference_data = self.get_inference_data( # pylint: disable=W0612 | ||
data, eight_schools_params | ||
) | ||
test_dict = { | ||
"posterior": ["eta", "theta", "mu", "tau"], | ||
"posterior_predictive": ["eta", "theta", "mu", "tau"], | ||
"sample_stats": ["eta", "theta", "mu", "tau"], | ||
"prior": ["eta", "theta", "mu", "tau"], | ||
"prior_predictive": ["eta", "theta", "mu", "tau"], | ||
"sample_stats_prior": ["eta", "theta", "mu", "tau"], | ||
"observed_data": ["J", "y", "sigma"], | ||
} | ||
fails = check_multiple_attrs(test_dict, inference_data) | ||
assert not fails | ||
|
||
# check filename does not exist and use to_zarr method | ||
here = os.path.dirname(os.path.abspath(__file__)) | ||
data_directory = os.path.join(here, "..", "saved_models") | ||
filepath = os.path.join(data_directory, "zarr") | ||
assert not os.path.exists(filepath) | ||
|
||
# InferenceData method | ||
if store == 0: | ||
# Tempdir | ||
store = inference_data.to_zarr(store=None) | ||
assert isinstance(store, MutableMapping) | ||
elif store == 1: | ||
inference_data.to_zarr(store=filepath) | ||
# assert file has been saved correctly | ||
assert os.path.exists(filepath) | ||
assert os.path.getsize(filepath) > 0 | ||
elif store == 2: | ||
store = zarr.storage.DirectoryStore(filepath) | ||
inference_data.to_zarr(store=store) | ||
# assert file has been saved correctly | ||
assert os.path.exists(filepath) | ||
assert os.path.getsize(filepath) > 0 | ||
|
||
if isinstance(store, MutableMapping): | ||
inference_data2 = InferenceData.from_zarr(store) | ||
else: | ||
inference_data2 = InferenceData.from_zarr(filepath) | ||
|
||
# Everything in dict still available in inference_data2 ? | ||
fails = check_multiple_attrs(test_dict, inference_data2) | ||
assert not fails | ||
|
||
# Remove created folder structure | ||
if os.path.exists(filepath): | ||
shutil.rmtree(filepath) | ||
assert not os.path.exists(filepath) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ numba | |
bokeh>=1.4.0 | ||
ujson | ||
dask | ||
zarr>=2.5.0 |