Skip to content

Commit c6f2cf0

Browse files
authored
More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints (#5296)
* More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints The new version check magic numbers in files on disk, not just already open file objects. I've also added a bunch of unit-tests. Fixes GH5295 * Fix failures and warning in test_backends.py * format black
1 parent 9e84d09 commit c6f2cf0

File tree

6 files changed

+162
-20
lines changed

6 files changed

+162
-20
lines changed

doc/whats-new.rst

+4
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ Deprecations
4141
Bug fixes
4242
~~~~~~~~~
4343

44+
- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying
45+
an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in
46+
0.18.0.
47+
By `Stephan Hoyer <https://github.com/shoyer>`_
4448

4549
Documentation
4650
~~~~~~~~~~~~~

xarray/backends/h5netcdf_.py

+11-7
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,12 @@
66
import numpy as np
77

88
from ..core import indexing
9-
from ..core.utils import FrozenDict, is_remote_uri, read_magic_number
9+
from ..core.utils import (
10+
FrozenDict,
11+
is_remote_uri,
12+
read_magic_number_from_file,
13+
try_read_magic_number_from_file_or_path,
14+
)
1015
from ..core.variable import Variable
1116
from .common import (
1217
BACKEND_ENTRYPOINTS,
@@ -140,10 +145,10 @@ def open(
140145
"try passing a path or file-like object"
141146
)
142147
elif isinstance(filename, io.IOBase):
143-
magic_number = read_magic_number(filename)
148+
magic_number = read_magic_number_from_file(filename)
144149
if not magic_number.startswith(b"\211HDF\r\n\032\n"):
145150
raise ValueError(
146-
f"{magic_number} is not the signature of a valid netCDF file"
151+
f"{magic_number} is not the signature of a valid netCDF4 file"
147152
)
148153

149154
if format not in [None, "NETCDF4"]:
@@ -333,10 +338,9 @@ def close(self, **kwargs):
333338

334339
class H5netcdfBackendEntrypoint(BackendEntrypoint):
335340
def guess_can_open(self, filename_or_obj):
336-
try:
337-
return read_magic_number(filename_or_obj).startswith(b"\211HDF\r\n\032\n")
338-
except TypeError:
339-
pass
341+
magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
342+
if magic_number is not None:
343+
return magic_number.startswith(b"\211HDF\r\n\032\n")
340344

341345
try:
342346
_, ext = os.path.splitext(filename_or_obj)

xarray/backends/netCDF4_.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
from .. import coding
1010
from ..coding.variables import pop_to
1111
from ..core import indexing
12-
from ..core.utils import FrozenDict, close_on_error, is_remote_uri
12+
from ..core.utils import (
13+
FrozenDict,
14+
close_on_error,
15+
is_remote_uri,
16+
try_read_magic_number_from_path,
17+
)
1318
from ..core.variable import Variable
1419
from .common import (
1520
BACKEND_ENTRYPOINTS,
@@ -510,6 +515,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint):
510515
def guess_can_open(self, filename_or_obj):
511516
if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj):
512517
return True
518+
magic_number = try_read_magic_number_from_path(filename_or_obj)
519+
if magic_number is not None:
520+
# netcdf 3 or HDF5
521+
return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n"))
513522
try:
514523
_, ext = os.path.splitext(filename_or_obj)
515524
except TypeError:

xarray/backends/scipy_.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
1+
import gzip
12
import io
23
import os
34

45
import numpy as np
56

67
from ..core.indexing import NumpyIndexingAdapter
7-
from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number
8+
from ..core.utils import (
9+
Frozen,
10+
FrozenDict,
11+
close_on_error,
12+
try_read_magic_number_from_file_or_path,
13+
)
814
from ..core.variable import Variable
915
from .common import (
1016
BACKEND_ENTRYPOINTS,
@@ -72,8 +78,6 @@ def __setitem__(self, key, value):
7278

7379

7480
def _open_scipy_netcdf(filename, mode, mmap, version):
75-
import gzip
76-
7781
# if the string ends with .gz, then gunzip and open as netcdf file
7882
if isinstance(filename, str) and filename.endswith(".gz"):
7983
try:
@@ -235,10 +239,13 @@ def close(self):
235239

236240
class ScipyBackendEntrypoint(BackendEntrypoint):
237241
def guess_can_open(self, filename_or_obj):
238-
try:
239-
return read_magic_number(filename_or_obj).startswith(b"CDF")
240-
except TypeError:
241-
pass
242+
243+
magic_number = try_read_magic_number_from_file_or_path(filename_or_obj)
244+
if magic_number is not None and magic_number.startswith(b"\x1f\x8b"):
245+
with gzip.open(filename_or_obj) as f:
246+
magic_number = try_read_magic_number_from_file_or_path(f)
247+
if magic_number is not None:
248+
return magic_number.startswith(b"CDF")
242249

243250
try:
244251
_, ext = os.path.splitext(filename_or_obj)

xarray/core/utils.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import functools
55
import io
66
import itertools
7+
import os
78
import re
89
import warnings
910
from enum import Enum
@@ -646,7 +647,7 @@ def is_remote_uri(path: str) -> bool:
646647
return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path))
647648

