Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code cleanup #5234

Merged
merged 19 commits into from
May 13, 2021
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 19 additions & 26 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,11 @@ def _get_default_engine_netcdf():

def _get_default_engine(path: str, allow_remote: bool = False):
if allow_remote and is_remote_uri(path):
engine = _get_default_engine_remote_uri()
return _get_default_engine_remote_uri()
elif path.endswith(".gz"):
engine = _get_default_engine_gz()
return _get_default_engine_gz()
else:
engine = _get_default_engine_netcdf()
return engine
return _get_default_engine_netcdf()
andersy005 marked this conversation as resolved.
Show resolved Hide resolved


def _validate_dataset_names(dataset):
Expand Down Expand Up @@ -282,7 +281,7 @@ def _chunk_ds(

mtime = _get_mtime(filename_or_obj)
token = tokenize(filename_or_obj, mtime, engine, chunks, **extra_tokens)
name_prefix = "open_dataset-%s" % token
name_prefix = f"open_dataset-{token}"

variables = {}
for name, var in backend_ds.variables.items():
Expand All @@ -295,8 +294,7 @@ def _chunk_ds(
name_prefix=name_prefix,
token=token,
)
ds = backend_ds._replace(variables)
return ds
return backend_ds._replace(variables)


def _dataset_from_backend_dataset(
Expand All @@ -308,12 +306,10 @@ def _dataset_from_backend_dataset(
overwrite_encoded_chunks,
**extra_tokens,
):
if not (isinstance(chunks, (int, dict)) or chunks is None):
if chunks != "auto":
raise ValueError(
"chunks must be an int, dict, 'auto', or None. "
"Instead found %s. " % chunks
)
if not isinstance(chunks, (int, dict)) and chunks not in {None, "auto"}:
raise ValueError(
f"chunks must be an int, dict, 'auto', or None. Instead found {chunks}."
)

_protect_dataset_variables_inplace(backend_ds, cache)
if chunks is None:
Expand All @@ -331,9 +327,8 @@ def _dataset_from_backend_dataset(
ds.set_close(backend_ds._close)

# Ensure source filename always stored in dataset object (GH issue #2550)
if "source" not in ds.encoding:
if isinstance(filename_or_obj, str):
ds.encoding["source"] = filename_or_obj
if "source" not in ds.encoding and isinstance(filename_or_obj, str):
ds.encoding["source"] = filename_or_obj

return ds

Expand Down Expand Up @@ -515,7 +510,6 @@ def open_dataset(
**decoders,
**kwargs,
)

return ds


Expand Down Expand Up @@ -1015,8 +1009,8 @@ def to_netcdf(
elif engine != "scipy":
raise ValueError(
"invalid engine for creating bytes with "
"to_netcdf: %r. Only the default engine "
"or engine='scipy' is supported" % engine
f"to_netcdf: {engine!r}. Only the default engine "
"or engine='scipy' is supported"
)
if not compute:
raise NotImplementedError(
Expand All @@ -1037,7 +1031,7 @@ def to_netcdf(
try:
store_open = WRITEABLE_STORES[engine]
except KeyError:
raise ValueError("unrecognized engine for to_netcdf: %r" % engine)
raise ValueError(f"unrecognized engine for to_netcdf: {engine!r}")

if format is not None:
format = format.upper()
Expand All @@ -1049,9 +1043,8 @@ def to_netcdf(
autoclose = have_chunks and scheduler in ["distributed", "multiprocessing"]
if autoclose and engine == "scipy":
raise NotImplementedError(
"Writing netCDF files with the %s backend "
"is not currently supported with dask's %s "
"scheduler" % (engine, scheduler)
f"Writing netCDF files with the {engine} backend "
f"is not currently supported with dask's {scheduler} scheduler"
)

target = path_or_file if path_or_file is not None else BytesIO()
Expand All @@ -1061,7 +1054,7 @@ def to_netcdf(
kwargs["invalid_netcdf"] = invalid_netcdf
else:
raise ValueError(
"unrecognized option 'invalid_netcdf' for engine %s" % engine
f"unrecognized option 'invalid_netcdf' for engine {engine}"
)
store = store_open(target, mode, format, group, **kwargs)

Expand Down Expand Up @@ -1203,7 +1196,7 @@ def save_mfdataset(
Data variables:
a (time) float64 0.0 0.02128 0.04255 0.06383 ... 0.9574 0.9787 1.0
>>> years, datasets = zip(*ds.groupby("time.year"))
>>> paths = ["%s.nc" % y for y in years]
>>> paths = [f"{y}.nc" for y in years]
>>> xr.save_mfdataset(datasets, paths)
"""
if mode == "w" and len(set(paths)) < len(paths):
Expand All @@ -1215,7 +1208,7 @@ def save_mfdataset(
if not isinstance(obj, Dataset):
raise TypeError(
"save_mfdataset only supports writing Dataset "
"objects, received type %s" % type(obj)
f"objects, received type {type(obj)}"
)

if groups is None:
Expand Down
3 changes: 1 addition & 2 deletions xarray/backends/cfgrib_.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,7 @@ def get_dimensions(self):

def get_encoding(self):
dims = self.get_dimensions()
encoding = {"unlimited_dims": {k for k, v in dims.items() if v is None}}
return encoding
return {"unlimited_dims": {k for k, v in dims.items() if v is None}}


class CfgribfBackendEntrypoint(BackendEntrypoint):
Expand Down
7 changes: 3 additions & 4 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ def robust_getitem(array, key, catch=Exception, max_retries=6, initial_delay=500
base_delay = initial_delay * 2 ** n
next_delay = base_delay + np.random.randint(base_delay)
msg = (
"getitem failed, waiting %s ms before trying again "
"(%s tries remaining). Full traceback: %s"
% (next_delay, max_retries - n, traceback.format_exc())
f"getitem failed, waiting {next_delay} ms before trying again "
f"({max_retries - n} tries remaining). Full traceback: {traceback.format_exc()}"
)
logger.debug(msg)
time.sleep(1e-3 * next_delay)
Expand Down Expand Up @@ -336,7 +335,7 @@ def set_dimensions(self, variables, unlimited_dims=None):
if dim in existing_dims and length != existing_dims[dim]:
raise ValueError(
"Unable to update size for existing dimension"
"%r (%d != %d)" % (dim, length, existing_dims[dim])
f"{dim!r} ({length} != {existing_dims[dim]})"
)
elif dim not in existing_dims:
is_unlimited = dim in unlimited_dims
Expand Down
18 changes: 7 additions & 11 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@
class H5NetCDFArrayWrapper(BaseNetCDF4Array):
def get_array(self, needs_lock=True):
ds = self.datastore._acquire(needs_lock)
variable = ds.variables[self.variable_name]
return variable
return ds.variables[self.variable_name]

def __getitem__(self, key):
return indexing.explicit_indexing_adapter(
Expand Down Expand Up @@ -102,7 +101,7 @@ def __init__(self, manager, group=None, mode=None, lock=HDF5_LOCK, autoclose=Fal
if group is None:
root, group = find_root_and_group(manager)
else:
if not type(manager) is h5netcdf.File:
if type(manager) is not h5netcdf.File:
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"must supply a h5netcdf.File if the group "
"argument is provided"
Expand Down Expand Up @@ -233,11 +232,9 @@ def get_dimensions(self):
return self.ds.dimensions

def get_encoding(self):
encoding = {}
encoding["unlimited_dims"] = {
k for k, v in self.ds.dimensions.items() if v is None
return {
"unlimited_dims": {k for k, v in self.ds.dimensions.items() if v is None}
}
return encoding

def set_dimension(self, name, length, is_unlimited=False):
if is_unlimited:
Expand Down Expand Up @@ -266,9 +263,9 @@ def prepare_variable(
"h5netcdf does not yet support setting a fill value for "
"variable-length strings "
"(https://github.com/shoyer/h5netcdf/issues/37). "
"Either remove '_FillValue' from encoding on variable %r "
f"Either remove '_FillValue' from encoding on variable {name!r} "
"or set {'dtype': 'S1'} in encoding to use the fixed width "
"NC_CHAR type." % name
"NC_CHAR type."
)

if dtype is str:
Expand Down Expand Up @@ -380,7 +377,7 @@ def open_dataset(

store_entrypoint = StoreBackendEntrypoint()

ds = store_entrypoint.open_dataset(
return store_entrypoint.open_dataset(
store,
mask_and_scale=mask_and_scale,
decode_times=decode_times,
Expand All @@ -390,7 +387,6 @@ def open_dataset(
use_cftime=use_cftime,
decode_timedelta=decode_timedelta,
)
return ds
andersy005 marked this conversation as resolved.
Show resolved Hide resolved


if has_h5netcdf:
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def locked(self):
return any(lock.locked for lock in self.locks)

def __repr__(self):
return "CombinedLock(%r)" % list(self.locks)
return f"CombinedLock({list(self.locks)!r})"


class DummyLock:
Expand Down
68 changes: 29 additions & 39 deletions xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,25 +122,23 @@ def _encode_nc4_variable(var):
def _check_encoding_dtype_is_vlen_string(dtype):
if dtype is not str:
raise AssertionError( # pragma: no cover
"unexpected dtype encoding %r. This shouldn't happen: please "
"file a bug report at github.com/pydata/xarray" % dtype
f"unexpected dtype encoding {dtype!r}. This shouldn't happen: please "
"file a bug report at github.com/pydata/xarray"
)


def _get_datatype(var, nc_format="NETCDF4", raise_on_invalid_encoding=False):
if nc_format == "NETCDF4":
datatype = _nc4_dtype(var)
else:
if "dtype" in var.encoding:
encoded_dtype = var.encoding["dtype"]
_check_encoding_dtype_is_vlen_string(encoded_dtype)
if raise_on_invalid_encoding:
raise ValueError(
"encoding dtype=str for vlen strings is only supported "
"with format='NETCDF4'."
)
datatype = var.dtype
return datatype
return _nc4_dtype(var)
if "dtype" in var.encoding:
encoded_dtype = var.encoding["dtype"]
_check_encoding_dtype_is_vlen_string(encoded_dtype)
if raise_on_invalid_encoding:
raise ValueError(
"encoding dtype=str for vlen strings is only supported "
"with format='NETCDF4'."
)
return var.dtype


def _nc4_dtype(var):
Expand Down Expand Up @@ -178,7 +176,7 @@ def _nc4_require_group(ds, group, mode, create_group=_netcdf4_create_group):
ds = create_group(ds, key)
else:
# wrap error to provide slightly more helpful message
raise OSError("group not found: %s" % key, e)
raise OSError(f"group not found: {key}", e)
return ds


Expand All @@ -203,7 +201,7 @@ def _force_native_endianness(var):
# if endian exists, remove it from the encoding.
var.encoding.pop("endian", None)
# check to see if encoding has a value for endian its 'native'
if not var.encoding.get("endian", "native") == "native":
if var.encoding.get("endian", "native") != "native":
raise NotImplementedError(
"Attempt to write non-native endian type, "
"this is not supported by the netCDF4 "
Expand Down Expand Up @@ -270,8 +268,8 @@ def _extract_nc4_variable_encoding(
invalid = [k for k in encoding if k not in valid_encodings]
if invalid:
raise ValueError(
"unexpected encoding parameters for %r backend: %r. Valid "
"encodings are: %r" % (backend, invalid, valid_encodings)
f"unexpected encoding parameters for {backend!r} backend: {invalid!r}. Valid "
f"encodings are: {valid_encodings!r}"
)
else:
for k in list(encoding):
Expand All @@ -282,10 +280,8 @@ def _extract_nc4_variable_encoding(


def _is_list_of_strings(value):
if np.asarray(value).dtype.kind in ["U", "S"] and np.asarray(value).size > 1:
return True
else:
return False
arr = np.asarray(value)
return arr.dtype.kind in ["U", "S"] and arr.size > 1


class NetCDF4DataStore(WritableCFDataStore):
Expand Down Expand Up @@ -313,7 +309,7 @@ def __init__(
if group is None:
root, group = find_root_and_group(manager)
else:
if not type(manager) is netCDF4.Dataset:
if type(manager) is not netCDF4.Dataset:
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
"must supply a root netCDF4.Dataset if the group "
"argument is provided"
Expand Down Expand Up @@ -359,10 +355,7 @@ def open(

if lock is None:
if mode == "r":
if is_remote_uri(filename):
lock = NETCDFC_LOCK
else:
lock = NETCDF4_PYTHON_LOCK
lock = NETCDFC_LOCK if is_remote_uri(filename) else NETCDF4_PYTHON_LOCK
andersy005 marked this conversation as resolved.
Show resolved Hide resolved
else:
if format is None or format.startswith("NETCDF4"):
base_lock = NETCDF4_PYTHON_LOCK
Expand Down Expand Up @@ -417,25 +410,22 @@ def open_store_variable(self, name, var):
return Variable(dimensions, data, attributes, encoding)

def get_variables(self):
dsvars = FrozenDict(
return FrozenDict(
(k, self.open_store_variable(k, v)) for k, v in self.ds.variables.items()
)
return dsvars

def get_attrs(self):
attrs = FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs())
return attrs
return FrozenDict((k, self.ds.getncattr(k)) for k in self.ds.ncattrs())

def get_dimensions(self):
dims = FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())
return dims
return FrozenDict((k, len(v)) for k, v in self.ds.dimensions.items())

def get_encoding(self):
encoding = {}
encoding["unlimited_dims"] = {
k for k, v in self.ds.dimensions.items() if v.isunlimited()
return {
"unlimited_dims": {
k for k, v in self.ds.dimensions.items() if v.isunlimited()
}
}
return encoding

def set_dimension(self, name, length, is_unlimited=False):
dim_length = length if not is_unlimited else None
Expand Down Expand Up @@ -473,9 +463,9 @@ def prepare_variable(
"netCDF4 does not yet support setting a fill value for "
"variable-length strings "
"(https://github.com/Unidata/netcdf4-python/issues/730). "
"Either remove '_FillValue' from encoding on variable %r "
f"Either remove '_FillValue' from encoding on variable {name!r} "
"or set {'dtype': 'S1'} in encoding to use the fixed width "
"NC_CHAR type." % name
"NC_CHAR type."
)
andersy005 marked this conversation as resolved.
Show resolved Hide resolved

encoding = _extract_nc4_variable_encoding(
Expand Down
2 changes: 0 additions & 2 deletions xarray/backends/netcdf3.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,6 @@ def is_valid_nc3_name(s):
"""
if not isinstance(s, str):
return False
if not isinstance(s, str):
s = s.decode("utf-8")
num_bytes = len(s.encode("utf-8"))
return (
(unicodedata.normalize("NFC", s) == s)
Expand Down
Loading