diff --git a/cfgrib/dataset.py b/cfgrib/dataset.py index 1dbe3f09..cea3f810 100644 --- a/cfgrib/dataset.py +++ b/cfgrib/dataset.py @@ -521,13 +521,18 @@ def build_variable_components( extra_coords_data: T.Dict[str, T.Dict[str, T.Any]] = { coord_name: {} for coord_name in extra_coords } - for dim in header_dimensions: - header_value_index[dim] = {v: i for i, v in enumerate(coord_vars[dim].data.tolist())} + extra_dims = tuple(extra_coords.values()) + for dim in header_dimensions + extra_dims: + if np.isscalar(coord_vars[dim].data): + header_value_index[dim] = {np.asscalar(coord_vars[dim].data): 0} + else: + header_value_index[dim] = {v: i for i, v in enumerate(coord_vars[dim].data.tolist())} for header_values, offset in index.offsets: header_indexes = [] # type: T.List[int] - for dim in header_dimensions: + for dim in header_dimensions + extra_dims: header_value = header_values[index.index_keys.index(coord_name_key_map.get(dim, dim))] - header_indexes.append(header_value_index[dim][header_value]) + if dim in header_dimensions: + header_indexes.append(header_value_index[dim][header_value]) for coord_name in extra_coords: coord_value = header_values[ index.index_keys.index(coord_name_key_map.get(coord_name, coord_name)) @@ -563,10 +568,14 @@ def build_variable_components( coord_vars["valid_time"] = Variable(dimensions=time_dims, data=time_data, attributes=attrs) for coord_name in extra_coords: - coord_vars[coord_name] = Variable( - dimensions=(extra_coords[coord_name],), - data=np.array(list(extra_coords_data[coord_name].values())), - ) + coord_data = np.array(list(extra_coords_data[coord_name].values())) + if extra_coords[coord_name] not in header_dimensions: + coord_dimensions: T.Tuple[str, ...] = () + coord_data = coord_data.reshape(()) + else: + coord_dimensions = (extra_coords[coord_name],) + coord_vars[coord_name] = Variable(dimensions=coord_dimensions, data=coord_data,) + data_var_attrs["coordinates"] = " ".join(coord_vars.keys()) # OnDiskArray is close enough to np.ndarray to work, but not to make mypy happy data_var = Variable(dimensions=dimensions, data=on_disk_array, attributes=data_var_attrs) # type: ignore diff --git a/cfgrib/xarray_plugin.py b/cfgrib/xarray_plugin.py index 74b5aea8..474514f4 100644 --- a/cfgrib/xarray_plugin.py +++ b/cfgrib/xarray_plugin.py @@ -94,6 +94,7 @@ def open_dataset( squeeze: bool = True, time_dims: T.Iterable[str] = ("time", "step"), errors: str = "warn", + extra_coords: T.Dict[str, str] = {}, ) -> xr.Dataset: store = CfGribDataStore( @@ -106,6 +107,7 @@ def open_dataset( time_dims=time_dims, lock=lock, errors=errors, + extra_coords=extra_coords, ) with xr.core.utils.close_on_error(store): vars, attrs = store.load() # type: ignore diff --git a/tests/sample-data/era5-single-level-scalar-time.grib b/tests/sample-data/era5-single-level-scalar-time.grib new file mode 100644 index 00000000..8ca53650 Binary files /dev/null and b/tests/sample-data/era5-single-level-scalar-time.grib differ diff --git a/tests/test_30_dataset.py b/tests/test_30_dataset.py index 36b9f9b7..57d5fd92 100644 --- a/tests/test_30_dataset.py +++ b/tests/test_30_dataset.py @@ -9,6 +9,7 @@ SAMPLE_DATA_FOLDER = os.path.join(os.path.dirname(__file__), "sample-data") TEST_DATA = os.path.join(SAMPLE_DATA_FOLDER, "era5-levels-members.grib") TEST_DATA_UKMO = os.path.join(SAMPLE_DATA_FOLDER, "forecast_monthly_ukmo.grib") +TEST_DATA_SCALAR_TIME = os.path.join(SAMPLE_DATA_FOLDER, "era5-single-level-scalar-time.grib") def test_enforce_unique_attributes(): @@ -206,7 +207,15 @@ def test_Dataset_extra_coords(): assert res.variables["experimentVersionNumber"].dimensions == ("time",) -def test_Dataet_extra_coords_error(): +def test_Dataset_scalar_extra_coords(): + res = dataset.open_file( + TEST_DATA_SCALAR_TIME, extra_coords={"experimentVersionNumber": "time"} + ) + assert "experimentVersionNumber" in res.variables + assert res.variables["experimentVersionNumber"].dimensions == () + + +def test_Dataset_extra_coords_error(): with pytest.raises(ValueError): dataset.open_file(TEST_DATA, extra_coords={"validityDate": "number"}) diff --git a/tests/test_50_sample_data.py b/tests/test_50_sample_data.py index b7149637..a99417de 100644 --- a/tests/test_50_sample_data.py +++ b/tests/test_50_sample_data.py @@ -89,3 +89,17 @@ def test_canonical_dataset_to_grib(grib_name, tmpdir): xarray_to_grib.canonical_dataset_to_grib(res, out_path) reread = xarray_store.open_dataset(out_path) assert res.equals(reread) + + +@pytest.mark.parametrize( + "grib_name,ndims", [("era5-levels-members", 1), ("era5-single-level-scalar-time", 0),], +) +def test_open_dataset_extra_coords(grib_name, ndims): + grib_path = os.path.join(SAMPLE_DATA_FOLDER, grib_name + ".grib") + res = xarray_store.open_dataset( + grib_path, + backend_kwargs={"extra_coords": {"experimentVersionNumber": "time"}}, + cache=False, + ) + assert "experimentVersionNumber" in res.coords + assert len(res["experimentVersionNumber"].dims) == ndims