Skip to content

Commit

Permalink
Merge pull request #238 from bopen/fix_extra_coords
Browse files Browse the repository at this point in the history
Fix in extra_coords
  • Loading branch information
alexamici authored Jul 30, 2021
2 parents e7207b2 + 20582b9 commit 967001d
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 9 deletions.
25 changes: 17 additions & 8 deletions cfgrib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,13 +521,18 @@ def build_variable_components(
extra_coords_data: T.Dict[str, T.Dict[str, T.Any]] = {
coord_name: {} for coord_name in extra_coords
}
for dim in header_dimensions:
header_value_index[dim] = {v: i for i, v in enumerate(coord_vars[dim].data.tolist())}
extra_dims = tuple(extra_coords.values())
for dim in header_dimensions + extra_dims:
if np.isscalar(coord_vars[dim].data):
header_value_index[dim] = {np.asscalar(coord_vars[dim].data): 0}
else:
header_value_index[dim] = {v: i for i, v in enumerate(coord_vars[dim].data.tolist())}
for header_values, offset in index.offsets:
header_indexes = [] # type: T.List[int]
for dim in header_dimensions:
for dim in header_dimensions + extra_dims:
header_value = header_values[index.index_keys.index(coord_name_key_map.get(dim, dim))]
header_indexes.append(header_value_index[dim][header_value])
if dim in header_dimensions:
header_indexes.append(header_value_index[dim][header_value])
for coord_name in extra_coords:
coord_value = header_values[
index.index_keys.index(coord_name_key_map.get(coord_name, coord_name))
Expand Down Expand Up @@ -563,10 +568,14 @@ def build_variable_components(
coord_vars["valid_time"] = Variable(dimensions=time_dims, data=time_data, attributes=attrs)

for coord_name in extra_coords:
coord_vars[coord_name] = Variable(
dimensions=(extra_coords[coord_name],),
data=np.array(list(extra_coords_data[coord_name].values())),
)
coord_data = np.array(list(extra_coords_data[coord_name].values()))
if extra_coords[coord_name] not in header_dimensions:
coord_dimensions: T.Tuple[str, ...] = ()
coord_data = coord_data.reshape(())
else:
coord_dimensions = (extra_coords[coord_name],)
coord_vars[coord_name] = Variable(dimensions=coord_dimensions, data=coord_data,)

data_var_attrs["coordinates"] = " ".join(coord_vars.keys())
# OnDiskArray is close enough to np.ndarray to work, but not to make mypy happy
data_var = Variable(dimensions=dimensions, data=on_disk_array, attributes=data_var_attrs) # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions cfgrib/xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def open_dataset(
squeeze: bool = True,
time_dims: T.Iterable[str] = ("time", "step"),
errors: str = "warn",
extra_coords: T.Dict[str, str] = {},
) -> xr.Dataset:

store = CfGribDataStore(
Expand All @@ -106,6 +107,7 @@ def open_dataset(
time_dims=time_dims,
lock=lock,
errors=errors,
extra_coords=extra_coords,
)
with xr.core.utils.close_on_error(store):
vars, attrs = store.load() # type: ignore
Expand Down
Binary file not shown.
11 changes: 10 additions & 1 deletion tests/test_30_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data")
TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "era5-levels-members.grib")
TEST_DATA_UKMO = os.path.join(SAMPLE_DATA_FOLDER, "forecast_monthly_ukmo.grib")
TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib")


def test_enforce_unique_attributes():
Expand Down Expand Up @@ -206,7 +207,15 @@ def test_Dataset_extra_coords():
assert res.variables["experimentVersionNumber"].dimensions == ("time",)


def test_Dataet_extra_coords_error():
def test_Dataset_scalar_extra_coords():
res = dataset.open_file(
TEST_DATA_SCALAR_TIME, extra_coords={"experimentVersionNumber": "time"}
)
assert "experimentVersionNumber" in res.variables
assert res.variables["experimentVersionNumber"].dimensions == ()


def test_Dataset_extra_coords_error():
with pytest.raises(ValueError):
dataset.open_file(TEST_DATA, extra_coords={"validityDate": "number"})

Expand Down
14 changes: 14 additions & 0 deletions tests/test_50_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,17 @@ def test_canonical_dataset_to_grib(grib_name, tmpdir):
xarray_to_grib.canonical_dataset_to_grib(res, out_path)
reread = xarray_store.open_dataset(out_path)
assert res.equals(reread)


@pytest.mark.parametrize(
"grib_name,ndims", [("era5-levels-members", 1), ("era5-single-level-scalar-time", 0),],
)
def test_open_dataset_extra_coords(grib_name, ndims):
grib_path = os.path.join(SAMPLE_DATA_FOLDER, grib_name + ".grib")
res = xarray_store.open_dataset(
grib_path,
backend_kwargs={"extra_coords": {"experimentVersionNumber": "time"}},
cache=False,
)
assert "experimentVersionNumber" in res.coords
assert len(res["experimentVersionNumber"].dims) == ndims

0 comments on commit 967001d

Please sign in to comment.