Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints #5296

Merged
merged 3 commits into from
May 14, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ Deprecations
Bug fixes
~~~~~~~~~

- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying
an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in
0.18.0.
By `Stephan Hoyer <https://github.com/shoyer>`_

Documentation
~~~~~~~~~~~~~
Expand Down
18 changes: 11 additions & 7 deletions xarray/backends/h5netcdf_.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import numpy as np

from ..core import indexing
from ..core.utils import FrozenDict, is_remote_uri, read_magic_number
from ..core.utils import (
FrozenDict,
is_remote_uri,
read_magic_number_from_file,
try_read_magic_number_from_file_or_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -141,10 +146,10 @@ def open(
"try passing a path or file-like object"
)
elif isinstance(filename, io.IOBase):
magic_number = read_magic_number(filename)
magic_number = read_magic_number_from_file(filename)
if not magic_number.startswith(b"\211HDF\r\n\032\n"):
raise ValueError(
f"{magic_number} is not the signature of a valid netCDF file"
f"{magic_number} is not the signature of a valid netCDF4 file"
)

if format not in [None, "NETCDF4"]:
Expand Down Expand Up @@ -336,10 +341,9 @@ def close(self, **kwargs):

class H5netcdfBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
try:
return read_magic_number(filename_or_obj).startswith(b"\211HDF\r\n\032\n")
except TypeError:
pass
magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
if magic_number is not None:
return magic_number.startswith(b"\211HDF\r\n\032\n")

try:
_, ext = os.path.splitext(filename_or_obj)
Expand Down
11 changes: 10 additions & 1 deletion xarray/backends/netCDF4_.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
from .. import coding
from ..coding.variables import pop_to
from ..core import indexing
from ..core.utils import FrozenDict, close_on_error, is_remote_uri
from ..core.utils import (
FrozenDict,
close_on_error,
is_remote_uri,
try_read_magic_number_from_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -517,6 +522,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj):
return True
magic_number = try_read_magic_number_from_path(filename_or_obj)
if magic_number is not None:
# netcdf 3 or HDF5
return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n"))
try:
_, ext = os.path.splitext(filename_or_obj)
except TypeError:
Expand Down
15 changes: 10 additions & 5 deletions xarray/backends/scipy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
import numpy as np

from ..core.indexing import NumpyIndexingAdapter
from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number
from ..core.utils import (
Frozen,
FrozenDict,
close_on_error,
try_read_magic_number_from_file_or_path,
)
from ..core.variable import Variable
from .common import (
BACKEND_ENTRYPOINTS,
Expand Down Expand Up @@ -235,10 +240,10 @@ def close(self):

class ScipyBackendEntrypoint(BackendEntrypoint):
def guess_can_open(self, filename_or_obj):
try:
return read_magic_number(filename_or_obj).startswith(b"CDF")
except TypeError:
pass

magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
if magic_number is not None:
return magic_number.startswith(b"CDF")

try:
_, ext = os.path.splitext(filename_or_obj)
Expand Down
28 changes: 26 additions & 2 deletions xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import functools
import io
import itertools
import os
import re
import warnings
from enum import Enum
Expand Down Expand Up @@ -652,7 +653,7 @@ def is_remote_uri(path: str) -> bool:
return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path))


def read_magic_number(filename_or_obj, count=8):
def read_magic_number_from_file(filename_or_obj, count=8) -> bytes:
# check byte header to determine file type
if isinstance(filename_or_obj, bytes):
magic_number = filename_or_obj[:count]
Expand All @@ -663,13 +664,36 @@ def read_magic_number(filename_or_obj, count=8):
"file-like object read/write pointer not at the start of the file, "
"please close and reopen, or use a context manager"
)
magic_number = filename_or_obj.read(count)
magic_number = filename_or_obj.read(count) # type: ignore
filename_or_obj.seek(0)
else:
raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}")
return magic_number


def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]:
if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"):
path = os.fspath(pathlike)
try:
with open(path, "rb") as f:
return read_magic_number_from_file(f, count)
except (FileNotFoundError, TypeError):
pass
return None


def try_read_magic_number_from_file_or_path(
filename_or_obj, count=8
) -> Optional[bytes]:
magic_number = try_read_magic_number_from_path(filename_or_obj, count)
if magic_number is None:
try:
magic_number = read_magic_number_from_file(filename_or_obj, count)
except TypeError:
pass
return magic_number


def is_uniform_spaced(arr, **kwargs) -> bool:
"""Return True if values of an array are uniformly spaced and sorted.

Expand Down
84 changes: 83 additions & 1 deletion xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,14 @@
save_mfdataset,
)
from xarray.backends.common import robust_getitem
from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint
from xarray.backends.netcdf3 import _nc3_dtype_coercions
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
from xarray.backends.netCDF4_ import (
NetCDF4BackendEntrypoint,
_extract_nc4_variable_encoding,
)
from xarray.backends.pydap_ import PydapDataStore
from xarray.backends.scipy_ import ScipyBackendEntrypoint
from xarray.coding.variables import SerializationWarning
from xarray.conventions import encode_dataset_coordinates
from xarray.core import indexes, indexing
Expand Down Expand Up @@ -5161,3 +5166,80 @@ def test_chunking_consintency(chunks, tmp_path):

with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual:
xr.testing.assert_chunks_equal(actual, expected)


def _check_guess_can_open_and_open(entrypoint, obj, engine, expected):
assert entrypoint.guess_can_open(obj)
with open_dataset(obj, engine=engine) as actual:
assert_identical(expected, actual)


@requires_netCDF4
def test_netcdf4_entrypoint(tmp_path):
entrypoint = NetCDF4BackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, format="netcdf3_classic")
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)

path = tmp_path / "bar"
ds.to_netcdf(path, format="netcdf4_classic")
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)

assert entrypoint.guess_can_open("http://something/remote")
assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc4")
assert entrypoint.guess_can_open("something-local.cdf")
assert not entrypoint.guess_can_open("not-found-and-no-extension")

path = tmp_path / "baz"
with open(path, "wb") as f:
f.write(b"not-a-netcdf-file")
assert not entrypoint.guess_can_open(path)


@requires_scipy
def test_scipy_entrypoint(tmp_path):
entrypoint = ScipyBackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, engine="scipy")
_check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
_check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
with open(path, "rb") as f:
_check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds)

contents = ds.to_netcdf(engine="scipy")
_check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds)
_check_guess_can_open_and_open(
entrypoint, BytesIO(contents), engine="scipy", expected=ds
)

assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc.gz")
assert not entrypoint.guess_can_open("not-found-and-no-extension")
assert not entrypoint.guess_can_open(b"not-a-netcdf-file")


@requires_h5netcdf
def test_h5netcdf_entrypoint(tmp_path):
entrypoint = H5netcdfBackendEntrypoint()
ds = create_test_data()

path = tmp_path / "foo"
ds.to_netcdf(path, engine="h5netcdf")
_check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds)
_check_guess_can_open_and_open(
entrypoint, str(path), engine="h5netcdf", expected=ds
)
with open(path, "rb") as f:
_check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds)

assert entrypoint.guess_can_open("something-local.nc")
assert entrypoint.guess_can_open("something-local.nc4")
assert entrypoint.guess_can_open("something-local.cdf")
assert not entrypoint.guess_can_open("not-found-and-no-extension")