648649

649-
def read_magic_number(filename_or_obj, count=8):
650+
def read_magic_number_from_file(filename_or_obj, count=8) -> bytes:
650651
# check byte header to determine file type
651652
if isinstance(filename_or_obj, bytes):
652653
magic_number = filename_or_obj[:count]
@@ -657,13 +658,36 @@ def read_magic_number(filename_or_obj, count=8):
657658
"file-like object read/write pointer not at the start of the file, "
658659
"please close and reopen, or use a context manager"
659660
)
660-
magic_number = filename_or_obj.read(count)
661+
magic_number = filename_or_obj.read(count) # type: ignore
661662
filename_or_obj.seek(0)
662663
else:
663664
raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}")
664665
return magic_number
665666

666667

668+
def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]:
669+
if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"):
670+
path = os.fspath(pathlike)
671+
try:
672+
with open(path, "rb") as f:
673+
return read_magic_number_from_file(f, count)
674+
except (FileNotFoundError, TypeError):
675+
pass
676+
return None
677+
678+
679+
def try_read_magic_number_from_file_or_path(
680+
filename_or_obj, count=8
681+
) -> Optional[bytes]:
682+
magic_number = try_read_magic_number_from_path(filename_or_obj, count)
683+
if magic_number is None:
684+
try:
685+
magic_number = read_magic_number_from_file(filename_or_obj, count)
686+
except TypeError:
687+
pass
688+
return magic_number
689+
690+
667691
def is_uniform_spaced(arr, **kwargs) -> bool:
668692
"""Return True if values of an array are uniformly spaced and sorted.
669693

xarray/tests/test_backends.py

+97-3
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import contextlib
2+
import gzip
23
import itertools
34
import math
45
import os.path
56
import pickle
7+
import re
68
import shutil
79
import sys
810
import tempfile
@@ -30,9 +32,14 @@
3032
save_mfdataset,
3133
)
3234
from xarray.backends.common import robust_getitem
35+
from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint
3336
from xarray.backends.netcdf3 import _nc3_dtype_coercions
34-
from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding
37+
from xarray.backends.netCDF4_ import (
38+
NetCDF4BackendEntrypoint,
39+
_extract_nc4_variable_encoding,
40+
)
3541
from xarray.backends.pydap_ import PydapDataStore
42+
from xarray.backends.scipy_ import ScipyBackendEntrypoint
3643
from xarray.coding.variables import SerializationWarning
3744
from xarray.conventions import encode_dataset_coordinates
3845
from xarray.core import indexes, indexing
@@ -2771,7 +2778,7 @@ def test_open_badbytes(self):
27712778
with open_dataset(b"garbage", engine="netcdf4"):
27722779
pass
27732780
with pytest.raises(
2774-
ValueError, match=r"not the signature of a valid netCDF file"
2781+
ValueError, match=r"not the signature of a valid netCDF4 file"
27752782
):
27762783
with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"):
27772784
pass
@@ -2817,7 +2824,11 @@ def test_open_fileobj(self):
28172824
with open(tmp_file, "rb") as f:
28182825
f.seek(8)
28192826
with pytest.raises(ValueError, match="cannot guess the engine"):
2820-
open_dataset(f)
2827+
with pytest.warns(
2828+
RuntimeWarning,
2829+
match=re.escape("'h5netcdf' fails while guessing"),
2830+
):
2831+
open_dataset(f)
28212832

28222833

28232834
@requires_h5netcdf
@@ -5161,3 +5172,86 @@ def test_chunking_consintency(chunks, tmp_path):
51615172

51625173
with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual:
51635174
xr.testing.assert_chunks_equal(actual, expected)
5175+
5176+
5177+
def _check_guess_can_open_and_open(entrypoint, obj, engine, expected):
5178+
assert entrypoint.guess_can_open(obj)
5179+
with open_dataset(obj, engine=engine) as actual:
5180+
assert_identical(expected, actual)
5181+
5182+
5183+
@requires_netCDF4
5184+
def test_netcdf4_entrypoint(tmp_path):
5185+
entrypoint = NetCDF4BackendEntrypoint()
5186+
ds = create_test_data()
5187+
5188+
path = tmp_path / "foo"
5189+
ds.to_netcdf(path, format="netcdf3_classic")
5190+
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
5191+
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)
5192+
5193+
path = tmp_path / "bar"
5194+
ds.to_netcdf(path, format="netcdf4_classic")
5195+
_check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds)
5196+
_check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds)
5197+
5198+
assert entrypoint.guess_can_open("http://something/remote")
5199+
assert entrypoint.guess_can_open("something-local.nc")
5200+
assert entrypoint.guess_can_open("something-local.nc4")
5201+
assert entrypoint.guess_can_open("something-local.cdf")
5202+
assert not entrypoint.guess_can_open("not-found-and-no-extension")
5203+
5204+
path = tmp_path / "baz"
5205+
with open(path, "wb") as f:
5206+
f.write(b"not-a-netcdf-file")
5207+
assert not entrypoint.guess_can_open(path)
5208+
5209+
5210+
@requires_scipy
5211+
def test_scipy_entrypoint(tmp_path):
5212+
entrypoint = ScipyBackendEntrypoint()
5213+
ds = create_test_data()
5214+
5215+
path = tmp_path / "foo"
5216+
ds.to_netcdf(path, engine="scipy")
5217+
_check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
5218+
_check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
5219+
with open(path, "rb") as f:
5220+
_check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds)
5221+
5222+
contents = ds.to_netcdf(engine="scipy")
5223+
_check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds)
5224+
_check_guess_can_open_and_open(
5225+
entrypoint, BytesIO(contents), engine="scipy", expected=ds
5226+
)
5227+
5228+
path = tmp_path / "foo.nc.gz"
5229+
with gzip.open(path, mode="wb") as f:
5230+
f.write(contents)
5231+
_check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds)
5232+
_check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)
5233+
5234+
assert entrypoint.guess_can_open("something-local.nc")
5235+
assert entrypoint.guess_can_open("something-local.nc.gz")
5236+
assert not entrypoint.guess_can_open("not-found-and-no-extension")
5237+
assert not entrypoint.guess_can_open(b"not-a-netcdf-file")
5238+
5239+
5240+
@requires_h5netcdf
5241+
def test_h5netcdf_entrypoint(tmp_path):
5242+
entrypoint = H5netcdfBackendEntrypoint()
5243+
ds = create_test_data()
5244+
5245+
path = tmp_path / "foo"
5246+
ds.to_netcdf(path, engine="h5netcdf")
5247+
_check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds)
5248+
_check_guess_can_open_and_open(
5249+
entrypoint, str(path), engine="h5netcdf", expected=ds
5250+
)
5251+
with open(path, "rb") as f:
5252+
_check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds)
5253+
5254+
assert entrypoint.guess_can_open("something-local.nc")
5255+
assert entrypoint.guess_can_open("something-local.nc4")
5256+
assert entrypoint.guess_can_open("something-local.cdf")
5257+
assert not entrypoint.guess_can_open("not-found-and-no-extension")

0 commit comments

Comments
 (0)