Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding tests to #506 #3

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 34 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def make_http_paths(netcdf_local_paths, request):
def daily_xarray_dataset():
return make_ds(nt=10)


@pytest.fixture(scope="session")
def daily_xarray_dataset_with_coordinateless_dimension(daily_xarray_dataset):
"""
Expand All @@ -215,6 +214,12 @@ def daily_xarray_dataset_with_coordinateless_dimension(daily_xarray_dataset):
del ds["lon"]
return ds

@pytest.fixture(scope="session")
def daily_xarray_dataset_with_extra_dimension_coordinates():
ds = make_ds(nt=11, add_extra_dim_coords=True)
ds['extra_dim_var'] = ds['extra_dim_coord']
return ds


@pytest.fixture(scope="session")
def netcdf_local_paths_sequential_1d(daily_xarray_dataset, tmpdir_factory):
Expand Down Expand Up @@ -303,6 +308,28 @@ def netcdf_local_paths_sequential_multivariable_with_coordinateless_dimension(
file_type="netcdf4",
)

@pytest.fixture(scope='session')
def netcdf_local_paths_sequential_with_extra_dimension_coordinate(
daily_xarray_dataset_with_extra_dimension_coordinates, tmpdir_factory
):
return make_local_paths(
daily_xarray_dataset_with_extra_dimension_coordinates,
tmpdir_factory,
"D",
split_up_files_by_day,
file_type="netcdf4",
)

@pytest.fixture(
scope="session",
params=[
lazy_fixture("netcdf_local_paths_sequential_with_extra_dimension_coordinate"),
],
)
def netcdf_local_paths_sequential_extra_dimension_coordinate(request):
return request.param



@pytest.fixture(
scope="session",
Expand Down Expand Up @@ -380,7 +407,6 @@ def netcdf_local_paths_sequential_with_coordinateless_dimension(
def netcdf_local_file_pattern_sequential(netcdf_local_paths_sequential):
return make_file_pattern(netcdf_local_paths_sequential)


@pytest.fixture(scope="session")
def netcdf_local_file_pattern_sequential_multivariable(
netcdf_local_paths_sequential_multivariable,
Expand Down Expand Up @@ -418,6 +444,12 @@ def netcdf_local_file_pattern_sequential_with_coordinateless_dimension(
"""
return make_file_pattern(netcdf_local_paths_sequential_with_coordinateless_dimension)

@pytest.fixture(scope='session')
def netcdf_local_file_pattern_sequential_extra_dimension_coordinate(
netcdf_local_paths_sequential_extra_dimension_coordinate,
):
return make_file_pattern(netcdf_local_paths_sequential_extra_dimension_coordinate)


# Storage fixtures --------------------------------------------------------------------------------

Expand Down
9 changes: 7 additions & 2 deletions tests/data_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
import xarray as xr


def make_ds(nt=10, non_dim_coords=False):
def make_ds(nt=10, non_dim_coords=False, add_extra_dim_coords=False):
"""Return a synthetic random xarray dataset."""
np.random.seed(2)
# TODO: change nt to 11 in order to catch the edge case where
# items_per_input does not evenly divide the length of the sequence dimension
ny, nx = 18, 36
ny, nx, ne = 18, 36, 2
time = pd.date_range(start="2010-01-01", periods=nt, freq="D")
lon = (np.arange(nx) + 0.5) * 360 / nx
lon_attrs = {"units": "degrees_east", "long_name": "longitude"}
Expand All @@ -28,6 +28,11 @@ def make_ds(nt=10, non_dim_coords=False):
if non_dim_coords:
coords["timestep"] = ("time", np.arange(nt))
coords["baz"] = (("lat", "lon"), np.random.rand(ny, nx))

if add_extra_dim_coords:
# introduce a coordinate with a dimension not used in the data variables
coords["extra_dim_coord"] = (("extra_dim", "time"), np.random.rand(ne, nt))
coords["extra_dim"] = ("extra_dim", np.arange(ne))

ds = xr.Dataset(
{"bar": (dims, bar, bar_attrs), "foo": (dims, foo, foo_attrs)},
Expand Down
33 changes: 31 additions & 2 deletions tests/test_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def pipeline():
yield p


@pytest.mark.parametrize("target_chunks", [{"time": 1}, {"time": 2}, {"time": 3}])
@pytest.mark.parametrize("target_chunks", [{"time": 1}, {"time": 2}, {"time": 3}, {'time':1, 'lon': 18}])
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensures that providing target chunks to additional dimension which are part of the data variable does not trigger the failure

def test_xarray_zarr(
daily_xarray_dataset,
netcdf_local_file_pattern_sequential,
Expand All @@ -45,7 +45,6 @@ def test_xarray_zarr(
assert ds.time.encoding["chunks"] == (target_chunks["time"],)
xr.testing.assert_equal(ds.load(), daily_xarray_dataset)


def test_xarray_zarr_subpath(
daily_xarray_dataset,
netcdf_local_file_pattern_sequential,
Expand All @@ -67,3 +66,33 @@ def test_xarray_zarr_subpath(

ds = xr.open_dataset(os.path.join(tmp_target_url, "subpath"), engine="zarr")
xr.testing.assert_equal(ds.load(), daily_xarray_dataset)

@pytest.mark.parametrize("target_chunks", [{"time": 1}, {"time": 2}, {"time": 3}])
def test_xarray_zarr_extra_dimension_coordinate(
daily_xarray_dataset_with_extra_dimension_coordinates,
netcdf_local_file_pattern_sequential_extra_dimension_coordinate,
pipeline,
tmp_target_url,
target_chunks,
):
# triggers https://github.com/pangeo-forge/pangeo-forge-recipes/issues/504
target_chunks['extra_dim'] = 2

pattern = netcdf_local_file_pattern_sequential_extra_dimension_coordinate

with pipeline as p:
(
p
| beam.Create(pattern.items())
| OpenWithXarray(file_type=pattern.file_type)
| StoreToZarr(
target_root=tmp_target_url,
store_name="store",
target_chunks=target_chunks,
combine_dims=pattern.combine_dim_keys,
)
)

ds = xr.open_dataset(os.path.join(tmp_target_url, "store"), engine="zarr")
assert ds.time.encoding["chunks"] == (target_chunks["time"],)
xr.testing.assert_equal(ds.load(), daily_xarray_dataset_with_extra_dimension_coordinates)