Skip to content

Commit

Permalink
Fix support for missing values. Fixes #313
Browse files Browse the repository at this point in the history
  • Loading branch information
iainrussell committed Sep 14, 2022
1 parent cd752a1 commit 2707106
Show file tree
Hide file tree
Showing 6 changed files with 22 additions and 20 deletions.
11 changes: 4 additions & 7 deletions cfgrib/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,10 +305,9 @@ def expand_item(item, shape):
return tuple(expanded_item)


def get_values_in_order(message, shape, missing_value):
def get_values_in_order(message, shape):
# type: (abc.Field, T.Tuple[int]) -> np.ndarray
# inform the data provider to return missing values as missing_value
message["missingValue"] = missing_value
values = message["values"]
if message.get("alternativeRowScanning", False):
values = values.copy().reshape(shape)
Expand All @@ -335,7 +334,7 @@ def build_array(self) -> np.ndarray:
for header_indexes, message_ids in self.field_id_index.items():
# NOTE: fill a single field as found in the message
message = self.index.get_field(message_ids[0]) # type: ignore
values = get_values_in_order(message, array[header_indexes].shape, self.missing_value)
values = get_values_in_order(message, array[header_indexes].shape)
array.__getitem__(header_indexes).flat[:] = values
array[array == self.missing_value] = np.nan
return array
Expand All @@ -353,9 +352,7 @@ def __getitem__(self, item):
continue
# NOTE: fill a single field as found in the message
message = self.index.get_field(message_ids[0]) # type: ignore
values = get_values_in_order(
message, array_field[tuple(array_field_indexes)].shape, self.missing_value
)
values = get_values_in_order(message, array_field[tuple(array_field_indexes)].shape)
array_field.__getitem__(tuple(array_field_indexes)).flat[:] = values

array = np.asarray(array_field[(Ellipsis,) + item[-self.geo_ndim :]])
Expand Down Expand Up @@ -564,7 +561,7 @@ def build_variable_components(

extra_coords_data[coord_name][header_value] = coord_value
offsets[tuple(header_indexes)] = message_ids
missing_value = data_var_attrs.get("missingValue", np.finfo(np.float32).max)
missing_value = data_var_attrs.get("missingValue", messages.MISSING_VAUE_INDICATOR)
on_disk_array = OnDiskArray(
index=index,
shape=shape,
Expand Down
12 changes: 6 additions & 6 deletions cfgrib/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

from . import abc

MISSING_VAUE_INDICATOR = np.finfo(np.float32).max

eccodes_version = eccodes.codes_get_api_version()

LOG = logging.getLogger(__name__)
Expand Down Expand Up @@ -114,6 +116,10 @@ def from_message(cls, message, **kwargs):
codes_id = eccodes.codes_clone(message.codes_id)
return cls(codes_id=codes_id, **kwargs)

# ensure that missing values in the values array are represented by MISSING_VAUE_INDICATOR
def __attrs_post_init__(self):
self["missingValue"] = MISSING_VAUE_INDICATOR

def __del__(self) -> None:
eccodes.codes_release(self.codes_id)

Expand Down Expand Up @@ -233,12 +239,6 @@ def __getitem__(self, item: str) -> T.Any:
else:
return self.context[item]

def __setitem__(self, item: str, value: T.Any) -> None:
if item in self.computed_keys:
self.computed_keys[item] = value
else:
self.context[item] = value

def __iter__(self) -> T.Iterator[str]:
seen = set()
for key in self.context:
Expand Down
7 changes: 6 additions & 1 deletion cfgrib/xarray_to_grib.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ def canonical_dataarray_to_grib(
# validate Dataset keys, DataArray names, and attr keys/values
detected_keys, suggested_keys = detect_grib_keys(data_var, default_grib_keys, grib_keys)
merged_grib_keys = merge_grib_keys(grib_keys, detected_keys, suggested_keys)
merged_grib_keys["missingValue"] = messages.MISSING_VAUE_INDICATOR

if "gridType" not in merged_grib_keys:
raise ValueError("required grib_key 'gridType' not passed nor auto-detected")
Expand All @@ -232,7 +233,7 @@ def canonical_dataarray_to_grib(
if invalid_field_values.all():
continue

missing_value = merged_grib_keys.get("missingValue", 9999)
missing_value = merged_grib_keys.get("GRIB_missingValue", messages.MISSING_VAUE_INDICATOR)
field_values[invalid_field_values] = missing_value

message = cfmessage.CfMessage.from_message(template_message)
Expand All @@ -241,6 +242,10 @@ def canonical_dataarray_to_grib(
coord_name = "level"
message[coord_name] = coord_value

if invalid_field_values.any():
message["bitmapPresent"] = 1
message["missingValue"] = missing_value

# OPTIMIZE: convert to list because Message.message_set doesn't support np.ndarray
message["values"] = field_values.tolist()

Expand Down
4 changes: 2 additions & 2 deletions tests/test_30_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,5 +322,5 @@ def test_alternating_rows() -> None:
def test_missing_field_values() -> None:
res = dataset.open_file(TEST_DATA_MISSING_VALS)
t2 = res.variables["t2m"]
assert np.isclose(np.nanmean(t2.data.build_array()[0, :, :]), 268.375)
assert np.isclose(np.nanmean(t2.data.build_array()[1, :, :]), 270.716)
assert np.isclose(np.nanmean(t2.data[0, :, :]), 268.375)
assert np.isclose(np.nanmean(t2.data[1, :, :]), 270.716)
4 changes: 2 additions & 2 deletions tests/test_50_sample_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,5 +117,5 @@ def test_dataset_missing_field_values() -> None:
os.path.join(SAMPLE_DATA_FOLDER, "fields_with_missing_values.grib")
)
t2 = res.variables["t2m"]
assert np.isclose(t2[0, :, :].mean(), 268.375)
assert np.isclose(t2[1, :, :].mean(), 270.716)
assert np.isclose(np.nanmean(t2[0, :, :]), 268.375)
assert np.isclose(np.nanmean(t2[1, :, :]), 270.716)
4 changes: 2 additions & 2 deletions tests/test_50_xarray_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,5 +89,5 @@ def test_read() -> None:
def test_xr_open_dataset_file_missing_vals() -> None:
ds = xr.open_dataset(TEST_DATA_MISSING_VALS, engine="cfgrib")
t2 = ds["t2m"]
assert np.isclose(t2[0, :, :].mean(), 268.375)
assert np.isclose(t2[1, :, :].mean(), 270.716)
assert np.isclose(np.nanmean(t2.values[0, :, :]), 268.375)
assert np.isclose(np.nanmean(t2.values[1, :, :]), 270.716)

0 comments on commit 2707106

Please sign in to comment.