From 19f2b8f8e6e6b1a3e0e4b62f27cb652ebef32bad Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 12 May 2021 16:34:48 -0700 Subject: [PATCH 1/3] More robust guess_can_open for netCDF4/scipy/h5netcdf entrypoints The new version check magic numbers in files on disk, not just already open file objects. I've also added a bunch of unit-tests. Fixes GH5295 --- doc/whats-new.rst | 4 ++ xarray/backends/h5netcdf_.py | 18 +++++--- xarray/backends/netCDF4_.py | 11 ++++- xarray/backends/scipy_.py | 15 ++++--- xarray/core/utils.py | 28 +++++++++++- xarray/tests/test_backends.py | 84 ++++++++++++++++++++++++++++++++++- 6 files changed, 144 insertions(+), 16 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3f81678b8d5..1465c5a0955 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -34,6 +34,10 @@ Deprecations Bug fixes ~~~~~~~~~ +- Opening netCDF files from a path that doesn't end in ``.nc`` without supplying + an explicit ``engine`` works again (:issue:`5295`), fixing a bug introduced in + 0.18.0. + By `Stephan Hoyer `_ Documentation ~~~~~~~~~~~~~ diff --git a/xarray/backends/h5netcdf_.py b/xarray/backends/h5netcdf_.py index 84e89f80dae..fd5368c476b 100644 --- a/xarray/backends/h5netcdf_.py +++ b/xarray/backends/h5netcdf_.py @@ -6,7 +6,12 @@ import numpy as np from ..core import indexing -from ..core.utils import FrozenDict, is_remote_uri, read_magic_number +from ..core.utils import ( + FrozenDict, + is_remote_uri, + read_magic_number_from_file, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -141,10 +146,10 @@ def open( "try passing a path or file-like object" ) elif isinstance(filename, io.IOBase): - magic_number = read_magic_number(filename) + magic_number = read_magic_number_from_file(filename) if not magic_number.startswith(b"\211HDF\r\n\032\n"): raise ValueError( - f"{magic_number} is not the signature of a valid netCDF file" + f"{magic_number} is not the signature of a valid netCDF4 file" ) if format not in [None, "NETCDF4"]: @@ -336,10 +341,9 @@ def close(self, **kwargs): class H5netcdfBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): - try: - return read_magic_number(filename_or_obj).startswith(b"\211HDF\r\n\032\n") - except TypeError: - pass + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None: + return magic_number.startswith(b"\211HDF\r\n\032\n") try: _, ext = os.path.splitext(filename_or_obj) diff --git a/xarray/backends/netCDF4_.py b/xarray/backends/netCDF4_.py index a60c940c3c4..c26fab788a8 100644 --- a/xarray/backends/netCDF4_.py +++ b/xarray/backends/netCDF4_.py @@ -9,7 +9,12 @@ from .. import coding from ..coding.variables import pop_to from ..core import indexing -from ..core.utils import FrozenDict, close_on_error, is_remote_uri +from ..core.utils import ( + FrozenDict, + close_on_error, + is_remote_uri, + try_read_magic_number_from_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -517,6 +522,10 @@ class NetCDF4BackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): if isinstance(filename_or_obj, str) and is_remote_uri(filename_or_obj): return True + magic_number = try_read_magic_number_from_path(filename_or_obj) + if magic_number is not None: + # netcdf 3 or HDF5 + return magic_number.startswith((b"CDF", b"\211HDF\r\n\032\n")) try: _, ext = os.path.splitext(filename_or_obj) except TypeError: diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index c27716ea44d..0ac84e2e347 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -4,7 +4,12 @@ import numpy as np from ..core.indexing import NumpyIndexingAdapter -from ..core.utils import Frozen, FrozenDict, close_on_error, read_magic_number +from ..core.utils import ( + Frozen, + FrozenDict, + close_on_error, + try_read_magic_number_from_file_or_path, +) from ..core.variable import Variable from .common import ( BACKEND_ENTRYPOINTS, @@ -235,10 +240,10 @@ def close(self): class ScipyBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): - try: - return read_magic_number(filename_or_obj).startswith(b"CDF") - except TypeError: - pass + + magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None: + return magic_number.startswith(b"CDF") try: _, ext = os.path.splitext(filename_or_obj) diff --git a/xarray/core/utils.py b/xarray/core/utils.py index d3b4cd39c53..62423ce0132 100644 --- a/xarray/core/utils.py +++ b/xarray/core/utils.py @@ -4,6 +4,7 @@ import functools import io import itertools +import os import re import warnings from enum import Enum @@ -652,7 +653,7 @@ def is_remote_uri(path: str) -> bool: return bool(re.search(r"^[a-z][a-z0-9]*(\://|\:\:)", path)) -def read_magic_number(filename_or_obj, count=8): +def read_magic_number_from_file(filename_or_obj, count=8) -> bytes: # check byte header to determine file type if isinstance(filename_or_obj, bytes): magic_number = filename_or_obj[:count] @@ -663,13 +664,36 @@ def read_magic_number(filename_or_obj, count=8): "file-like object read/write pointer not at the start of the file, " "please close and reopen, or use a context manager" ) - magic_number = filename_or_obj.read(count) + magic_number = filename_or_obj.read(count) # type: ignore filename_or_obj.seek(0) else: raise TypeError(f"cannot read the magic number form {type(filename_or_obj)}") return magic_number +def try_read_magic_number_from_path(pathlike, count=8) -> Optional[bytes]: + if isinstance(pathlike, str) or hasattr(pathlike, "__fspath__"): + path = os.fspath(pathlike) + try: + with open(path, "rb") as f: + return read_magic_number_from_file(f, count) + except (FileNotFoundError, TypeError): + pass + return None + + +def try_read_magic_number_from_file_or_path( + filename_or_obj, count=8 +) -> Optional[bytes]: + magic_number = try_read_magic_number_from_path(filename_or_obj, count) + if magic_number is None: + try: + magic_number = read_magic_number_from_file(filename_or_obj, count) + except TypeError: + pass + return magic_number + + def is_uniform_spaced(arr, **kwargs) -> bool: """Return True if values of an array are uniformly spaced and sorted. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 3e3d6e8b8d0..6199d9e554a 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -30,9 +30,14 @@ save_mfdataset, ) from xarray.backends.common import robust_getitem +from xarray.backends.h5netcdf_ import H5netcdfBackendEntrypoint from xarray.backends.netcdf3 import _nc3_dtype_coercions -from xarray.backends.netCDF4_ import _extract_nc4_variable_encoding +from xarray.backends.netCDF4_ import ( + NetCDF4BackendEntrypoint, + _extract_nc4_variable_encoding, +) from xarray.backends.pydap_ import PydapDataStore +from xarray.backends.scipy_ import ScipyBackendEntrypoint from xarray.coding.variables import SerializationWarning from xarray.conventions import encode_dataset_coordinates from xarray.core import indexes, indexing @@ -5161,3 +5166,80 @@ def test_chunking_consintency(chunks, tmp_path): with xr.open_dataset(tmp_path / "test.nc", chunks=chunks) as actual: xr.testing.assert_chunks_equal(actual, expected) + + +def _check_guess_can_open_and_open(entrypoint, obj, engine, expected): + assert entrypoint.guess_can_open(obj) + with open_dataset(obj, engine=engine) as actual: + assert_identical(expected, actual) + + +@requires_netCDF4 +def test_netcdf4_entrypoint(tmp_path): + entrypoint = NetCDF4BackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, format="netcdf3_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + path = tmp_path / "bar" + ds.to_netcdf(path, format="netcdf4_classic") + _check_guess_can_open_and_open(entrypoint, path, engine="netcdf4", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="netcdf4", expected=ds) + + assert entrypoint.guess_can_open("http://something/remote") + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + + path = tmp_path / "baz" + with open(path, "wb") as f: + f.write(b"not-a-netcdf-file") + assert not entrypoint.guess_can_open(path) + + +@requires_scipy +def test_scipy_entrypoint(tmp_path): + entrypoint = ScipyBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="scipy") + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="scipy", expected=ds) + + contents = ds.to_netcdf(engine="scipy") + _check_guess_can_open_and_open(entrypoint, contents, engine="scipy", expected=ds) + _check_guess_can_open_and_open( + entrypoint, BytesIO(contents), engine="scipy", expected=ds + ) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc.gz") + assert not entrypoint.guess_can_open("not-found-and-no-extension") + assert not entrypoint.guess_can_open(b"not-a-netcdf-file") + + +@requires_h5netcdf +def test_h5netcdf_entrypoint(tmp_path): + entrypoint = H5netcdfBackendEntrypoint() + ds = create_test_data() + + path = tmp_path / "foo" + ds.to_netcdf(path, engine="h5netcdf") + _check_guess_can_open_and_open(entrypoint, path, engine="h5netcdf", expected=ds) + _check_guess_can_open_and_open( + entrypoint, str(path), engine="h5netcdf", expected=ds + ) + with open(path, "rb") as f: + _check_guess_can_open_and_open(entrypoint, f, engine="h5netcdf", expected=ds) + + assert entrypoint.guess_can_open("something-local.nc") + assert entrypoint.guess_can_open("something-local.nc4") + assert entrypoint.guess_can_open("something-local.cdf") + assert not entrypoint.guess_can_open("not-found-and-no-extension") From d8ca03859bdbd9d605dd26611e834f723d7c04dd Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 12 May 2021 17:25:54 -0700 Subject: [PATCH 2/3] Fix failures and warning in test_backends.py --- xarray/backends/scipy_.py | 6 ++++-- xarray/tests/test_backends.py | 13 +++++++++++-- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index 0ac84e2e347..a40a28f2ed6 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -1,3 +1,4 @@ +import gzip import io import os @@ -77,8 +78,6 @@ def __setitem__(self, key, value): def _open_scipy_netcdf(filename, mode, mmap, version): - import gzip - # if the string ends with .gz, then gunzip and open as netcdf file if isinstance(filename, str) and filename.endswith(".gz"): try: @@ -242,6 +241,9 @@ class ScipyBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) + if magic_number is not None and magic_number.startswith(b'\x1f\x8b'): + with gzip.open(filename_or_obj) as f: + magic_number = try_read_magic_number_from_file_or_path(f) if magic_number is not None: return magic_number.startswith(b"CDF") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 6199d9e554a..b435f76ae25 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -1,8 +1,10 @@ import contextlib +import gzip import itertools import math import os.path import pickle +import re import shutil import sys import tempfile @@ -2776,7 +2778,7 @@ def test_open_badbytes(self): with open_dataset(b"garbage", engine="netcdf4"): pass with pytest.raises( - ValueError, match=r"not the signature of a valid netCDF file" + ValueError, match=r"not the signature of a valid netCDF4 file" ): with open_dataset(BytesIO(b"garbage"), engine="h5netcdf"): pass @@ -2822,7 +2824,8 @@ def test_open_fileobj(self): with open(tmp_file, "rb") as f: f.seek(8) with pytest.raises(ValueError, match="cannot guess the engine"): - open_dataset(f) + with pytest.warns(RuntimeWarning, match=re.escape("'h5netcdf' fails while guessing")): + open_dataset(f) @requires_h5netcdf @@ -5219,6 +5222,12 @@ def test_scipy_entrypoint(tmp_path): entrypoint, BytesIO(contents), engine="scipy", expected=ds ) + path = tmp_path / "foo.nc.gz" + with gzip.open(path, mode='wb') as f: + f.write(contents) + _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) + _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds) + assert entrypoint.guess_can_open("something-local.nc") assert entrypoint.guess_can_open("something-local.nc.gz") assert not entrypoint.guess_can_open("not-found-and-no-extension") From 90b5f114d47b5dba86cab3fa975a73ede9f6ba64 Mon Sep 17 00:00:00 2001 From: Stephan Hoyer Date: Wed, 12 May 2021 18:37:12 -0700 Subject: [PATCH 3/3] format black --- xarray/backends/scipy_.py | 2 +- xarray/tests/test_backends.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/xarray/backends/scipy_.py b/xarray/backends/scipy_.py index a40a28f2ed6..3a10c56de3e 100644 --- a/xarray/backends/scipy_.py +++ b/xarray/backends/scipy_.py @@ -241,7 +241,7 @@ class ScipyBackendEntrypoint(BackendEntrypoint): def guess_can_open(self, filename_or_obj): magic_number = try_read_magic_number_from_file_or_path(filename_or_obj) - if magic_number is not None and magic_number.startswith(b'\x1f\x8b'): + if magic_number is not None and magic_number.startswith(b"\x1f\x8b"): with gzip.open(filename_or_obj) as f: magic_number = try_read_magic_number_from_file_or_path(f) if magic_number is not None: diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index b435f76ae25..60eb5b924ca 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -2824,7 +2824,10 @@ def test_open_fileobj(self): with open(tmp_file, "rb") as f: f.seek(8) with pytest.raises(ValueError, match="cannot guess the engine"): - with pytest.warns(RuntimeWarning, match=re.escape("'h5netcdf' fails while guessing")): + with pytest.warns( + RuntimeWarning, + match=re.escape("'h5netcdf' fails while guessing"), + ): open_dataset(f) @@ -5223,7 +5226,7 @@ def test_scipy_entrypoint(tmp_path): ) path = tmp_path / "foo.nc.gz" - with gzip.open(path, mode='wb') as f: + with gzip.open(path, mode="wb") as f: f.write(contents) _check_guess_can_open_and_open(entrypoint, path, engine="scipy", expected=ds) _check_guess_can_open_and_open(entrypoint, str(path), engine="scipy", expected=ds)