Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swapped the netCDF4 dependency to h5netcdf #2122

Merged
merged 12 commits into from
Dec 28, 2022
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
### New features
- Adds Savage-Dickey density ratio plot for Bayes factor approximation. ([2037](https://github.com/arviz-devs/arviz/pull/2037), [2152](https://github.com/arviz-devs/arviz/pull/2152))
- Add `CmdStanPySamplingWrapper` and `PyMCSamplingWrapper` classes ([2158](https://github.com/arviz-devs/arviz/pull/2158))
- Changed dependency on netcdf4-python to h5netcdf ([2122](https://github.com/arviz-devs/arviz/pull/2122))

### Maintenance and fixes
- Fix `reloo` outdated usage of `ELPDData` ([2158](https://github.com/arviz-devs/arviz/pull/2158))
Expand Down
17 changes: 5 additions & 12 deletions arviz/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _sha256(path):
return sha256hash.hexdigest()


def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs):
def load_arviz_data(dataset=None, data_home=None, **kwargs):
"""Load a local or remote pre-made dataset.

Run with no parameters to get a list of all available models.
Expand All @@ -100,17 +100,10 @@ def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs):
----------
dataset : str
Name of dataset to load.

data_home : str, optional
Where to save remote datasets

regex : bool, optional
Specifies regex support for chunking information in
:func:`arviz.from_netcdf`. This feature is currently experimental.

**kwargs : dict of {str: dict}, optional
Keyword arguments to be passed to :func:`arviz.from_netcdf`.
This feature is currently experimental.
**kwargs : dict, optional
Keyword arguments passed to :func:`arviz.from_netcdf`.

Returns
-------
Expand All @@ -119,7 +112,7 @@ def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs):
"""
if dataset in LOCAL_DATASETS:
resource = LOCAL_DATASETS[dataset]
return from_netcdf(resource.filename)
return from_netcdf(resource.filename, **kwargs)

elif dataset in REMOTE_DATASETS:
remote = REMOTE_DATASETS[dataset]
Expand All @@ -140,7 +133,7 @@ def load_arviz_data(dataset=None, data_home=None, regex=False, **kwargs):
"({remote.checksum}), file may be corrupted. "
"Run `arviz.clear_data_home()` and try again, or please open an issue."
)
return from_netcdf(file_path, kwargs, regex)
return from_netcdf(file_path, **kwargs)
else:
if dataset is None:
return dict(itertools.chain(LOCAL_DATASETS.items(), REMOTE_DATASETS.items()))
Expand Down
50 changes: 40 additions & 10 deletions arviz/data/inference_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
overload,
)

import netCDF4 as nc
import numpy as np
import xarray as xr
from packaging import version
Expand Down Expand Up @@ -337,7 +336,9 @@ def items(self) -> "InferenceData.InferenceDataItemsView":
return InferenceData.InferenceDataItemsView(self)

@staticmethod
def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
def from_netcdf(
filename, *, engine="h5netcdf", group_kwargs=None, regex=False
) -> "InferenceData":
"""Initialize object from a netcdf file.

Expects that the file will have groups, each of which can be loaded by xarray.
Expand All @@ -349,6 +350,8 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
----------
filename : str
location of netcdf file
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
group_kwargs : dict of {str: dict}, optional
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
The keys of the higher level should be group names or regex matching group
Expand All @@ -360,30 +363,44 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":

Returns
-------
InferenceData object
InferenceData
"""
groups = {}
attrs = {}

if engine == "h5netcdf":
import h5netcdf
elif engine == "netcdf4":
import netCDF4 as nc
else:
raise ValueError(
f"Invalid value for engine: {engine}. Valid options are: h5netcdf or netcdf4"
)

try:
with nc.Dataset(filename, mode="r") as data:
with h5netcdf.File(filename, mode="r") if engine == "h5netcdf" else nc.Dataset(
filename, mode="r"
) as data:
data_groups = list(data.groups)

for group in data_groups:
group_kws = {}

group_kws = {}
if group_kwargs is not None and regex is False:
group_kws = group_kwargs.get(group, {})
if group_kwargs is not None and regex is True:
for key, kws in group_kwargs.items():
if re.search(key, group):
group_kws = kws
group_kws.setdefault("engine", engine)
with xr.open_dataset(filename, group=group, **group_kws) as data:
if rcParams["data.load"] == "eager":
groups[group] = data.load()
else:
groups[group] = data

with xr.open_dataset(filename, mode="r") as data:
with xr.open_dataset(filename, engine=engine) as data:
attrs.update(data.load().attrs)

return InferenceData(attrs=attrs, **groups)
Expand All @@ -402,9 +419,13 @@ def from_netcdf(filename, group_kwargs=None, regex=False) -> "InferenceData":
raise err

def to_netcdf(
self, filename: str, compress: bool = True, groups: Optional[List[str]] = None
self,
filename: str,
compress: bool = True,
groups: Optional[List[str]] = None,
engine: str = "h5netcdf",
) -> str:
"""Write InferenceData to file using netcdf4.
"""Write InferenceData to netcdf4 file.

Parameters
----------
Expand All @@ -415,6 +436,8 @@ def to_netcdf(
saving and loading somewhat slower (default: True).
groups : list, optional
Write only these groups to netcdf file.
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.

Returns
-------
Expand All @@ -423,7 +446,7 @@ def to_netcdf(
"""
mode = "w" # overwrite first, then append
if self._attrs:
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode)
xr.Dataset(attrs=self._attrs).to_netcdf(filename, mode=mode, engine=engine)
mode = "a"

if self._groups_all: # check's whether a group is present or not.
Expand All @@ -434,7 +457,7 @@ def to_netcdf(

for group in groups:
data = getattr(self, group)
kwargs = {}
kwargs = {"engine": engine}
if compress:
kwargs["encoding"] = {
var_name: {"zlib": True}
Expand All @@ -445,7 +468,14 @@ def to_netcdf(
data.close()
mode = "a"
elif not self._attrs: # creates a netcdf file for an empty InferenceData object.
empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
if engine == "h5netcdf":
import h5netcdf

empty_netcdf_file = h5netcdf.File(filename, mode="w")
elif engine == "netcdf4":
import netCDF4 as nc

empty_netcdf_file = nc.Dataset(filename, mode="w", format="NETCDF4")
empty_netcdf_file.close()
return filename

Expand Down
14 changes: 10 additions & 4 deletions arviz/data/io_netcdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
from .inference_data import InferenceData


def from_netcdf(filename, group_kwargs=None, regex=False):
def from_netcdf(filename, *, engine="h5netcdf", group_kwargs=None, regex=False):
"""Load netcdf file back into an arviz.InferenceData.

Parameters
----------
filename : str
name or path of the file to load trace
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
group_kwargs : dict of {str: dict}
Keyword arguments to be passed into each call of :func:`xarray.open_dataset`.
The keys of the higher level should be group names or regex matching group
Expand All @@ -31,10 +33,12 @@ def from_netcdf(filename, group_kwargs=None, regex=False):
"""
if group_kwargs is None:
group_kwargs = {}
return InferenceData.from_netcdf(filename, group_kwargs, regex)
return InferenceData.from_netcdf(
filename, engine=engine, group_kwargs=group_kwargs, regex=regex
)


def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
def to_netcdf(data, filename, *, group="posterior", engine="h5netcdf", coords=None, dims=None):
"""Save dataset as a netcdf file.

WARNING: Only idempotent in case `data` is InferenceData
Expand All @@ -47,6 +51,8 @@ def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
name or path of the file to load trace
group : str (optional)
In case `data` is not InferenceData, this is the group it will be saved to
engine : {"h5netcdf", "netcdf4"}, default "h5netcdf"
Library used to read the netcdf file.
coords : dict (optional)
See `convert_to_inference_data`
dims : dict (optional)
Expand All @@ -58,5 +64,5 @@ def to_netcdf(data, filename, *, group="posterior", coords=None, dims=None):
filename saved to
"""
inference_data = convert_to_inference_data(data, group=group, coords=coords, dims=dims)
file_name = inference_data.to_netcdf(filename)
file_name = inference_data.to_netcdf(filename, engine=engine)
return file_name
13 changes: 12 additions & 1 deletion arviz/tests/base_tests/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -1296,11 +1296,22 @@ def test_io_function(self, data, eight_schools_params):

@pytest.mark.parametrize("groups_arg", [False, True])
@pytest.mark.parametrize("compress", [True, False])
def test_io_method(self, data, eight_schools_params, groups_arg, compress):
@pytest.mark.parametrize("engine", ["h5netcdf", "netcdf4"])
def test_io_method(self, data, eight_schools_params, groups_arg, compress, engine):
# create InferenceData and check it has been properly created
inference_data = self.get_inference_data( # pylint: disable=W0612
data, eight_schools_params
)
if engine == "h5netcdf":
try:
import h5netcdf # pylint: disable=unused-import
except ImportError:
pytest.skip("h5netcdf not installed")
elif engine == "netcdf4":
try:
import netCDF4 # pylint: disable=unused-import
except ImportError:
pytest.skip("netcdf4 not installed")
test_dict = {
"posterior": ["eta", "theta", "mu", "tau"],
"posterior_predictive": ["eta", "theta", "mu", "tau"],
Expand Down
1 change: 1 addition & 0 deletions requirements-optional.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
numba
netcdf4
bokeh>=1.4.0,<3.0
contourpy
ujson
Expand Down
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@ pytest-cov
cloudpickle

-r requirements-optional.txt
-r requirements-external.txt
-r requirements-external.txt
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ scipy>=1.8.0
packaging
pandas>=1.4.0
xarray>=0.21.0
netcdf4
h5netcdf>=1.0.2
typing_extensions>=4.1.0
xarray-einstats>=0.3