Skip to content

Commit

Permalink
Added support for numpy.bool_ (#4986)
Browse files Browse the repository at this point in the history
  • Loading branch information
caenrigen authored Mar 12, 2021
1 parent 6ff27ca commit 213e352
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 11 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Deprecations

Bug fixes
~~~~~~~~~
- Added support for `numpy.bool_` attributes in roundtrips using `h5netcdf` engine with `invalid_netcdf=True` [which casts `bool`s to `numpy.bool_`] (:issue:`4981`, :pull:`4986`).
By `Victor Negîrneac <https://github.com/caenrigen>`_.
- Don't allow passing ``axis`` to :py:meth:`Dataset.reduce` methods (:issue:`3510`, :pull:`4940`).
By `Justus Magin <https://github.com/keewis>`_.

Expand Down
30 changes: 19 additions & 11 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,21 @@ def check_name(name):
check_name(k)


def _validate_attrs(dataset):
def _validate_attrs(dataset, invalid_netcdf=False):
"""`attrs` must have a string key and a value which is either: a number,
a string, an ndarray or a list/tuple of numbers/strings.
a string, an ndarray, a list/tuple of numbers/strings, or a numpy.bool_.
Notes
-----
A numpy.bool_ is only allowed when using the h5netcdf engine with
`invalid_netcdf=True`.
"""

def check_attr(name, value):
valid_types = (str, Number, np.ndarray, np.number, list, tuple)
if invalid_netcdf:
valid_types += (np.bool_,)

def check_attr(name, value, valid_types):
if isinstance(name, str):
if not name:
raise ValueError(
Expand All @@ -160,22 +169,21 @@ def check_attr(name, value):
"serialization to netCDF files"
)

if not isinstance(value, (str, Number, np.ndarray, np.number, list, tuple)):
if not isinstance(value, valid_types):
raise TypeError(
f"Invalid value for attr {name!r}: {value!r} must be a number, "
"a string, an ndarray or a list/tuple of "
"numbers/strings for serialization to netCDF "
"files"
f"Invalid value for attr {name!r}: {value!r}. For serialization to "
"netCDF files, its value must be of one of the following types: "
f"{', '.join([vtype.__name__ for vtype in valid_types])}"
)

# Check attrs on the dataset itself
for k, v in dataset.attrs.items():
check_attr(k, v)
check_attr(k, v, valid_types)

# Check attrs on each variable within the dataset
for variable in dataset.variables.values():
for k, v in variable.attrs.items():
check_attr(k, v)
check_attr(k, v, valid_types)


def _resolve_decoders_kwargs(decode_cf, open_backend_dataset_parameters, **decoders):
Expand Down Expand Up @@ -1019,7 +1027,7 @@ def to_netcdf(

# validate Dataset keys, DataArray names, and attr keys/values
_validate_dataset_names(dataset)
_validate_attrs(dataset)
_validate_attrs(dataset, invalid_netcdf=invalid_netcdf and engine == "h5netcdf")

try:
store_open = WRITEABLE_STORES[engine]
Expand Down
8 changes: 8 additions & 0 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -2541,6 +2541,14 @@ def test_complex(self, invalid_netcdf, warntype, num_warns):

assert recorded_num_warns == num_warns

def test_numpy_bool_(self):
# h5netcdf loads booleans as numpy.bool_, this type needs to be supported
# when writing invalid_netcdf datasets in order to support a roundtrip
expected = Dataset({"x": ("y", np.ones(5), {"numpy_bool": np.bool_(True)})})
save_kwargs = {"invalid_netcdf": True}
with self.roundtrip(expected, save_kwargs=save_kwargs) as actual:
assert_identical(expected, actual)

def test_cross_engine_read_write_netcdf4(self):
# Drop dim3, because its labels include strings. These appear to be
# not properly read with python-netCDF4, which converts them into
Expand Down

0 comments on commit 213e352

Please sign in to comment.