Skip to content

Commit

Permalink
Add load_dataset and save_dataset functions sgkit-dev#392
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite committed Dec 15, 2020
1 parent c554882 commit 0a8ecb2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 8 deletions.
30 changes: 25 additions & 5 deletions sgkit/io/dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from typing import Any
from typing import Any, Dict, Optional

import fsspec
import xarray as xr
from xarray import Dataset

from sgkit.typing import PathType


def save_dataset(ds: Dataset, path: PathType, **kwargs: Any) -> None:
def save_dataset(
ds: Dataset,
path: PathType,
storage_options: Optional[Dict[str, str]] = None,
**kwargs: Any
) -> None:
"""Save a dataset to Zarr storage.
This function is a thin wrapper around :meth:`xarray.Dataset.to_zarr`
Expand All @@ -18,17 +24,25 @@ def save_dataset(ds: Dataset, path: PathType, **kwargs: Any) -> None:
Dataset to save.
path
Path to directory in file system to save to.
storage_options:
Any additional parameters for the storage backend (see ``fsspec.open``).
kwargs
Additional arguments to pass to :meth:`xarray.Dataset.to_zarr`.
"""
store = str(path)
if isinstance(path, str):
storage_options = storage_options or {}
store = fsspec.get_mapper(path, **storage_options)
else:
store = str(path)
for v in ds:
# Workaround for https://github.com/pydata/xarray/issues/4380
ds[v].encoding.pop("chunks", None)
ds.to_zarr(store, **kwargs)


def load_dataset(path: PathType) -> Dataset:
def load_dataset(
path: PathType, storage_options: Optional[Dict[str, str]] = None
) -> Dataset:
"""Load a dataset from Zarr storage.
This function is a thin wrapper around :meth:`xarray.open_zarr`
Expand All @@ -38,13 +52,19 @@ def load_dataset(path: PathType) -> Dataset:
----------
path
Path to directory in file system to load from.
storage_options:
Any additional parameters for the storage backend (see ``fsspec.open``).
Returns
-------
Dataset
The dataset loaded from the file system.
"""
store = str(path)
if isinstance(path, str):
storage_options = storage_options or {}
store = fsspec.get_mapper(path, **storage_options)
else:
store = str(path)
ds: Dataset = xr.open_zarr(store, concat_characters=False) # type: ignore[no-untyped-call]
for v in ds:
# Workaround for https://github.com/pydata/xarray/issues/4386
Expand Down
15 changes: 12 additions & 3 deletions sgkit/tests/io/test_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytest
import xarray as xr
from xarray import Dataset

Expand All @@ -11,14 +12,22 @@ def assert_identical(ds1: Dataset, ds2: Dataset) -> None:
assert all([ds1[v].dtype == ds2[v].dtype for v in ds1.data_vars])


def test_save_and_load_dataset(tmp_path):
path = str(tmp_path / "ds.zarr")
@pytest.mark.parametrize(
"is_path",
[True, False],
)
def test_save_and_load_dataset(tmp_path, is_path):
path = tmp_path / "ds.zarr"
if not is_path:
path = str(path)
ds = simulate_genotype_call_dataset(n_variant=10, n_sample=10)
save_dataset(ds, path)
ds2 = load_dataset(path)
assert_identical(ds, ds2)

# save and load again to test https://github.com/pydata/xarray/issues/4386
path2 = str(tmp_path / "ds2.zarr")
path2 = tmp_path / "ds2.zarr"
if not is_path:
path2 = str(path2)
save_dataset(ds2, path2)
assert_identical(ds, load_dataset(path2))

0 comments on commit 0a8ecb2

Please sign in to comment.