From 020e9d8d2bc0a61e14010c03939470f7763d64b9 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Thu, 25 Jan 2024 16:41:23 +0000 Subject: [PATCH 01/18] Generic array based usage --- earthkit/data/core/fieldlist.py | 39 ++++- earthkit/data/sources/array_list.py | 225 ++++++++++++++++++++++++++++ 2 files changed, 262 insertions(+), 2 deletions(-) create mode 100644 earthkit/data/sources/array_list.py diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index e125a5aa..21d2d7aa 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -17,11 +17,32 @@ from earthkit.data.utils.metadata import metadata_argument +class ArrayMaker: + def to_numpy(self): + pass + + +class PytorchArrayMaker: + def from_numpy(self, v): + import torch + + return torch.from_numpy(v) + + class Field(Base): r"""Represents a Field.""" def __init__(self, metadata=None): self.__metadata = metadata + self.__array_ns = None + + @property + def _array_ns(self): + if self.__array_ns is None: + import array_api_compat + + self.__array_ns = array_api_compat.array_namespace(self.values) + return self.__array_ns @abstractmethod def _values(self, dtype=None): @@ -48,7 +69,7 @@ def values(self): v = self._values() if len(v.shape) != 1: n = math.prod(v.shape) - return v.reshape(n) + return self._array_ns.reshape(v, n) return v def _make_metadata(self): @@ -83,9 +104,16 @@ def to_numpy(self, flatten=False, dtype=None): v = self._values(dtype=dtype) shape = self._required_shape(flatten) if shape != v.shape: - return v.reshape(shape) + return self._array_ns.reshape(v, shape) return v + def to_array(self, flatten=False, dtype=None, backend="numpy"): + if backend == "pytorch": + v = self.to_numpy() + import torch + + return torch.from_numpy(v) + def _required_shape(self, flatten): return self.shape if not flatten else (math.prod(self.shape),) @@ -684,6 +712,13 @@ def to_numpy(self, **kwargs): return np.array([f.to_numpy(**kwargs) for f in self]) + def to_array(self, backend="cupy"): + import array_api_compat + + x = [f.to_array(backend=backend) for f in self] + xp = array_api_compat.array_namespace(x[0]) + return xp.stack(x) + @property def values(self): r"""ndarray: Get the field values as a 2D ndarray. It is formed as the array of diff --git a/earthkit/data/sources/array_list.py b/earthkit/data/sources/array_list.py new file mode 100644 index 00000000..2d957879 --- /dev/null +++ b/earthkit/data/sources/array_list.py @@ -0,0 +1,225 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging + +import numpy as np + +from earthkit.data.core.fieldlist import Field, FieldList +from earthkit.data.core.index import MaskIndex, MultiIndex +from earthkit.data.readers.grib.pandas import PandasMixIn +from earthkit.data.readers.grib.xarray import XarrayMixIn + +LOG = logging.getLogger(__name__) + + +class ArrayField(Field): + r"""Represent a field consisting of an ndarray and metadata object. + + Parameters + ---------- + array: ndarray + Array storing the values of the field + metadata: :class:`Metadata` + Metadata object describing the field metadata. + """ + + def __init__(self, array, metadata): + self._array = array + super().__init__(metadata=metadata) + import array_api_compat + + self.__array_ns = array_api_compat.array_namespace(self._array) + + def _make_metadata(self): + pass + + def _values(self, dtype=None): + if dtype is None: + return self._array + else: + return self._array_ns.astype(self._array, dtype, copy=False) + + def __repr__(self): + return f"{self.__class__.__name__}()" + + def write(self, f, **kwargs): + r"""Write the field to a file object. + + Parameters + ---------- + f: file object + The target file object. + **kwargs: dict, optional + Other keyword arguments passed to :meth:`data.writers.grib.GribWriter.write`. + """ + from earthkit.data.writers import write + + write(f, self.to_numpy(flatten=True), self._metadata, **kwargs) + # write(f, self.values, self._metadata, **kwargs) + + +class ArrayFieldListCore(PandasMixIn, XarrayMixIn, FieldList): + def __init__(self, array, metadata, *args, **kwargs): + self._array = array + self._metadata = metadata + + if not isinstance(self._metadata, list): + self._metadata = [self._metadata] + + if isinstance(self._array, np.ndarray): + if self._array.shape[0] != len(self._metadata): + # we have a single array and a single metadata + if len(self._metadata) == 1 and self._shape_match( + self._array.shape, self._metadata[0].geography.shape() + ): + self._array = np.array([self._array]) + else: + raise ValueError( + ( + f"first array dimension ({self._array.shape[0]}) differs " + f"from number of metadata objects ({len(self._metadata)})" + ) + ) + elif isinstance(self._array, list): + if len(self._array) != len(self._metadata): + raise ValueError( + ( + f"array len ({len(self._array)}) differs " + f"from number of metadata objects ({len(self._metadata)})" + ) + ) + + for i, a in enumerate(self._array): + if not isinstance(a, np.ndarray): + raise ValueError( + f"All array element must be an ndarray. Type at position={i} is {type(a)}" + ) + + else: + raise TypeError("array must be an ndarray or a list of ndarrays") + + # hide internal metadata related to values + self._metadata = [md._hide_internal_keys() for md in self._metadata] + + super().__init__(*args, **kwargs) + + def _shape_match(self, shape1, shape2): + if shape1 == shape2: + return True + if len(shape1) == 1 and shape1[0] == np.prod(shape2): + return True + return False + + @classmethod + def new_mask_index(self, *args, **kwargs): + return ArrayMaskFieldList(*args, **kwargs) + + @classmethod + def merge(cls, sources): + assert all(isinstance(_, ArrayFieldListCore) for _ in sources) + merger = ListMerger(sources) + # merger = MultiUnwindMerger(sources) + return merger.to_fieldlist() + + def __repr__(self): + return f"{self.__class__.__name__}(fields={len(self)})" + + def _to_numpy_fieldlist(self, **kwargs): + if self[0]._array_matches(self._array[0], **kwargs): + return self + else: + return type(self)(self.to_numpy(**kwargs), self._metadata) + + def save(self, filename, append=False, check_nans=True, bits_per_value=16): + r"""Write all the fields into a file. + + Parameters + ---------- + filename: str + The target file path. + append: bool + When it is true append data to the target file. Otherwise + the target file be overwritten if already exists. + check_nans: bool + Replace nans in the values with GRIB missing values when generating the output. + bits_per_value: int + Set the ``bitsPerValue`` GRIB key in the generated output. + """ + super().save( + filename, + append=append, + check_nans=check_nans, + bits_per_value=bits_per_value, + ) + + +class MultiUnwindMerger: + def __init__(self, sources): + self.sources = list(self._flatten(sources)) + + def _flatten(self, sources): + if isinstance(sources, ArrayMultiFieldList): + for s in sources.indexes: + yield from self._flatten(s) + elif isinstance(sources, list): + for s in sources: + yield from self._flatten(s) + else: + yield sources + + def to_fieldlist(self): + return ArrayMultiFieldList(self.sources) + + +class ListMerger: + def __init__(self, sources): + self.sources = sources + + def to_fieldlist(self): + array = [] + metadata = [] + for s in self.sources: + for f in s: + array.append(f._array) + metadata.append(f._metadata) + return ArrayFieldList(array, metadata) + + +class ArrayFieldList(ArrayFieldListCore): + r"""Represent a list of :obj:`NumpyField `\ s. + + The preferred way to create a NumpyFieldList is to use either the + static :obj:`from_numpy` method or the :obj:`to_fieldlist` method. + + See Also + -------- + from_numpy + to_fieldlist + + """ + + def _getitem(self, n): + if isinstance(n, int): + return ArrayField(self._array[n], self._metadata[n]) + + def __len__(self): + return ( + len(self._array) if isinstance(self._array, list) else self._array.shape[0] + ) + + +class ArrayMaskFieldList(ArrayFieldListCore, MaskIndex): + def __init__(self, *args, **kwargs): + MaskIndex.__init__(self, *args, **kwargs) + + +class ArrayMultiFieldList(ArrayFieldListCore, MultiIndex): + def __init__(self, *args, **kwargs): + MultiIndex.__init__(self, *args, **kwargs) From 2f9d3c1e562f65f4b6fb0156c76b7c2af9191b11 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 6 Feb 2024 21:16:55 +0000 Subject: [PATCH 02/18] Array backend --- earthkit/data/core/array.py | 80 +++++++++++++++++++++++++++++++++ earthkit/data/core/fieldlist.py | 61 ++++++++++++++++++++----- 2 files changed, 131 insertions(+), 10 deletions(-) create mode 100644 earthkit/data/core/array.py diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py new file mode 100644 index 00000000..3920ec95 --- /dev/null +++ b/earthkit/data/core/array.py @@ -0,0 +1,80 @@ +try: + import array_api_compat +except Exception: + array_api_compat = None + + +class ArrayBackend: + _array_ns = None + + def to_numpy(self): + pass + + def _get_backend(self, v): + for k, b in array_backends: + if b.match(v): + return b + + def mutate(self): + if self._array_ns is None: + return EmptyBackend() + return self + + +class EmptyBackend(ArrayBackend): + def match(self, v): + return False + + +class NumpyBackend(ArrayBackend): + def __init__(self): + import numpy as np + + self._array_ns = array_api_compat.array_namespace(np.ones(2)) + + def match(self, v): + import numpy as np + + return isinstance(v, np.ndarray) + + @property + def array_ns(self): + return self._array_ns + + def to_array(self, v, backend=None): + if backend is not None: + if backend == self: + return v + return backend.to_numpy(v) + # else: + # try: + # import array_api_compat + + # __array_ns = array_api_compat.array_namespace(v) + # return v + + def from_numpy(self, v): + return v + + def from_pytorch(self, v): + import torch + + return torch.to_numpy(v) + + +class PytorchBackend(ArrayBackend): + def __init__(self): + try: + import torch + + self._array_ns = array_api_compat.array_namespace(torch.ones(2)) + except Exception: + pass + + def match(self, v): + import numpy as np + + return isinstance(v, np.ndarray) + + +array_backends = {"numpy": NumpyBackend, "pytorch": PytorchBackend().mutate()} diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index 21d2d7aa..872baf54 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -12,38 +12,79 @@ from collections import defaultdict from earthkit.data.core import Base +from earthkit.data.core.array import NumpyBackend from earthkit.data.core.index import Index from earthkit.data.decorators import cached_method from earthkit.data.utils.metadata import metadata_argument +# class ArrayMaker: +# def to_numpy(self): +# pass -class ArrayMaker: - def to_numpy(self): - pass +# class NumpyBackend: +# def __init__(self): +# import numpy as np +# self._array_ns = np -class PytorchArrayMaker: - def from_numpy(self, v): - import torch +# @property +# def array_ns(self): +# return self._array_ns - return torch.from_numpy(v) +# def to_array(self, v): +# try: +# import array_api_compat + +# __array_ns = array_api_compat.array_namespace(v) +# return v + +# def from_numpy(self, v): +# return v + +# def from_pytorch(self, v): +# import torch + +# return torch.to_numpy(v) + + +# class PytorchArrayMaker: +# def from_numpy(self, v): +# import torch + +# return torch.from_numpy(v) + + +# array_backends = {"numpy": NumpyArrayMaker, "pytorch": PytorchArrayMaker} class Field(Base): r"""Represents a Field.""" - def __init__(self, metadata=None): + raw_backend = NumpyBackend + + def __init__(self, metadata=None, backend=None): self.__metadata = metadata self.__array_ns = None + if backend is None: + backend = "numpy" + # self.backend = array_backends[backend] @property def _array_ns(self): if self.__array_ns is None: - import array_api_compat + try: + import array_api_compat + + self.__array_ns = array_api_compat.array_namespace(self.values) + except ImportError: + import numpy as np - self.__array_ns = array_api_compat.array_namespace(self.values) + self.__array_ns = np return self.__array_ns + def _to_array(self, v): + return self.backend.to_array(v, self.raw_backend) + @abstractmethod def _values(self, dtype=None): r"""Return the values stored in the field as an ndarray. From facbd229f451e37c20fc82fa509a7c228d533e3d Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 12 Feb 2024 08:24:06 +0000 Subject: [PATCH 03/18] Array backend --- earthkit/data/core/array.py | 80 ------------ earthkit/data/core/array/__init__.py | 109 ++++++++++++++++ earthkit/data/core/array/numpy.py | 0 earthkit/data/core/array/pytorch.py | 0 earthkit/data/core/fieldlist.py | 130 +++++++------------ earthkit/data/readers/grib/codes.py | 4 +- earthkit/data/readers/grib/index/__init__.py | 7 +- earthkit/data/readers/grib/reader.py | 4 +- earthkit/data/sources/url.py | 2 - 9 files changed, 168 insertions(+), 168 deletions(-) delete mode 100644 earthkit/data/core/array.py create mode 100644 earthkit/data/core/array/__init__.py create mode 100644 earthkit/data/core/array/numpy.py create mode 100644 earthkit/data/core/array/pytorch.py diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py deleted file mode 100644 index 3920ec95..00000000 --- a/earthkit/data/core/array.py +++ /dev/null @@ -1,80 +0,0 @@ -try: - import array_api_compat -except Exception: - array_api_compat = None - - -class ArrayBackend: - _array_ns = None - - def to_numpy(self): - pass - - def _get_backend(self, v): - for k, b in array_backends: - if b.match(v): - return b - - def mutate(self): - if self._array_ns is None: - return EmptyBackend() - return self - - -class EmptyBackend(ArrayBackend): - def match(self, v): - return False - - -class NumpyBackend(ArrayBackend): - def __init__(self): - import numpy as np - - self._array_ns = array_api_compat.array_namespace(np.ones(2)) - - def match(self, v): - import numpy as np - - return isinstance(v, np.ndarray) - - @property - def array_ns(self): - return self._array_ns - - def to_array(self, v, backend=None): - if backend is not None: - if backend == self: - return v - return backend.to_numpy(v) - # else: - # try: - # import array_api_compat - - # __array_ns = array_api_compat.array_namespace(v) - # return v - - def from_numpy(self, v): - return v - - def from_pytorch(self, v): - import torch - - return torch.to_numpy(v) - - -class PytorchBackend(ArrayBackend): - def __init__(self): - try: - import torch - - self._array_ns = array_api_compat.array_namespace(torch.ones(2)) - except Exception: - pass - - def match(self, v): - import numpy as np - - return isinstance(v, np.ndarray) - - -array_backends = {"numpy": NumpyBackend, "pytorch": PytorchBackend().mutate()} diff --git a/earthkit/data/core/array/__init__.py b/earthkit/data/core/array/__init__.py new file mode 100644 index 00000000..b4e0a084 --- /dev/null +++ b/earthkit/data/core/array/__init__.py @@ -0,0 +1,109 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import threading +from abc import ABCMeta, abstractmethod + + +class ArrayBackendManager: + def __init__(self): + self.backends = {} + self.lock = threading.Lock() + + def find(self, name): + with self.lock: + b = self.backends.get(name, None) + if b is None: + b = array_backend_types[name]() + self.backends[name] = b + return b + + +MANAGER = ArrayBackendManager() + + +class ArrayBackend(metaclass=ABCMeta): + _array_ns = None + + @property + def array_ns(self): + return self._array_ns + + @staticmethod + def find(name): + return MANAGER.find(name) + + def to_array(self, v, backend=None): + if backend is not None: + if backend is self: + return v + + return backend.to_backend(self, v) + + @abstractmethod + def to_backend(self, backend, v): + pass + + @abstractmethod + def from_numpy(self, v): + pass + + @abstractmethod + def from_pytorch(self, v): + pass + + +class NumpyBackend(ArrayBackend): + def __init__(self): + import numpy as np + + try: + import array_api_compat + + self._array_ns = array_api_compat.array_namespace(np.ones(2)) + except Exception: + self._array_ns = np + + def to_backend(self, backend, v): + return backend.from_numpy(v) + + def from_numpy(self, v): + return v + + def from_pytorch(self, v): + import torch + + return torch.to_numpy(v) + + +class PytorchBackend(ArrayBackend): + def __init__(self): + try: + import array_api_compat + + except Exception: + raise ImportError("array_api_compat is required to use pytorch backend") + + import torch + + self._array_ns = array_api_compat.array_namespace(torch.ones(2)) + + def to_backend(self, backend, v): + return backend.from_pytorch(v) + + def from_numpy(self, v): + import torch + + return torch.from_numpy(v) + + def from_pytorch(self, v): + return v + + +array_backend_types = {"numpy": NumpyBackend, "pytorch": PytorchBackend} diff --git a/earthkit/data/core/array/numpy.py b/earthkit/data/core/array/numpy.py new file mode 100644 index 00000000..e69de29b diff --git a/earthkit/data/core/array/pytorch.py b/earthkit/data/core/array/pytorch.py new file mode 100644 index 00000000..e69de29b diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index 872baf54..bd998ea2 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -12,75 +12,20 @@ from collections import defaultdict from earthkit.data.core import Base -from earthkit.data.core.array import NumpyBackend +from earthkit.data.core.array import ArrayBackend from earthkit.data.core.index import Index from earthkit.data.decorators import cached_method from earthkit.data.utils.metadata import metadata_argument -# class ArrayMaker: -# def to_numpy(self): -# pass - - -# class NumpyBackend: -# def __init__(self): -# import numpy as np -# self._array_ns = np - -# @property -# def array_ns(self): -# return self._array_ns - -# def to_array(self, v): -# try: -# import array_api_compat - -# __array_ns = array_api_compat.array_namespace(v) -# return v - -# def from_numpy(self, v): -# return v - -# def from_pytorch(self, v): -# import torch - -# return torch.to_numpy(v) - - -# class PytorchArrayMaker: -# def from_numpy(self, v): -# import torch - -# return torch.from_numpy(v) - - -# array_backends = {"numpy": NumpyArrayMaker, "pytorch": PytorchArrayMaker} - class Field(Base): r"""Represents a Field.""" - raw_backend = NumpyBackend + raw_backend = ArrayBackend.find("numpy") - def __init__(self, metadata=None, backend=None): + def __init__(self, backend, metadata=None): self.__metadata = metadata - self.__array_ns = None - if backend is None: - backend = "numpy" - # self.backend = array_backends[backend] - - @property - def _array_ns(self): - if self.__array_ns is None: - try: - import array_api_compat - - self.__array_ns = array_api_compat.array_namespace(self.values) - except ImportError: - import numpy as np - - self.__array_ns = np - return self.__array_ns + self.backend = backend def _to_array(self, v): return self.backend.to_array(v, self.raw_backend) @@ -107,10 +52,10 @@ def _values(self, dtype=None): @property def values(self): r"""ndarray: Get the values stored in the field as a 1D ndarray.""" - v = self._values() + v = self._to_array(self._values()) if len(v.shape) != 1: n = math.prod(v.shape) - return self._array_ns.reshape(v, n) + return self.backend.array_ns.reshape(v, n) return v def _make_metadata(self): @@ -143,17 +88,35 @@ def to_numpy(self, flatten=False, dtype=None): """ v = self._values(dtype=dtype) + ArrayBackend.find("numpy").to_array(v, self.raw_backend) shape = self._required_shape(flatten) if shape != v.shape: - return self._array_ns.reshape(v, shape) + return v.reshape(shape) return v - def to_array(self, flatten=False, dtype=None, backend="numpy"): - if backend == "pytorch": - v = self.to_numpy() - import torch + def to_array(self, flatten=False, dtype=None): + r"""Return the values stored in the field as an ndarray. + + Parameters + ---------- + flatten: bool + When it is True a flat ndarray is returned. Otherwise an ndarray with the field's + :obj:`shape` is returned. + dtype: str, numpy.dtype or None + Typecode or data-type of the array. When it is :obj:`None` the default + type used by the underlying data accessor is used. For GRIB it is ``np.float64``. + + Returns + ------- + ndarray + Field values - return torch.from_numpy(v) + """ + v = self._to_array(self._values(dtype=dtype)) + shape = self._required_shape(flatten) + if shape != v.shape: + return self.backend.array_ns.reshape(v, shape) + return v def _required_shape(self, flatten): return self.shape if not flatten else (math.prod(self.shape),) @@ -228,17 +191,19 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): if k not in _keys: raise ValueError(f"data: invalid argument: {k}") - r = [_keys[k](dtype=dtype) for k in keys] + r = [self._to_array(_keys[k](dtype=dtype)) for k in keys] shape = self._required_shape(flatten) if shape != r[0].shape: - r = [x.reshape(shape) for x in r] + # r = [x.reshape(shape) for x in r] + r = [self.backend.array_ns.reshape(x, shape) for x in r] if len(r) == 1: return r[0] else: - import numpy as np + return self.backend.array_ns.stack(r) + # import numpy as np - return np.array(r) + # return np.array(r) def to_points(self, flatten=False, dtype=None): r"""Return the geographical coordinates in the data's original @@ -608,7 +573,11 @@ class FieldList(Index): _md_indices = {} - def __init__(self, *args, **kwargs): + def __init__(self, *args, backend=None, **kwargs): + if backend is None: + backend = "numpy" + self.backend = ArrayBackend.find(backend) + super().__init__(*args, **kwargs) @staticmethod @@ -753,12 +722,9 @@ def to_numpy(self, **kwargs): return np.array([f.to_numpy(**kwargs) for f in self]) - def to_array(self, backend="cupy"): - import array_api_compat - - x = [f.to_array(backend=backend) for f in self] - xp = array_api_compat.array_namespace(x[0]) - return xp.stack(x) + def to_array(self): + x = [f.to_array() for f in self] + return self.backend.array_ns.stack(x) @property def values(self): @@ -784,9 +750,13 @@ def values(self): array([262.78027344, 267.44726562, 268.61230469]) """ - import numpy as np + # import numpy as np + + x = [f.values for f in self] + return self.backend.array_ns.stack(x) + # return self.backend.to_array(x) - return np.array([f.values for f in self]) + # return np.array([f.values for f in self]) def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): r"""Return the values and/or the geographical coordinates. diff --git a/earthkit/data/readers/grib/codes.py b/earthkit/data/readers/grib/codes.py index 62a3b5c4..9061b9d9 100644 --- a/earthkit/data/readers/grib/codes.py +++ b/earthkit/data/readers/grib/codes.py @@ -245,8 +245,8 @@ class GribField(Field): Size of the message (in bytes) """ - def __init__(self, path, offset, length): - super().__init__() + def __init__(self, path, offset, length, backend): + super().__init__(backend) self.path = path self._offset = offset self._length = length diff --git a/earthkit/data/readers/grib/index/__init__.py b/earthkit/data/readers/grib/index/__init__.py index bdb91282..09b0a3e0 100644 --- a/earthkit/data/readers/grib/index/__init__.py +++ b/earthkit/data/readers/grib/index/__init__.py @@ -13,7 +13,7 @@ from abc import abstractmethod from earthkit.data.core.fieldlist import FieldList -from earthkit.data.core.index import Index, MaskIndex, MultiIndex +from earthkit.data.core.index import MaskIndex, MultiIndex from earthkit.data.decorators import alias_argument from earthkit.data.indexing.database import ( FILEPARTS_KEY_NAMES, @@ -109,7 +109,8 @@ def __init__(self, *args, **kwargs): ): self._availability = Availability(self.availability_path) - Index.__init__(self, *args, **kwargs) + # Index.__init__(self, *args, **kwargs) + FieldList.__init__(self, *args, **kwargs) @classmethod def new_mask_index(self, *args, **kwargs): @@ -194,7 +195,7 @@ class GribFieldListInFiles(GribFieldList): def _getitem(self, n): if isinstance(n, int): part = self.part(n if n >= 0 else len(self) + n) - return GribField(part.path, part.offset, part.length) + return GribField(part.path, part.offset, part.length, self.backend) def __len__(self): return self.number_of_parts() diff --git a/earthkit/data/readers/grib/reader.py b/earthkit/data/readers/grib/reader.py index 2fa973e8..5ee083b7 100644 --- a/earthkit/data/readers/grib/reader.py +++ b/earthkit/data/readers/grib/reader.py @@ -20,8 +20,10 @@ class GRIBReader(GribFieldListInOneFile, Reader): appendable = True # GRIB messages can be added to the same file def __init__(self, source, path, parts=None): + backend = source._kwargs.get("backend", None) + Reader.__init__(self, source, path) - GribFieldListInOneFile.__init__(self, path, parts=parts) + GribFieldListInOneFile.__init__(self, path, parts=parts, backend=backend) def __repr__(self): return "GRIBReader(%s)" % (self.path,) diff --git a/earthkit/data/sources/url.py b/earthkit/data/sources/url.py index efa97e1a..511468db 100644 --- a/earthkit/data/sources/url.py +++ b/earthkit/data/sources/url.py @@ -219,8 +219,6 @@ def __init__( download_file_extension=".download", ) - print(f"downloader={self.downloader}") - if extension and extension[0] != ".": extension = "." + extension From c6c5c28bf19debd763f4bc810abaff2862a5b00a Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 13 Feb 2024 14:06:34 +0000 Subject: [PATCH 04/18] Impelement array backends for fieldlist --- docs/examples.rst | 1 + docs/examples/grib_array_backends.ipynb | 959 +++++++++++++++++++ docs/examples/numpy_fieldlist.ipynb | 8 +- earthkit/data/core/array.py | 234 +++++ earthkit/data/core/array/__init__.py | 109 --- earthkit/data/core/array/numpy.py | 0 earthkit/data/core/array/pytorch.py | 0 earthkit/data/core/fieldlist.py | 83 +- earthkit/data/readers/__init__.py | 8 +- earthkit/data/readers/grib/__init__.py | 8 +- earthkit/data/readers/grib/index/__init__.py | 12 +- earthkit/data/readers/grib/memory.py | 28 +- earthkit/data/readers/grib/reader.py | 8 - earthkit/data/readers/netcdf.py | 37 +- earthkit/data/sources/array_list.py | 92 +- earthkit/data/sources/constants.py | 8 +- earthkit/data/sources/list_of_dicts.py | 6 +- earthkit/data/sources/numpy_list.py | 214 +---- earthkit/data/sources/stream.py | 38 +- tests/documentation/test_notebooks.py | 1 + tests/grib/test_grib_sel.py | 1 + tests/grib/test_grib_stream.py | 2 +- tests/grib/test_grib_url_stream.py | 2 +- 23 files changed, 1400 insertions(+), 459 deletions(-) create mode 100644 docs/examples/grib_array_backends.ipynb create mode 100644 earthkit/data/core/array.py delete mode 100644 earthkit/data/core/array/__init__.py delete mode 100644 earthkit/data/core/array/numpy.py delete mode 100644 earthkit/data/core/array/pytorch.py diff --git a/docs/examples.rst b/docs/examples.rst index fabff3bc..51dc5f84 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -38,6 +38,7 @@ Here is a list of example notebooks to illustrate how to use earthkit-data. examples/grib_url_stream.ipynb examples/grib_to_netcdf.ipynb examples/numpy_fieldlist.ipynb + examples/grib_array_backends.ipynb examples/grib_nearest_gridpoint.ipynb examples/grib_time_series.ipynb examples/grib_fdb_write.ipynb diff --git a/docs/examples/grib_array_backends.ipynb b/docs/examples/grib_array_backends.ipynb new file mode 100644 index 00000000..bab9310e --- /dev/null +++ b/docs/examples/grib_array_backends.ipynb @@ -0,0 +1,959 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "22408e0c-ddae-4924-b352-e7d796602c14", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Using array backends for GRIB data" + ] + }, + { + "cell_type": "markdown", + "id": "05675a78-99c4-404f-aae4-12c1d8ee1ced", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "In this example we will use a GRIB file containing 4 messages. First we ensure the file is available." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "076ae474-e6c4-4a0d-a66c-91b03965627a", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import earthkit.data\n", + "earthkit.data.download_example_file(\"test4.grib\")" + ] + }, + { + "cell_type": "raw", + "id": "847c90c4-5928-4481-abc9-c2ed9aada29f", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "When reading GRIB data with :func:`from_source` we can specify the array ``backend`` we want to use when extracting the field values. The default backend is \"numpy\". For this example we choose the \"pytorch\" backend. Since pytorch is an optional dependency for earthkit-data we need to ensure it is installed in the environment." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b3b30d9f-0edb-4938-baec-7026acd70192", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "!pip install torch --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5d2174d7-0f36-4b20-8ad5-bd93fd12f91b", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "ds = earthkit.data.from_source(\"file\", \"test6.grib\", backend=\"pytorch\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3c108a25-5f41-422f-9adb-98c932205dce", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
centreshortNametypeOfLevelleveldataDatedataTimestepRangedataTypenumbergridType
0ecmftisobaricInhPa10002018080112000an0regular_ll
1ecmfuisobaricInhPa10002018080112000an0regular_ll
2ecmfvisobaricInhPa10002018080112000an0regular_ll
3ecmftisobaricInhPa8502018080112000an0regular_ll
4ecmfuisobaricInhPa8502018080112000an0regular_ll
5ecmfvisobaricInhPa8502018080112000an0regular_ll
\n", + "
" + ], + "text/plain": [ + " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", + "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", + "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", + "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", + "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", + "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", + "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "\n", + " dataType number gridType \n", + "0 an 0 regular_ll \n", + "1 an 0 regular_ll \n", + "2 an 0 regular_ll \n", + "3 an 0 regular_ll \n", + "4 an 0 regular_ll \n", + "5 an 0 regular_ll " + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.ls()" + ] + }, + { + "cell_type": "markdown", + "id": "fae43976-e6c9-4520-a28a-13099a44f08d", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "#### values()" + ] + }, + { + "cell_type": "raw", + "id": "0019f9b6-3607-48df-a018-d8435cdac15e", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "When we use either :py:attr:`Field.values ` or :py:attr:`FieldList.values ` now we get a pytorch Tensor." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "21a1e8b1-f0e5-4de6-bdbf-e92c5df5989e", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642,\n", + " 272.5642, 272.5642, 272.5642], dtype=torch.float64)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[0].values[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fbd92126-bb5b-47e5-80a7-d4bed3097764", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([84])" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[0].values.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "714b320c-90ea-4326-bc5f-340dda66daab", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 84])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.values.shape" + ] + }, + { + "cell_type": "markdown", + "id": "c9197df9-3b35-4220-b189-2439de0a4ea9", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "#### to_array()" + ] + }, + { + "cell_type": "raw", + "id": "fe0f41b3-15f6-4a09-a4ef-8b6b36686142", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "The :py:meth:`Field.to_array() ` and :py:meth:`FieldList.values ` methods return the values in the underlying backend. " + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f60797eb-d578-4638-b1d0-bd18949dd249", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[272.5642, 272.5642],\n", + " [288.5642, 296.5642]], dtype=torch.float64)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[0].to_array()[:2,:2]" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f4d053fa-2acf-4949-9bbd-a0b2ccf30318", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 7, 12])" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.to_array().shape" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "04cd31df-34cd-47a9-90cc-833b9805bd55", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([6, 84])" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.to_array(flatten=True).shape" + ] + }, + { + "cell_type": "markdown", + "id": "f0efdd2b-e078-43e2-936e-35db3a645a6c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "#### Array fieldlists" + ] + }, + { + "cell_type": "raw", + "id": "ea2a5619-6022-4166-85f6-227b9282ffa7", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "Our fieldlist can be converted into an in-memory :py:class:`~data.sources.array_list.ArrayFieldList` where each message consists of a :py:class:`~data.readers.grib.metadata.GribMetadata` object and an array with the given backend storing the field values." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c5d3efc1-1299-4e0d-b59a-301e795bffc5", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "ArrayFieldList(fields=6)" + ], + "text/plain": [ + "ArrayFieldList(fields=6)" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r = ds.to_fieldlist(\"pytorch\")\n", + "r" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f87b0384-60cb-4bf2-8669-e193427c28e1", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
centreshortNametypeOfLevelleveldataDatedataTimestepRangedataTypenumbergridType
0ecmftisobaricInhPa10002018080112000an0regular_ll
1ecmfuisobaricInhPa10002018080112000an0regular_ll
2ecmfvisobaricInhPa10002018080112000an0regular_ll
3ecmftisobaricInhPa8502018080112000an0regular_ll
4ecmfuisobaricInhPa8502018080112000an0regular_ll
5ecmfvisobaricInhPa8502018080112000an0regular_ll
\n", + "
" + ], + "text/plain": [ + " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", + "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", + "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", + "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", + "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", + "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", + "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "\n", + " dataType number gridType \n", + "0 an 0 regular_ll \n", + "1 an 0 regular_ll \n", + "2 an 0 regular_ll \n", + "3 an 0 regular_ll \n", + "4 an 0 regular_ll \n", + "5 an 0 regular_ll " + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r.ls()" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "871e9c13-06e8-4ed6-90a9-8696c95ede8b", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642,\n", + " 272.5642, 272.5642, 272.5642], dtype=torch.float64)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r[0].values[:10]" + ] + }, + { + "cell_type": "markdown", + "id": "a78665fd-9a37-456a-8fda-a11c358aba64", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "Whe can build a new ArrayFiedlList straight from metadata and array values. This can be used for computations, when we want to alter the values and store the result in a new FieldList." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "a92db7cb-e2e1-472e-92b6-0f42360d2105", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
centreshortNametypeOfLevelleveldataDatedataTimestepRangedataTypenumbergridType
0ecmftisobaricInhPa10002018080112000an0regular_ll
1ecmfuisobaricInhPa10002018080112000an0regular_ll
2ecmfvisobaricInhPa10002018080112000an0regular_ll
3ecmftisobaricInhPa8502018080112000an0regular_ll
4ecmfuisobaricInhPa8502018080112000an0regular_ll
5ecmfvisobaricInhPa8502018080112000an0regular_ll
\n", + "
" + ], + "text/plain": [ + " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", + "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", + "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", + "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", + "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", + "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", + "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "\n", + " dataType number gridType \n", + "0 an 0 regular_ll \n", + "1 an 0 regular_ll \n", + "2 an 0 regular_ll \n", + "3 an 0 regular_ll \n", + "4 an 0 regular_ll \n", + "5 an 0 regular_ll " + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "md = ds.metadata()\n", + "v = ds.to_array() + 2\n", + "r1 = earthkit.data.FieldList.from_array(v, md)\n", + "r1.ls()" + ] + }, + { + "cell_type": "markdown", + "id": "920f6b98-f0e6-4ffc-a5e4-f8f643c23f76", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "As expected, the values are now differing by 2 from the ones in the originial FieldList." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "5b78ea8a-64e9-4bfa-9d11-8b390995994b", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([274.5642, 274.5642, 274.5642, 274.5642, 274.5642, 274.5642, 274.5642,\n", + " 274.5642, 274.5642, 274.5642], dtype=torch.float64)" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "r1[0].values[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3e44a0c8-dbf8-4454-8519-c57330cc4d71", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev_ecc", + "language": "python", + "name": "dev_ecc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/numpy_fieldlist.ipynb b/docs/examples/numpy_fieldlist.ipynb index 7cd65b01..8be82809 100644 --- a/docs/examples/numpy_fieldlist.ipynb +++ b/docs/examples/numpy_fieldlist.ipynb @@ -390,7 +390,13 @@ "cell_type": "code", "execution_count": 11, "id": "4113f204-dd96-422d-acb5-e2e0901cfb7a", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [ { "data": { diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py new file mode 100644 index 00000000..b290df09 --- /dev/null +++ b/earthkit/data/core/array.py @@ -0,0 +1,234 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import threading +from abc import ABCMeta, abstractmethod + + +class ArrayBackendItem: + def __init__(self, backend_type): + self.type = backend_type + self._obj = None + self._avail = None + self.lock = threading.Lock() + + def obj(self): + if self._obj is None: + with self.lock: + if self._obj is None: + self._obj = self.type() + return self._obj + + def available(self): + if self._avail is None: + if self._obj is not None: + self._avail = True + else: + try: + self.obj() + self._avail = True + except Exception: + self._avail = False + return self._avail + + +class ArrayBackendManager: + def __init__(self): + """The backend objects are created on demand to avoid unnecessary imports""" + self.backends = {k: ArrayBackendItem(v) for k, v in array_backend_types.items()} + self._np_backend = None + + def find_for_name(self, name): + b = self.backends.get(name, None) + if b is None: + raise TypeError(f"No backend found for name={name}") + + # this will try to create the backend if it does not exist yet and + # throw an exception when it is not possible + return b.obj() + + def find_for_array(self, v, guess=None): + if guess is not None: + if guess.is_native_array(v): + return guess + + # try all the backends + for _, b in self.backends.items(): + # this will try create the backend if it does not exist yest. + # If it fails available() will return False from this moment on. + if b.available() and b.obj().is_native_array(v): + return b.obj() + + raise TypeError(f"No backend found for array type={type(v)}") + + def numpy_backend(self): + if self._np_backend is None: + self._np_backend = self.find_for_name("numpy") + return self._np_backend + + +class ArrayBackend(metaclass=ABCMeta): + _array_ns = None + _default = "numpy" + _name = None + _array_name = "array" + + def __init__(self): + self.lock = threading.Lock() + + @property + def array_ns(self): + """Delayed construction of array namespace""" + if self._array_ns is None: + with self.lock: + if self._array_ns is None: + self._array_ns = self._make_array_ns() + return self._array_ns + + @abstractmethod + def _make_array_ns(self): + pass + + @property + def name(self): + return self._name + + @property + def array_name(self): + return f"{self._name} {self._array_name}" + + def to_array(self, v, backend=None): + if backend is not None: + if backend is self: + return v + + return backend.to_backend(self, v) + + @abstractmethod + def is_native_array(self, v): + pass + + @abstractmethod + def to_backend(self, backend, v): + pass + + @abstractmethod + def from_numpy(self, v): + pass + + @abstractmethod + def from_pytorch(self, v): + pass + + +class NumpyBackend(ArrayBackend): + _name = "numpy" + + def __init__(self): + super().__init__() + + def _make_array_ns(self): + import numpy as np + + try: + import array_api_compat + + ns = array_api_compat.array_namespace(np.ones(2)) + except Exception: + ns = np + + return ns + + def is_native_array(self, v): + import numpy as np + + return isinstance(v, np.ndarray) + + def to_backend(self, backend, v): + return backend.from_numpy(v) + + def from_numpy(self, v): + return v + + def from_pytorch(self, v): + import torch + + return torch.to_numpy(v) + + +class PytorchBackend(ArrayBackend): + _name = "pytroch" + _array_name = "tensor" + + def __init__(self): + super().__init__() + # pytorch is an optional dependency, we need to see on init + # if we can load it + self.array_ns + + def _make_array_ns(self): + try: + import array_api_compat + + except Exception: + raise ImportError("array_api_compat is required to use pytorch backend") + + try: + import torch + except Exception: + raise ImportError("pytorch is required to use pytorch backend") + + return array_api_compat.array_namespace(torch.ones(2)) + + def is_native_array(self, v): + import torch + + return torch.is_tensor(v) + + def to_backend(self, backend, v): + return backend.from_pytorch(v) + + def from_numpy(self, v): + import torch + + return torch.from_numpy(v) + + def from_pytorch(self, v): + return v + + +array_backend_types = {"numpy": NumpyBackend, "pytorch": PytorchBackend} + +_MANAGER = ArrayBackendManager() + +NUMPY_BACKEND = _MANAGER.numpy_backend() + + +def ensure_backend(backend): + if backend is None: + return _MANAGER.find_for_name(ArrayBackend._default) + elif isinstance(backend, str): + return _MANAGER.find_for_name(backend) + else: + return backend + + +def get_backend(array, guess=None, strict=True): + if isinstance(array, list): + array = array[0] + + if guess is not None: + guess = ensure_backend(guess) + + b = _MANAGER.find_for_array(array, guess=guess) + if strict and guess is not None and b is not guess: + raise ValueError( + f"array type={b.array_name} and specified backend={guess} do not match" + ) + return b diff --git a/earthkit/data/core/array/__init__.py b/earthkit/data/core/array/__init__.py deleted file mode 100644 index b4e0a084..00000000 --- a/earthkit/data/core/array/__init__.py +++ /dev/null @@ -1,109 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import threading -from abc import ABCMeta, abstractmethod - - -class ArrayBackendManager: - def __init__(self): - self.backends = {} - self.lock = threading.Lock() - - def find(self, name): - with self.lock: - b = self.backends.get(name, None) - if b is None: - b = array_backend_types[name]() - self.backends[name] = b - return b - - -MANAGER = ArrayBackendManager() - - -class ArrayBackend(metaclass=ABCMeta): - _array_ns = None - - @property - def array_ns(self): - return self._array_ns - - @staticmethod - def find(name): - return MANAGER.find(name) - - def to_array(self, v, backend=None): - if backend is not None: - if backend is self: - return v - - return backend.to_backend(self, v) - - @abstractmethod - def to_backend(self, backend, v): - pass - - @abstractmethod - def from_numpy(self, v): - pass - - @abstractmethod - def from_pytorch(self, v): - pass - - -class NumpyBackend(ArrayBackend): - def __init__(self): - import numpy as np - - try: - import array_api_compat - - self._array_ns = array_api_compat.array_namespace(np.ones(2)) - except Exception: - self._array_ns = np - - def to_backend(self, backend, v): - return backend.from_numpy(v) - - def from_numpy(self, v): - return v - - def from_pytorch(self, v): - import torch - - return torch.to_numpy(v) - - -class PytorchBackend(ArrayBackend): - def __init__(self): - try: - import array_api_compat - - except Exception: - raise ImportError("array_api_compat is required to use pytorch backend") - - import torch - - self._array_ns = array_api_compat.array_namespace(torch.ones(2)) - - def to_backend(self, backend, v): - return backend.from_pytorch(v) - - def from_numpy(self, v): - import torch - - return torch.from_numpy(v) - - def from_pytorch(self, v): - return v - - -array_backend_types = {"numpy": NumpyBackend, "pytorch": PytorchBackend} diff --git a/earthkit/data/core/array/numpy.py b/earthkit/data/core/array/numpy.py deleted file mode 100644 index e69de29b..00000000 diff --git a/earthkit/data/core/array/pytorch.py b/earthkit/data/core/array/pytorch.py deleted file mode 100644 index e69de29b..00000000 diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index bd998ea2..092d622c 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -12,7 +12,7 @@ from collections import defaultdict from earthkit.data.core import Base -from earthkit.data.core.array import ArrayBackend +from earthkit.data.core.array import NUMPY_BACKEND, ensure_backend from earthkit.data.core.index import Index from earthkit.data.decorators import cached_method from earthkit.data.utils.metadata import metadata_argument @@ -21,18 +21,22 @@ class Field(Base): r"""Represents a Field.""" - raw_backend = ArrayBackend.find("numpy") + raw_backend = NUMPY_BACKEND def __init__(self, backend, metadata=None): self.__metadata = metadata self.backend = backend - def _to_array(self, v): - return self.backend.to_array(v, self.raw_backend) + def _to_array(self, v, backend=None): + if backend is None: + return self.backend.to_array(v, self.raw_backend) + else: + backend = ensure_backend(backend) + return backend.to_array(v, self.raw_backend) @abstractmethod def _values(self, dtype=None): - r"""Return the values stored in the field as an ndarray. + r"""Return the values as stored in the field as an ndarray. Parameters ---------- @@ -41,6 +45,8 @@ def _values(self, dtype=None): type used by the underlying data accessor is used. For GRIB it is ``np.float64``. + The original shape and backend type of the values is kept. + Returns ------- ndarray @@ -55,6 +61,7 @@ def values(self): v = self._to_array(self._values()) if len(v.shape) != 1: n = math.prod(v.shape) + n = (n,) return self.backend.array_ns.reshape(v, n) return v @@ -88,13 +95,13 @@ def to_numpy(self, flatten=False, dtype=None): """ v = self._values(dtype=dtype) - ArrayBackend.find("numpy").to_array(v, self.raw_backend) + NUMPY_BACKEND.to_array(v, self.raw_backend) shape = self._required_shape(flatten) if shape != v.shape: return v.reshape(shape) return v - def to_array(self, flatten=False, dtype=None): + def to_array(self, flatten=False, dtype=None, backend=None): r"""Return the values stored in the field as an ndarray. Parameters @@ -112,7 +119,7 @@ def to_array(self, flatten=False, dtype=None): Field values """ - v = self._to_array(self._values(dtype=dtype)) + v = self._to_array(self._values(dtype=dtype), backend=backend) shape = self._required_shape(flatten) if shape != v.shape: return self.backend.array_ns.reshape(v, shape) @@ -238,10 +245,14 @@ def to_points(self, flatten=False, dtype=None): x = self._metadata.geography.x(dtype=dtype) y = self._metadata.geography.y(dtype=dtype) if x is not None and y is not None: + x = self._to_array(x) + y = self._to_array(y) shape = self._required_shape(flatten) if shape != x.shape: - x = x.reshape(shape) - y = y.reshape(shape) + # x = x.reshape(shape) + # y = y.reshape(shape) + x = self.backend.array_ns.reshape(x, shape) + y = self.backend.array_ns.reshape(y, shape) return dict(x=x, y=y) elif self.projection().CARTOPY_CRS == "PlateCarree": lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype) @@ -573,18 +584,27 @@ class FieldList(Index): _md_indices = {} - def __init__(self, *args, backend=None, **kwargs): - if backend is None: - backend = "numpy" - self.backend = ArrayBackend.find(backend) + def __init__(self, backend=None, **kwargs): + self.backend = ensure_backend(backend) + super().__init__(**kwargs) + + def _init_from_mask(self, index): + self.backend = index._index.backend - super().__init__(*args, **kwargs) + def _init_from_multi(self, index): + self.backend = index._indexes[0].backend @staticmethod def from_numpy(array, metadata): - from earthkit.data.sources.numpy_list import NumpyFieldList + from earthkit.data.sources.array_list import ArrayFieldList + + return ArrayFieldList(array, metadata, backend=NUMPY_BACKEND) + + @staticmethod + def from_array(array, metadata): + from earthkit.data.sources.array_list import ArrayFieldList - return NumpyFieldList(array, metadata) + return ArrayFieldList(array, metadata) def ignore(self): # When the concrete type is Fieldlist we assume the object was @@ -722,8 +742,8 @@ def to_numpy(self, **kwargs): return np.array([f.to_numpy(**kwargs) for f in self]) - def to_array(self): - x = [f.to_array() for f in self] + def to_array(self, **kwargs): + x = [f.to_array(**kwargs) for f in self] return self.backend.array_ns.stack(x) @property @@ -750,13 +770,8 @@ def values(self): array([262.78027344, 267.44726562, 268.61230469]) """ - # import numpy as np - x = [f.values for f in self] return self.backend.array_ns.stack(x) - # return self.backend.to_array(x) - - # return np.array([f.values for f in self]) def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): r"""Return the values and/or the geographical coordinates. @@ -824,8 +839,6 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): values """ - import numpy as np - if self._is_shared_grid(): if isinstance(keys, str): keys = [keys] @@ -840,14 +853,14 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): elif k == "lon": r.append(latlon["lon"]) elif k == "value": - r.extend([f.to_numpy(flatten=flatten, dtype=dtype) for f in self]) + r.extend([f.to_array(flatten=flatten, dtype=dtype) for f in self]) else: raise ValueError(f"data: invalid argument: {k}") - return np.array(r) + return self.backend.array_ns.stack(r) elif len(self) == 0: - return np.array([]) + return self.backend.array_ns.stack([]) else: raise ValueError("Fields do not have the same grid geometry") @@ -1286,13 +1299,9 @@ def to_fieldlist(self, backend, **kwargs): dtype('float32') """ - converter = fieldlist_converters.get(backend, None) - if converter is not None: - return getattr(self, converter)(**kwargs) + backend = ensure_backend(backend) + return self._to_array_fieldlist(backend=backend, **kwargs) - def _to_numpy_fieldlist(self, **kwargs): + def _to_array_fieldlist(self, **kwargs): md = [f.metadata() for f in self] - return self.from_numpy(self.to_numpy(**kwargs), md) - - -fieldlist_converters = {"numpy": "_to_numpy_fieldlist"} + return self.from_array(self.to_array(**kwargs), md) diff --git a/earthkit/data/readers/__init__.py b/earthkit/data/readers/__init__.py index e4e62621..ec659829 100644 --- a/earthkit/data/readers/__init__.py +++ b/earthkit/data/readers/__init__.py @@ -182,16 +182,16 @@ def reader(source, path, **kwargs): ) -def memory_reader(source, buffer): +def memory_reader(source, buffer, **kwargs): """Create a reader for data held in a memory buffer""" assert isinstance(buffer, (bytes, bytearray)), source n_bytes = SETTINGS.get("reader-type-check-bytes") magic = buffer[: min(n_bytes, len(buffer) - 1)] - return _find_reader("memory_reader", source, buffer, magic=magic) + return _find_reader("memory_reader", source, buffer, magic=magic, **kwargs) -def stream_reader(source, stream, memory, content_type=None): +def stream_reader(source, stream, memory, **kwargs): """Create a reader for a stream""" magic = None if hasattr(stream, "peek") and callable(stream.peek): @@ -209,5 +209,5 @@ def stream_reader(source, stream, memory, content_type=None): stream, magic=magic, memory=memory, - content_type=content_type, + **kwargs, ) diff --git a/earthkit/data/readers/grib/__init__.py b/earthkit/data/readers/grib/__init__.py index 5c8c6362..1368518f 100644 --- a/earthkit/data/readers/grib/__init__.py +++ b/earthkit/data/readers/grib/__init__.py @@ -38,7 +38,9 @@ def memory_reader(source, buffer, *, magic=None, deeper_check=False, **kwargs): if _match_magic(magic, deeper_check): from .memory import GribFieldListInMemory, GribMessageMemoryReader - return GribFieldListInMemory(source, GribMessageMemoryReader(buffer)) + return GribFieldListInMemory( + source, GribMessageMemoryReader(buffer, **kwargs), **kwargs + ) def stream_reader( @@ -54,7 +56,7 @@ def stream_reader( if _is_default(magic, content_type) or _match_magic(magic, deeper_check): from .memory import GribFieldListInMemory, GribStreamReader - r = GribStreamReader(stream) + r = GribStreamReader(stream, **kwargs) if memory: - r = GribFieldListInMemory(source, r) + r = GribFieldListInMemory(source, r, **kwargs) return r diff --git a/earthkit/data/readers/grib/index/__init__.py b/earthkit/data/readers/grib/index/__init__.py index 09b0a3e0..e68c7c6e 100644 --- a/earthkit/data/readers/grib/index/__init__.py +++ b/earthkit/data/readers/grib/index/__init__.py @@ -113,7 +113,7 @@ def __init__(self, *args, **kwargs): FieldList.__init__(self, *args, **kwargs) @classmethod - def new_mask_index(self, *args, **kwargs): + def new_mask_index(cls, *args, **kwargs): return GribMaskFieldList(*args, **kwargs) @property @@ -122,7 +122,13 @@ def availability_path(self): @classmethod def merge(cls, sources): - assert all(isinstance(_, GribFieldList) for _ in sources) + if not all(isinstance(_, GribFieldList) for _ in sources): + raise ValueError( + "GribFieldList can only be merged to another GribFieldLists" + ) + if not all(s.backend is s[0].backend for s in sources): + raise ValueError("Only fieldlists with the same backend can be merged") + return GribMultiFieldList(sources) def _custom_availability(self, ignore_keys=None, filter_keys=lambda k: True): @@ -184,11 +190,13 @@ def _normalize_kwargs_names(self, **kwargs): class GribMaskFieldList(GribFieldList, MaskIndex): def __init__(self, *args, **kwargs): MaskIndex.__init__(self, *args, **kwargs) + FieldList._init_from_mask(self, self) class GribMultiFieldList(GribFieldList, MultiIndex): def __init__(self, *args, **kwargs): MultiIndex.__init__(self, *args, **kwargs) + FieldList._init_from_multi(self, self) class GribFieldListInFiles(GribFieldList): diff --git a/earthkit/data/readers/grib/memory.py b/earthkit/data/readers/grib/memory.py index 2e2e9d52..d1f230c3 100644 --- a/earthkit/data/readers/grib/memory.py +++ b/earthkit/data/readers/grib/memory.py @@ -11,6 +11,7 @@ import eccodes +from earthkit.data.core.array import ensure_backend from earthkit.data.readers import Reader from earthkit.data.readers.grib.codes import GribCodesHandle, GribField from earthkit.data.readers.grib.index import GribFieldList @@ -19,8 +20,9 @@ class GribMemoryReader(Reader): - def __init__(self): + def __init__(self, backend=None): self._peeked = None + self.backend = ensure_backend(backend) def __iter__(self): return self @@ -41,7 +43,7 @@ def _next_handle(self): def _message_from_handle(self, handle): if handle is not None: - return GribFieldInMemory(GribCodesHandle(handle, None, None)) + return GribFieldInMemory(GribCodesHandle(handle, None, None), self.backend) def peek(self): """Returns the next available message without consuming it""" @@ -87,8 +89,8 @@ def read_group(self, group): class GribFileMemoryReader(GribMemoryReader): - def __init__(self, path): - super().__init__() + def __init__(self, path, **kwargs): + super().__init__(**kwargs) self.fp = open(path, "rb") def __del__(self): @@ -99,8 +101,8 @@ def _next_handle(self): class GribMessageMemoryReader(GribMemoryReader): - def __init__(self, buf): - super().__init__() + def __init__(self, buf, **kwargs): + super().__init__(**kwargs) self.buf = buf def __del__(self): @@ -121,10 +123,10 @@ class GribStreamReader(GribMemoryReader): using _next_handle """ - def __init__(self, stream): + def __init__(self, stream, **kwargs): super().__init__() self._stream = stream - self._reader = eccodes.StreamReader(stream) + self._reader = eccodes.StreamReader(stream, **kwargs) def __del__(self): self._stream.close() @@ -142,8 +144,8 @@ def mutate_source(self): class GribFieldInMemory(GribField): """Represents a GRIB message in memory""" - def __init__(self, handle): - super().__init__(None, None, None) + def __init__(self, handle, backend=None): + super().__init__(None, None, None, backend) self._handle = handle @GribField.handle.getter @@ -159,8 +161,10 @@ class GribFieldListInMemory(GribFieldList, Reader): """Represent a GRIB field list in memory""" @staticmethod - def from_fields(fields): - fs = GribFieldListInMemory(None, None) + def from_fields(fields, backend=None): + if backend is None and len(fields) > 0: + backend = fields[0].backend + fs = GribFieldListInMemory(None, None, backend=backend) fs._fields = fields fs._loaded = True return fs diff --git a/earthkit/data/readers/grib/reader.py b/earthkit/data/readers/grib/reader.py index 5ee083b7..1186b0af 100644 --- a/earthkit/data/readers/grib/reader.py +++ b/earthkit/data/readers/grib/reader.py @@ -10,7 +10,6 @@ import logging from earthkit.data.readers import Reader -from earthkit.data.readers.grib.index import GribMultiFieldList from earthkit.data.readers.grib.index.file import GribFieldListInOneFile LOG = logging.getLogger(__name__) @@ -28,13 +27,6 @@ def __init__(self, source, path, parts=None): def __repr__(self): return "GRIBReader(%s)" % (self.path,) - @classmethod - def merge(cls, readers): - assert all(isinstance(s, GRIBReader) for s in readers), readers - assert len(readers) > 1 - - return GribMultiFieldList(readers) - def mutate_source(self): # A GRIBReader is a source itself return self diff --git a/earthkit/data/readers/netcdf.py b/earthkit/data/readers/netcdf.py index d847d680..32c068e0 100644 --- a/earthkit/data/readers/netcdf.py +++ b/earthkit/data/readers/netcdf.py @@ -16,7 +16,7 @@ from earthkit.data.core.fieldlist import Field, FieldList from earthkit.data.core.geography import Geography -from earthkit.data.core.index import Index, MaskIndex, MultiIndex +from earthkit.data.core.index import MaskIndex, MultiIndex from earthkit.data.core.metadata import RawMetadata from earthkit.data.utils.bbox import BoundingBox from earthkit.data.utils.dates import to_datetime @@ -151,6 +151,7 @@ def bbox(self, variable): def get_fields_from_ds( ds, + backend, field_type=None, check_only=False, ): # noqa C901 @@ -260,7 +261,7 @@ def get_fields_from_ds( if check_only: return True - fields.append(field_type(ds, name, slices, non_dim_coords)) + fields.append(field_type(ds, name, slices, non_dim_coords, backend)) # if not fields: # raise Exception("NetCDFReader no 2D fields found in %s" % (self.path,)) @@ -377,8 +378,8 @@ def _valid_datetime(self): class XArrayField(Field): - def __init__(self, ds, variable, slices, non_dim_coords): - super().__init__() + def __init__(self, ds, variable, slices, non_dim_coords, backend): + super().__init__(backend) self._ds = ds self._da = ds[variable] @@ -452,7 +453,9 @@ class XArrayFieldListCore(FieldList): def __init__(self, ds, *args, **kwargs): self.ds = ds self._fields = None - Index.__init__(self, *args, **kwargs) + # Index.__init__(self, *args, **kwargs) + # super().__init__(self, *args, **kwargs) + super().__init__(*kwargs) @property def fields(self): @@ -463,7 +466,10 @@ def fields(self): def has_fields(self): if self._fields is None: return get_fields_from_ds( - DataSet(self.ds), field_type=self.FIELD_TYPE, check_only=True + DataSet(self.ds), + self.backend, + field_type=self.FIELD_TYPE, + check_only=True, ) else: return len(self._fields) @@ -473,7 +479,9 @@ def _scan(self): self._fields = self._get_fields() def _get_fields(self): - return get_fields_from_ds(DataSet(self.ds), field_type=self.FIELD_TYPE) + return get_fields_from_ds( + DataSet(self.ds), self.backend, field_type=self.FIELD_TYPE + ) def to_pandas(self): return self.to_xarray().to_pandas() @@ -519,11 +527,13 @@ def __len__(self): class XArrayMaskFieldList(XArrayFieldListCore, MaskIndex): def __init__(self, *args, **kwargs): MaskIndex.__init__(self, *args, **kwargs) + FieldList._init_from_mask(self, self) class XArrayMultiFieldList(XArrayFieldListCore, MultiIndex): def __init__(self, *args, **kwargs): MultiIndex.__init__(self, *args, **kwargs) + FieldList._init_from_multi(self, self) def to_xarray(self, **kwargs): import xarray as xr @@ -554,7 +564,9 @@ def _get_fields(self): with closing( xr.open_mfdataset(self.path, combine="by_coords") ) as ds: # or nested - return get_fields_from_ds(DataSet(ds), field_type=self.FIELD_TYPE) + return get_fields_from_ds( + DataSet(ds), self.backend, field_type=self.FIELD_TYPE + ) def has_fields(self): if self._fields is None: @@ -564,7 +576,10 @@ def has_fields(self): xr.open_mfdataset(self.path, combine="by_coords") ) as ds: # or nested return get_fields_from_ds( - DataSet(ds), field_type=self.FIELD_TYPE, check_only=True + DataSet(ds), + self.backend, + field_type=self.FIELD_TYPE, + check_only=True, ) else: return len(self._fields) @@ -575,7 +590,7 @@ def merge(cls, sources): return NetCDFMultiFieldList(sources) @classmethod - def new_mask_index(self, *args, **kwargs): + def new_mask_index(cls, *args, **kwargs): return NetCDFMaskFieldList(*args, **kwargs) def to_xarray(self, **kwargs): @@ -632,6 +647,7 @@ def __len__(self): class NetCDFMaskFieldList(NetCDFFieldList, MaskIndex): def __init__(self, *args, **kwargs): MaskIndex.__init__(self, *args, **kwargs) + FieldList._init_from_mask(self, self) # TODO: Implement this, but discussion required def to_xarray(self, *args, **kwargs): @@ -641,6 +657,7 @@ def to_xarray(self, *args, **kwargs): class NetCDFMultiFieldList(NetCDFFieldList, MultiIndex): def __init__(self, *args, **kwargs): MultiIndex.__init__(self, *args, **kwargs) + FieldList._init_from_multi(self, self) def to_xarray(self, **kwargs): try: diff --git a/earthkit/data/sources/array_list.py b/earthkit/data/sources/array_list.py index 2d957879..7683ab68 100644 --- a/earthkit/data/sources/array_list.py +++ b/earthkit/data/sources/array_list.py @@ -11,6 +11,7 @@ import numpy as np +from earthkit.data.core.array import get_backend from earthkit.data.core.fieldlist import Field, FieldList from earthkit.data.core.index import MaskIndex, MultiIndex from earthkit.data.readers.grib.pandas import PandasMixIn @@ -20,22 +21,22 @@ class ArrayField(Field): - r"""Represent a field consisting of an ndarray and metadata object. + r"""Represent a field consisting of an array and metadata object. Parameters ---------- - array: ndarray + array: array Array storing the values of the field metadata: :class:`Metadata` Metadata object describing the field metadata. + backend: str, ArrayBackend + Array backend. """ - def __init__(self, array, metadata): + def __init__(self, array, metadata, backend): + super().__init__(backend, metadata=metadata) self._array = array - super().__init__(metadata=metadata) - import array_api_compat - - self.__array_ns = array_api_compat.array_namespace(self._array) + self.raw_backend = backend def _make_metadata(self): pass @@ -44,7 +45,7 @@ def _values(self, dtype=None): if dtype is None: return self._array else: - return self._array_ns.astype(self._array, dtype, copy=False) + return self.backend.array_ns.astype(self._array, dtype, copy=False) def __repr__(self): return f"{self.__class__.__name__}()" @@ -66,20 +67,25 @@ def write(self, f, **kwargs): class ArrayFieldListCore(PandasMixIn, XarrayMixIn, FieldList): - def __init__(self, array, metadata, *args, **kwargs): + def __init__(self, array, metadata, *args, backend=None, **kwargs): self._array = array self._metadata = metadata if not isinstance(self._metadata, list): self._metadata = [self._metadata] - if isinstance(self._array, np.ndarray): + # get backend and check consistency + backend = get_backend(self._array, guess=backend, strict=True) + + FieldList.__init__(self, *args, backend=backend, **kwargs) + + if self.backend.is_native_array(self._array): if self._array.shape[0] != len(self._metadata): # we have a single array and a single metadata if len(self._metadata) == 1 and self._shape_match( self._array.shape, self._metadata[0].geography.shape() ): - self._array = np.array([self._array]) + self._array = self.backend.array_ns.stack([self._array]) else: raise ValueError( ( @@ -97,19 +103,22 @@ def __init__(self, array, metadata, *args, **kwargs): ) for i, a in enumerate(self._array): - if not isinstance(a, np.ndarray): + if not self.backend.is_native_array(a): raise ValueError( - f"All array element must be an ndarray. Type at position={i} is {type(a)}" + ( + f"All array element must be an {self.backend.array_name}." + " Type at position={i} is {type(a)}" + ) ) else: - raise TypeError("array must be an ndarray or a list of ndarrays") + raise TypeError( + f"array must be an {self.backend.array_name} or a list of {self.backend.array_name}s" + ) # hide internal metadata related to values self._metadata = [md._hide_internal_keys() for md in self._metadata] - super().__init__(*args, **kwargs) - def _shape_match(self, shape1, shape2): if shape1 == shape2: return True @@ -123,19 +132,24 @@ def new_mask_index(self, *args, **kwargs): @classmethod def merge(cls, sources): - assert all(isinstance(_, ArrayFieldListCore) for _ in sources) + if not all(isinstance(_, ArrayFieldListCore) for _ in sources): + raise ValueError( + "ArrayFieldList can only be merged to another ArrayFieldLists" + ) + if not all(s.backend is s[0].backend for s in sources): + raise ValueError("Only fieldlists with the same backend can be merged") + merger = ListMerger(sources) - # merger = MultiUnwindMerger(sources) return merger.to_fieldlist() def __repr__(self): return f"{self.__class__.__name__}(fields={len(self)})" - def _to_numpy_fieldlist(self, **kwargs): + def _to_array_fieldlist(self, backend=None, **kwargs): if self[0]._array_matches(self._array[0], **kwargs): return self else: - return type(self)(self.to_numpy(**kwargs), self._metadata) + return type(self)(self.to_array(backend=backend, **kwargs), self._metadata) def save(self, filename, append=False, check_nans=True, bits_per_value=16): r"""Write all the fields into a file. @@ -160,22 +174,23 @@ def save(self, filename, append=False, check_nans=True, bits_per_value=16): ) -class MultiUnwindMerger: - def __init__(self, sources): - self.sources = list(self._flatten(sources)) - - def _flatten(self, sources): - if isinstance(sources, ArrayMultiFieldList): - for s in sources.indexes: - yield from self._flatten(s) - elif isinstance(sources, list): - for s in sources: - yield from self._flatten(s) - else: - yield sources +# class MultiUnwindMerger: +# def __init__(self, sources): +# self.sources = list(self._flatten(sources)) - def to_fieldlist(self): - return ArrayMultiFieldList(self.sources) +# def _flatten(self, sources): +# if isinstance(sources, ArrayMultiFieldList): +# for s in sources.indexes: +# yield from self._flatten(s) +# elif isinstance(sources, list): +# for s in sources: +# yield from self._flatten(s) +# else: +# yield sources + +# def to_fieldlist(self): + +# return ArrayMultiFieldList(self.sources) class ListMerger: @@ -189,7 +204,8 @@ def to_fieldlist(self): for f in s: array.append(f._array) metadata.append(f._metadata) - return ArrayFieldList(array, metadata) + backend = None if len(self.sources) == 0 else self.sources[0].backend + return ArrayFieldList(array, metadata, backend=backend) class ArrayFieldList(ArrayFieldListCore): @@ -207,7 +223,7 @@ class ArrayFieldList(ArrayFieldListCore): def _getitem(self, n): if isinstance(n, int): - return ArrayField(self._array[n], self._metadata[n]) + return ArrayField(self._array[n], self._metadata[n], self.backend) def __len__(self): return ( @@ -218,8 +234,10 @@ def __len__(self): class ArrayMaskFieldList(ArrayFieldListCore, MaskIndex): def __init__(self, *args, **kwargs): MaskIndex.__init__(self, *args, **kwargs) + FieldList._init_from_mask(self, self) class ArrayMultiFieldList(ArrayFieldListCore, MultiIndex): def __init__(self, *args, **kwargs): MultiIndex.__init__(self, *args, **kwargs) + FieldList._init_from_multi(self, self) diff --git a/earthkit/data/sources/constants.py b/earthkit/data/sources/constants.py index 8f4e4024..91107733 100644 --- a/earthkit/data/sources/constants.py +++ b/earthkit/data/sources/constants.py @@ -186,7 +186,7 @@ def cos_solar_zenith_angle(self, date): class ConstantField(Field): - def __init__(self, date, param, proc, shape, geometry): + def __init__(self, date, param, proc, shape, geometry, backend): self.date = date self.param = param self.proc = proc @@ -199,7 +199,7 @@ def __init__(self, date, param, proc, shape, geometry): levelist=None, number=None, ) - super().__init__(metadata=ConstantMetadata(d, geometry)) + super().__init__(backend, metadata=ConstantMetadata(d, geometry)) def _make_metadata(self): pass @@ -287,6 +287,8 @@ def __init__(self, source_or_dataset, request={}, repeat=1, **kwargs): self.procs = {param: getattr(self.maker, param) for param in self.params} self._len = len(self.dates) * len(self.params) * self.repeat + super().__init__(**kwargs) + @normalize("date", "date-list") @normalize("time", "int-list") @normalize("number", "int-list") @@ -326,12 +328,14 @@ def _getitem(self, n): self.procs[param], self.maker.shape, self.maker.field.metadata().geography, + self.backend, ) class ConstantsMaskFieldList(ConstantsFieldListCore, MaskIndex): def __init__(self, *args, **kwargs): MaskIndex.__init__(self, *args, **kwargs) + FieldList._init_from_mask(self, self) source = ConstantsFieldList diff --git a/earthkit/data/sources/list_of_dicts.py b/earthkit/data/sources/list_of_dicts.py index 782b7da0..67e11202 100644 --- a/earthkit/data/sources/list_of_dicts.py +++ b/earthkit/data/sources/list_of_dicts.py @@ -171,8 +171,8 @@ def bounding_box(self): class VirtualGribField(Field): - def __init__(self, d): - super().__init__(metadata=VirtualGribMetadata(d)) + def __init__(self, d, backend): + super().__init__(backend, metadata=VirtualGribMetadata(d)) def _values(self, dtype=None): v = self._metadata["values"] @@ -191,7 +191,7 @@ def __init__(self, list_of_dicts, *args, **kwargs): super().__init__(*args, **kwargs) def __getitem__(self, n): - return VirtualGribField(self.list_of_dicts[n]) + return VirtualGribField(self.list_of_dicts[n], self.backend) def __len__(self): return len(self.list_of_dicts) diff --git a/earthkit/data/sources/numpy_list.py b/earthkit/data/sources/numpy_list.py index 7a959942..2f3767c9 100644 --- a/earthkit/data/sources/numpy_list.py +++ b/earthkit/data/sources/numpy_list.py @@ -7,215 +7,11 @@ # nor does it submit to any jurisdiction. # -import logging +from earthkit.data.core.array import NUMPY_BACKEND +from earthkit.data.sources.array_list import ArrayFieldList -import numpy as np -from earthkit.data.core.fieldlist import Field, FieldList -from earthkit.data.core.index import MaskIndex, MultiIndex -from earthkit.data.readers.grib.pandas import PandasMixIn -from earthkit.data.readers.grib.xarray import XarrayMixIn - -LOG = logging.getLogger(__name__) - - -class NumpyField(Field): - r"""Represent a field consisting of an ndarray and metadata object. - - Parameters - ---------- - array: ndarray - Array storing the values of the field - metadata: :class:`Metadata` - Metadata object describing the field metadata. - """ - - def __init__(self, array, metadata): - self._array = array - super().__init__(metadata=metadata) - - def _make_metadata(self): - pass - - def _values(self, dtype=None): - if dtype is None: - return self._array - else: - return self._array.astype(dtype, copy=False) - - def __repr__(self): - return f"{self.__class__.__name__}()" - - def write(self, f, **kwargs): - r"""Write the field to a file object. - - Parameters - ---------- - f: file object - The target file object. - **kwargs: dict, optional - Other keyword arguments passed to :meth:`data.writers.grib.GribWriter.write`. - """ - from earthkit.data.writers import write - - write(f, self.values, self._metadata, **kwargs) - - -class NumpyFieldListCore(PandasMixIn, XarrayMixIn, FieldList): - def __init__(self, array, metadata, *args, **kwargs): - self._array = array - self._metadata = metadata - - if not isinstance(self._metadata, list): - self._metadata = [self._metadata] - - if isinstance(self._array, np.ndarray): - if self._array.shape[0] != len(self._metadata): - # we have a single array and a single metadata - if len(self._metadata) == 1 and self._shape_match( - self._array.shape, self._metadata[0].geography.shape() - ): - self._array = np.array([self._array]) - else: - raise ValueError( - ( - f"first array dimension ({self._array.shape[0]}) differs " - f"from number of metadata objects ({len(self._metadata)})" - ) - ) - elif isinstance(self._array, list): - if len(self._array) != len(self._metadata): - raise ValueError( - ( - f"array len ({len(self._array)}) differs " - f"from number of metadata objects ({len(self._metadata)})" - ) - ) - - for i, a in enumerate(self._array): - if not isinstance(a, np.ndarray): - raise ValueError( - f"All array element must be an ndarray. Type at position={i} is {type(a)}" - ) - - else: - raise TypeError("array must be an ndarray or a list of ndarrays") - - # hide internal metadata related to values - self._metadata = [md._hide_internal_keys() for md in self._metadata] - - super().__init__(*args, **kwargs) - - def _shape_match(self, shape1, shape2): - if shape1 == shape2: - return True - if len(shape1) == 1 and shape1[0] == np.prod(shape2): - return True - return False - - @classmethod - def new_mask_index(self, *args, **kwargs): - return NumpyMaskFieldList(*args, **kwargs) - - @classmethod - def merge(cls, sources): - assert all(isinstance(_, NumpyFieldListCore) for _ in sources) - merger = ListMerger(sources) - # merger = MultiUnwindMerger(sources) - return merger.to_fieldlist() - - def __repr__(self): - return f"{self.__class__.__name__}(fields={len(self)})" - - def _to_numpy_fieldlist(self, **kwargs): - if self[0]._array_matches(self._array[0], **kwargs): - return self - else: - return type(self)(self.to_numpy(**kwargs), self._metadata) - - def save(self, filename, append=False, check_nans=True, bits_per_value=16): - r"""Write all the fields into a file. - - Parameters - ---------- - filename: str - The target file path. - append: bool - When it is true append data to the target file. Otherwise - the target file be overwritten if already exists. - check_nans: bool - Replace nans in the values with GRIB missing values when generating the output. - bits_per_value: int - Set the ``bitsPerValue`` GRIB key in the generated output. - """ - super().save( - filename, - append=append, - check_nans=check_nans, - bits_per_value=bits_per_value, - ) - - -class MultiUnwindMerger: - def __init__(self, sources): - self.sources = list(self._flatten(sources)) - - def _flatten(self, sources): - if isinstance(sources, NumpyMultiFieldList): - for s in sources.indexes: - yield from self._flatten(s) - elif isinstance(sources, list): - for s in sources: - yield from self._flatten(s) - else: - yield sources - - def to_fieldlist(self): - return NumpyMultiFieldList(self.sources) - - -class ListMerger: - def __init__(self, sources): - self.sources = sources - - def to_fieldlist(self): - array = [] - metadata = [] - for s in self.sources: - for f in s: - array.append(f._array) - metadata.append(f._metadata) - return NumpyFieldList(array, metadata) - - -class NumpyFieldList(NumpyFieldListCore): - r"""Represent a list of :obj:`NumpyField `\ s. - - The preferred way to create a NumpyFieldList is to use either the - static :obj:`from_numpy` method or the :obj:`to_fieldlist` method. - - See Also - -------- - from_numpy - to_fieldlist - - """ - - def _getitem(self, n): - if isinstance(n, int): - return NumpyField(self._array[n], self._metadata[n]) - - def __len__(self): - return ( - len(self._array) if isinstance(self._array, list) else self._array.shape[0] - ) - - -class NumpyMaskFieldList(NumpyFieldListCore, MaskIndex): - def __init__(self, *args, **kwargs): - MaskIndex.__init__(self, *args, **kwargs) - - -class NumpyMultiFieldList(NumpyFieldListCore, MultiIndex): +class NumpyFieldList(ArrayFieldList): def __init__(self, *args, **kwargs): - MultiIndex.__init__(self, *args, **kwargs) + kwargs.pop("backend", None) + super().__init__(*args, backend=NUMPY_BACKEND, **kwargs) diff --git a/earthkit/data/sources/stream.py b/earthkit/data/sources/stream.py index 90e01463..37e9ca56 100644 --- a/earthkit/data/sources/stream.py +++ b/earthkit/data/sources/stream.py @@ -59,7 +59,7 @@ def __init__(self, stream, **kwargs): @property def _reader(self): if self._reader_ is None: - self._reader_ = stream_reader(self, self._stream, True) + self._reader_ = stream_reader(self, self._stream, True, **self._kwargs) if self._reader_ is None: raise TypeError(f"could not create reader for stream={self._stream}") return self._reader_ @@ -73,11 +73,12 @@ def mutate(self): class StreamSourceBase(Source): - def __init__(self, stream, *, batch_size=1, group_by=None): + def __init__(self, stream, *, batch_size=1, group_by=None, **kwargs): super().__init__() self._reader_ = None self._stream = stream self.batch_size, self.group_by = check_stream_kwargs(batch_size, group_by) + self._kwargs = kwargs def __iter__(self): return self @@ -88,7 +89,7 @@ def mutate(self): @property def _reader(self): if self._reader_ is None: - self._reader_ = stream_reader(self, self._stream, False) + self._reader_ = stream_reader(self, self._stream, False, **self._kwargs) if self._reader_ is None: raise TypeError(f"could not create reader for stream={self._stream}") return self._reader_ @@ -214,28 +215,28 @@ class StreamSource(StreamSourceBase): def __init__(self, stream, **kwargs): super().__init__(stream, **kwargs) - # print(f"kwargs={kwargs} {id(kwargs)}") - # if kwargs: - # raise TypeError(f"got invalid keyword argument(s): {list(kwargs.keys())}") - def mutate(self): assert self._reader_ is None return _from_stream( - self._stream, batch_size=self.batch_size, group_by=self.group_by + self._stream, + batch_size=self.batch_size, + group_by=self.group_by, + **self._kwargs, ) class StreamSourceMaker: - def __init__(self, source, stream_kwargs): + def __init__(self, source, stream_kwargs, **kwargs): self.in_source = source + self._kwargs = kwargs self.stream_kwargs = dict(stream_kwargs) self.source = None def __call__(self): if self.source is None: stream = self.in_source.to_stream() - self.source = _from_stream(stream, **self.stream_kwargs) + self.source = _from_stream(stream, **self.stream_kwargs, **self._kwargs) prev = None src = self.source @@ -247,17 +248,17 @@ def __call__(self): return self.source -def _from_stream(stream, group_by, batch_size): +def _from_stream(stream, group_by, batch_size, **kwargs): _kwargs = dict(batch_size=batch_size, group_by=group_by) if group_by: - return StreamGroupSource(stream, **_kwargs) + return StreamGroupSource(stream, **_kwargs, **kwargs) elif batch_size == 0: - return StreamMemorySource(stream) + return StreamMemorySource(stream, **kwargs) elif batch_size > 1: - return StreamBatchSource(stream, **_kwargs) + return StreamBatchSource(stream, **_kwargs, **kwargs) elif batch_size == 1: - return StreamSingleSource(stream, **_kwargs) + return StreamSingleSource(stream, **_kwargs, **kwargs) raise ValueError(f"Unsupported stream parameters {batch_size=} {group_by=}") @@ -265,17 +266,14 @@ def _from_stream(stream, group_by, batch_size): def _from_source(source, **kwargs): stream_kwargs, kwargs = parse_stream_kwargs(**kwargs) - if kwargs: - raise TypeError(f"got invalid keyword argument(s): {list(kwargs.keys())}") - if not isinstance(source, (list, tuple)): source = [source] if len(source) == 1: - maker = StreamSourceMaker(source[0], stream_kwargs) + maker = StreamSourceMaker(source[0], stream_kwargs, **kwargs) return maker() else: - sources = [StreamSourceMaker(s, stream_kwargs) for s in source] + sources = [StreamSourceMaker(s, stream_kwargs, **kwargs) for s in source] return MultiStreamSource(sources, **stream_kwargs) diff --git a/tests/documentation/test_notebooks.py b/tests/documentation/test_notebooks.py index c2390990..411d8170 100644 --- a/tests/documentation/test_notebooks.py +++ b/tests/documentation/test_notebooks.py @@ -31,6 +31,7 @@ "polytope.ipynb", "grib_fdb_write.ipynb", "demo_source_plugin.ipynb", + "grib_array_backends.ipynb", ] diff --git a/tests/grib/test_grib_sel.py b/tests/grib/test_grib_sel.py index 20690187..a7642a5e 100644 --- a/tests/grib/test_grib_sel.py +++ b/tests/grib/test_grib_sel.py @@ -138,6 +138,7 @@ def test_grib_sel_multi_file(mode): # single resulting field g = f.sel(shortName="t", level=61) + print(f"{g=}") assert len(g) == 1 assert g.metadata(["shortName", "level:l", "typeOfLevel"]) == [["t", 61, "hybrid"]] diff --git a/tests/grib/test_grib_stream.py b/tests/grib/test_grib_stream.py index 508daa0a..b004cf6a 100644 --- a/tests/grib/test_grib_stream.py +++ b/tests/grib/test_grib_stream.py @@ -24,7 +24,7 @@ def repeat_list_items(items, count): @pytest.mark.parametrize( "_kwargs,error", [ - (dict(order_by="level"), TypeError), + # (dict(order_by="level"), TypeError), (dict(group_by=1), TypeError), (dict(group_by=["level", 1]), TypeError), # (dict(group_by="level", batch_size=1), TypeError), diff --git a/tests/grib/test_grib_url_stream.py b/tests/grib/test_grib_url_stream.py index ac5e1095..82612eea 100644 --- a/tests/grib/test_grib_url_stream.py +++ b/tests/grib/test_grib_url_stream.py @@ -24,7 +24,7 @@ def repeat_list_items(items, count): @pytest.mark.parametrize( "_kwargs,error", [ - (dict(order_by="level"), TypeError), + # (dict(order_by="level"), TypeError), (dict(group_by=1), TypeError), (dict(group_by=["level", 1]), TypeError), # (dict(group_by="level", batch_size=1), TypeError), From a40151114aaa1e2dfe1a5e2b952ba8c55c4ea890 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 13 Feb 2024 14:26:14 +0000 Subject: [PATCH 05/18] Impelement array backends for fieldlist --- docs/examples/grib_array_backends.ipynb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/examples/grib_array_backends.ipynb b/docs/examples/grib_array_backends.ipynb index bab9310e..93b28f64 100644 --- a/docs/examples/grib_array_backends.ipynb +++ b/docs/examples/grib_array_backends.ipynb @@ -258,7 +258,7 @@ "tags": [] }, "source": [ - "#### values()" + "#### values" ] }, { @@ -384,7 +384,7 @@ "tags": [] }, "source": [ - "The :py:meth:`Field.to_array() ` and :py:meth:`FieldList.values ` methods return the values in the underlying backend. " + "The :py:meth:`Field.to_array() ` and :py:meth:`FieldList.to_array() ` methods return the values based on the underlying backend. " ] }, { From 8f76af9ba30c6577f8cb31184bc04a2dd54fb573 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Wed, 14 Feb 2024 13:20:52 +0000 Subject: [PATCH 06/18] Impelement array backends for fieldlist --- earthkit/data/core/array.py | 29 ++++-- earthkit/data/core/fieldlist.py | 31 +++--- earthkit/data/sources/array_list.py | 18 ++-- earthkit/data/testing.py | 1 + tests/core/test_array.py | 75 ++++++++++++++ tests/grib/grib_fixtures.py | 79 ++++++++++++-- tests/grib/test_grib_backend.py | 149 +++++++++++++++++++++++++++ tests/grib/test_grib_convert.py | 23 +++-- tests/grib/test_grib_geography.py | 106 ++++++++++--------- tests/grib/test_grib_inidces.py | 41 ++++---- tests/grib/test_grib_metadata.py | 154 ++++++++++++++++------------ tests/grib/test_grib_order_by.py | 40 +++++--- tests/grib/test_grib_sel.py | 115 ++++++++++++--------- tests/grib/test_grib_slice.py | 65 +++++++----- tests/grib/test_grib_summary.py | 72 +++++++------ tests/grib/test_grib_values.py | 106 +++++++++++-------- 16 files changed, 754 insertions(+), 350 deletions(-) create mode 100644 tests/core/test_array.py create mode 100644 tests/grib/test_grib_backend.py diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py index b290df09..9c6fdfca 100644 --- a/earthkit/data/core/array.py +++ b/earthkit/data/core/array.py @@ -108,14 +108,17 @@ def to_array(self, v, backend=None): if backend is self: return v - return backend.to_backend(self, v) + return backend.to_backend(v, self) + else: + b = get_backend(v, strict=False) + return b.to_backend(v, self) @abstractmethod def is_native_array(self, v): pass @abstractmethod - def to_backend(self, backend, v): + def to_backend(self, v, backend): pass @abstractmethod @@ -126,6 +129,10 @@ def from_numpy(self, v): def from_pytorch(self, v): pass + @abstractmethod + def from_other(self, v, **kwargs): + pass + class NumpyBackend(ArrayBackend): _name = "numpy" @@ -150,20 +157,23 @@ def is_native_array(self, v): return isinstance(v, np.ndarray) - def to_backend(self, backend, v): + def to_backend(self, v, backend): return backend.from_numpy(v) def from_numpy(self, v): return v def from_pytorch(self, v): - import torch + return v.numpy() - return torch.to_numpy(v) + def from_other(self, v, **kwargs): + import numpy as np + + return np.array(v, **kwargs) class PytorchBackend(ArrayBackend): - _name = "pytroch" + _name = "pytorch" _array_name = "tensor" def __init__(self): @@ -191,7 +201,7 @@ def is_native_array(self, v): return torch.is_tensor(v) - def to_backend(self, backend, v): + def to_backend(self, v, backend): return backend.from_pytorch(v) def from_numpy(self, v): @@ -202,6 +212,11 @@ def from_numpy(self, v): def from_pytorch(self, v): return v + def from_other(self, v, **kwargs): + import torch + + return torch.tensor(v, **kwargs) + array_backend_types = {"numpy": NumpyBackend, "pytorch": PytorchBackend} diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index 092d622c..731b586d 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -21,18 +21,19 @@ class Field(Base): r"""Represents a Field.""" - raw_backend = NUMPY_BACKEND + raw_values_backend = NUMPY_BACKEND + raw_other_backend = NUMPY_BACKEND def __init__(self, backend, metadata=None): self.__metadata = metadata self.backend = backend - def _to_array(self, v, backend=None): + def _to_array(self, v, backend=None, raw=None): if backend is None: - return self.backend.to_array(v, self.raw_backend) + return self.backend.to_array(v, raw) else: backend = ensure_backend(backend) - return backend.to_array(v, self.raw_backend) + return backend.to_array(v, raw) @abstractmethod def _values(self, dtype=None): @@ -58,7 +59,7 @@ def _values(self, dtype=None): @property def values(self): r"""ndarray: Get the values stored in the field as a 1D ndarray.""" - v = self._to_array(self._values()) + v = self._to_array(self._values(), raw=self.raw_values_backend) if len(v.shape) != 1: n = math.prod(v.shape) n = (n,) @@ -95,7 +96,7 @@ def to_numpy(self, flatten=False, dtype=None): """ v = self._values(dtype=dtype) - NUMPY_BACKEND.to_array(v, self.raw_backend) + v = NUMPY_BACKEND.to_array(v, self.raw_values_backend) shape = self._required_shape(flatten) if shape != v.shape: return v.reshape(shape) @@ -119,7 +120,9 @@ def to_array(self, flatten=False, dtype=None, backend=None): Field values """ - v = self._to_array(self._values(dtype=dtype), backend=backend) + v = self._to_array( + self._values(dtype=dtype), backend=backend, raw=self.raw_values_backend + ) shape = self._required_shape(flatten) if shape != v.shape: return self.backend.array_ns.reshape(v, shape) @@ -186,9 +189,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): """ _keys = dict( - lat=self._metadata.geography.latitudes, - lon=self._metadata.geography.longitudes, - value=self._values, + lat=(self._metadata.geography.latitudes, self.raw_other_backend), + lon=(self._metadata.geography.longitudes, self.raw_other_backend), + value=(self._values, self.raw_values_backend), ) if isinstance(keys, str): @@ -198,7 +201,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): if k not in _keys: raise ValueError(f"data: invalid argument: {k}") - r = [self._to_array(_keys[k](dtype=dtype)) for k in keys] + r = [self._to_array(_keys[k][0](dtype=dtype), raw=_keys[k][1]) for k in keys] shape = self._required_shape(flatten) if shape != r[0].shape: # r = [x.reshape(shape) for x in r] @@ -245,12 +248,10 @@ def to_points(self, flatten=False, dtype=None): x = self._metadata.geography.x(dtype=dtype) y = self._metadata.geography.y(dtype=dtype) if x is not None and y is not None: - x = self._to_array(x) - y = self._to_array(y) + x = self._to_array(x, raw=self.raw_other_backend) + y = self._to_array(y, raw=self.raw_other_backend) shape = self._required_shape(flatten) if shape != x.shape: - # x = x.reshape(shape) - # y = y.reshape(shape) x = self.backend.array_ns.reshape(x, shape) y = self.backend.array_ns.reshape(y, shape) return dict(x=x, y=y) diff --git a/earthkit/data/sources/array_list.py b/earthkit/data/sources/array_list.py index 7683ab68..122e72d9 100644 --- a/earthkit/data/sources/array_list.py +++ b/earthkit/data/sources/array_list.py @@ -8,8 +8,7 @@ # import logging - -import numpy as np +import math from earthkit.data.core.array import get_backend from earthkit.data.core.fieldlist import Field, FieldList @@ -30,18 +29,19 @@ class ArrayField(Field): metadata: :class:`Metadata` Metadata object describing the field metadata. backend: str, ArrayBackend - Array backend. + Array backend. Must match the type of ``array``. """ def __init__(self, array, metadata, backend): super().__init__(backend, metadata=metadata) self._array = array - self.raw_backend = backend + self.raw_values_backend = backend def _make_metadata(self): pass def _values(self, dtype=None): + """native array type""" if dtype is None: return self._array else: @@ -122,7 +122,7 @@ def __init__(self, array, metadata, *args, backend=None, **kwargs): def _shape_match(self, shape1, shape2): if shape1 == shape2: return True - if len(shape1) == 1 and shape1[0] == np.prod(shape2): + if len(shape1) == 1 and shape1[0] == math.prod(shape2): return True return False @@ -209,14 +209,14 @@ def to_fieldlist(self): class ArrayFieldList(ArrayFieldListCore): - r"""Represent a list of :obj:`NumpyField `\ s. + r"""Represent a list of :obj:`ArrayField `\ s. - The preferred way to create a NumpyFieldList is to use either the - static :obj:`from_numpy` method or the :obj:`to_fieldlist` method. + The preferred way to create a ArrayFieldList is to use either the + static :obj:`from_array` method or the :obj:`to_fieldlist` method. See Also -------- - from_numpy + from_array to_fieldlist """ diff --git a/earthkit/data/testing.py b/earthkit/data/testing.py index 655720f1..b5402581 100644 --- a/earthkit/data/testing.py +++ b/earthkit/data/testing.py @@ -101,6 +101,7 @@ def modules_installed(*modules): NO_POLYTOPE = not os.path.exists(os.path.expanduser("~/.polytopeapirc")) NO_ECCOVJSON = not modules_installed("eccovjson") +NO_PYTORCH = not modules_installed("torch") def MISSING(*modules): diff --git a/tests/core/test_array.py b/tests/core/test_array.py new file mode 100644 index 00000000..6b6e59fb --- /dev/null +++ b/tests/core/test_array.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import pytest + +from earthkit.data.core.array import ensure_backend, get_backend +from earthkit.data.testing import NO_PYTORCH + + +def test_core_array_backend_numpy(): + b = ensure_backend("numpy") + assert b.name == "numpy" + + import numpy as np + + v = np.ones(10) + v_lst = [1.0] * 10 + + assert b.is_native_array(v) + assert id(b.from_numpy(v)) == id(v) + assert np.allclose(b.from_other(v_lst, dtype=np.float64), v) + assert get_backend(v) is b + assert get_backend(v, guess=b) is b + + assert np.isclose(b.array_ns.mean(v), 1.0) + + if not NO_PYTORCH: + import torch + + v_pt = torch.ones(10, dtype=torch.float64) + pt_b = ensure_backend("pytorch") + r = b.to_backend(v, pt_b) + assert torch.is_tensor(r) + assert torch.allclose(r, v_pt) + + +@pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") +def test_core_array_backend_pytorch(): + b = ensure_backend("pytorch") + assert b.name == "pytorch" + + import numpy as np + import torch + + v = torch.ones(10, dtype=torch.float64) + v_np = np.ones(10, dtype=np.float64) + v_lst = [1.0] * 10 + + assert b.is_native_array(v) + assert id(b.from_pytorch(v)) == id(v) + assert torch.allclose(b.from_numpy(v_np), v) + assert torch.allclose(b.from_other(v_lst, dtype=torch.float64), v) + assert get_backend(v) is b + assert get_backend(v, guess=b) is b + + np_b = ensure_backend("numpy") + r = b.to_backend(v, np_b) + assert isinstance(r, np.ndarray) + assert np.allclose(r, v_np) + + assert np.isclose(b.array_ns.mean(v), 1.0) + + +if __name__ == "__main__": + from earthkit.data.testing import main + + main(__file__) diff --git a/tests/grib/grib_fixtures.py b/tests/grib/grib_fixtures.py index c923cc52..9ad29033 100644 --- a/tests/grib/grib_fixtures.py +++ b/tests/grib/grib_fixtures.py @@ -12,17 +12,21 @@ from earthkit.data import from_source from earthkit.data.core.fieldlist import FieldList -from earthkit.data.testing import earthkit_examples_file, earthkit_test_data_file +from earthkit.data.testing import ( + NO_PYTORCH, + earthkit_examples_file, + earthkit_test_data_file, +) -def load_numpy_fieldlist(path): - ds = from_source("file", path) - return FieldList.from_numpy( +def load_array_fieldlist(path, backend): + ds = from_source("file", path, backend=backend) + return FieldList.from_array( ds.values, [m.override(generatingProcessIdentifier=120) for m in ds.metadata()] ) -def load_file_or_numpy_fs(filename, mode, folder="example"): +def load_grib_data(filename, fl_type, backend, folder="example"): if folder == "example": path = earthkit_examples_file(filename) elif folder == "data": @@ -30,7 +34,66 @@ def load_file_or_numpy_fs(filename, mode, folder="example"): else: raise ValueError("Invalid folder={folder}") - if mode == "file": - return from_source("file", path) + if fl_type == "file": + return from_source("file", path, backend=backend) + elif fl_type == "array": + return load_array_fieldlist(path, backend) else: - return load_numpy_fieldlist(path) + raise ValueError("Invalid fl_type={fl_type}") + + +def check_numpy_array_type(v, dtype=None): + import numpy as np + + assert isinstance(v, np.ndarray) + if dtype is not None: + if dtype == "float64": + dtype = np.float64 + elif dtype == "float32": + dtype = np.float32 + else: + raise ValueError("Unsupported dtype={dtype}") + assert v.dtype == dtype + + +def check_pytorch_array_type(v, dtype=None): + import torch + + assert torch.is_tensor(v) + if dtype is not None: + if dtype == "float64": + dtype = torch.float64 + elif dtype == "float32": + dtype = torch.float32 + else: + raise ValueError("Unsupported dtype={dtype}") + assert v.dtype == dtype + + +def check_array_type(v, backend, **kwargs): + if backend is None or backend == "numpy": + check_numpy_array_type(v, **kwargs) + elif backend == "pytorch": + check_pytorch_array_type(v, **kwargs) + else: + raise ValueError("Invalid backend={backend}") + + +def get_array_namespace(backend): + from earthkit.data.core.array import ensure_backend + + return ensure_backend(backend).array_ns + + +def get_array(v, backend): + from earthkit.data.core.array import ensure_backend + + b = ensure_backend(backend) + return b.from_other(v) + + +FL_TYPES = ["file", "array"] + +ARRAY_BACKENDS = ["numpy"] +if not NO_PYTORCH: + ARRAY_BACKENDS.append("pytorch") diff --git a/tests/grib/test_grib_backend.py b/tests/grib/test_grib_backend.py new file mode 100644 index 00000000..1c0d0a15 --- /dev/null +++ b/tests/grib/test_grib_backend.py @@ -0,0 +1,149 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import numpy as np +import pytest + +from earthkit.data import FieldList, from_source +from earthkit.data.testing import NO_PYTORCH, earthkit_examples_file + + +@pytest.mark.parametrize("_kwargs", [{}, {"backend": "numpy"}]) +def test_grib_file_numpy_backend(_kwargs): + ds = from_source("file", earthkit_examples_file("test6.grib"), **_kwargs) + + assert len(ds) == 6 + + assert isinstance(ds[0].values, np.ndarray) + assert ds[0].values.shape == (84,) + + assert isinstance(ds.values, np.ndarray) + assert ds.values.shape == ( + 6, + 84, + ) + + assert isinstance(ds[0].to_array(), np.ndarray) + assert ds[0].to_array().shape == (7, 12) + + assert isinstance(ds.to_array(), np.ndarray) + assert ds.to_array().shape == (6, 7, 12) + + assert isinstance(ds[0].to_numpy(), np.ndarray) + assert ds[0].to_numpy().shape == (7, 12) + + assert isinstance(ds.to_numpy(), np.ndarray) + assert ds.to_numpy().shape == (6, 7, 12) + + +@pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") +def test_grib_file_pytorch_backend(): + import torch + + ds = from_source("file", earthkit_examples_file("test6.grib"), backend="pytorch") + + assert len(ds) == 6 + + assert torch.is_tensor(ds[0].values) + assert ds[0].values.shape == (84,) + + assert torch.is_tensor(ds.values) + assert ds.values.shape == ( + 6, + 84, + ) + + assert torch.is_tensor(ds[0].to_array()) + assert ds[0].to_array().shape == (7, 12) + + assert torch.is_tensor(ds.to_array()) + assert ds.to_array().shape == (6, 7, 12) + + assert isinstance(ds[0].to_numpy(), np.ndarray) + assert ds[0].to_numpy().shape == (7, 12) + + assert isinstance(ds.to_numpy(), np.ndarray) + assert ds.to_numpy().shape == (6, 7, 12) + + +def test_grib_array_numpy_backend(): + s = from_source("file", earthkit_examples_file("test6.grib")) + + ds = FieldList.from_array( + s.values, + [m for m in s.metadata()], + ) + assert len(ds) == 6 + with pytest.raises(AttributeError): + ds.path + + assert isinstance(ds[0].values, np.ndarray) + assert ds[0].values.shape == (84,) + + assert isinstance(ds.values, np.ndarray) + assert ds.values.shape == ( + 6, + 84, + ) + + assert isinstance(ds[0].to_array(), np.ndarray) + assert ds[0].to_array().shape == (7, 12) + + assert isinstance(ds.to_array(), np.ndarray) + assert ds.to_array().shape == (6, 7, 12) + + assert isinstance(ds[0].to_numpy(), np.ndarray) + assert ds[0].to_numpy().shape == (7, 12) + + assert isinstance(ds.to_numpy(), np.ndarray) + assert ds.to_numpy().shape == (6, 7, 12) + + +@pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") +def test_grib_array_pytorch_backend(): + import torch + + s = from_source("file", earthkit_examples_file("test6.grib"), backend="pytorch") + + ds = FieldList.from_array( + s.values, + [m for m in s.metadata()], + ) + assert len(ds) == 6 + with pytest.raises(AttributeError): + ds.path + + assert torch.is_tensor(ds[0].values) + assert ds[0].values.shape == (84,) + + assert torch.is_tensor(ds.values) + assert ds.values.shape == ( + 6, + 84, + ) + + assert torch.is_tensor(ds[0].to_array()) + assert ds[0].to_array().shape == (7, 12) + + assert torch.is_tensor(ds.to_array()) + assert ds.to_array().shape == (6, 7, 12) + + assert isinstance(ds[0].to_numpy(), np.ndarray) + assert ds[0].to_numpy().shape == (7, 12) + + assert isinstance(ds.to_numpy(), np.ndarray) + assert ds.to_numpy().shape == (6, 7, 12) + + +if __name__ == "__main__": + from earthkit.data.testing import main + + main() diff --git a/tests/grib/test_grib_convert.py b/tests/grib/test_grib_convert.py index ee9a27cc..2780e073 100644 --- a/tests/grib/test_grib_convert.py +++ b/tests/grib/test_grib_convert.py @@ -17,13 +17,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_icon_to_xarray(mode): +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) +def test_icon_to_xarray(fl_type, backend): # test the conversion to xarray for an icon (unstructured grid) grib file. - g = load_file_or_numpy_fs("test_icon.grib", mode, folder="data") + g = load_grib_data("test_icon.grib", fl_type, backend, folder="data") ds = g.to_xarray() assert len(ds.data_vars) == 1 @@ -33,9 +34,10 @@ def test_icon_to_xarray(mode): assert ds["pres"].sizes["values"] == 6 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_to_xarray_filter_by_keys(mode): - g = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) +def test_to_xarray_filter_by_keys(fl_type, backend): + g = load_grib_data("tuv_pl.grib", fl_type, backend) g = g.sel(param="t", level=500) + g.sel(param="u") assert len(g) > 1 @@ -50,9 +52,10 @@ def test_to_xarray_filter_by_keys(mode): assert r["t"].sizes["isobaricInhPa"] == 1 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_pandas(mode): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) +def test_grib_to_pandas(fl_type, backend): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") # all points df = f.to_pandas() diff --git a/tests/grib/test_grib_geography.py b/tests/grib/test_grib_geography.py index d589a858..b5babd2a 100644 --- a/tests/grib/test_grib_geography.py +++ b/tests/grib/test_grib_geography.py @@ -19,7 +19,12 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ( # noqa: E402 + ARRAY_BACKENDS, + FL_TYPES, + check_array_type, + load_grib_data, +) def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @@ -29,19 +34,18 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): assert np.isclose(v.mean(), meanv, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single(mode, index): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_to_latlon_single(fl_type, backend, index): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") eps = 1e-5 g = f[index] if index is not None else f v = g.to_latlon(flatten=True) assert isinstance(v, dict) - assert isinstance(v["lon"], np.ndarray) - assert isinstance(v["lat"], np.ndarray) - assert v["lon"].dtype == np.float64 - assert v["lat"].dtype == np.float64 + check_array_type(v["lon"], backend, dtype="float64") + check_array_type(v["lat"], backend, dtype="float64") check_array( v["lon"], (84,), @@ -60,34 +64,34 @@ def test_grib_to_latlon_single(mode, index): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single_shape(mode, index): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_to_latlon_single_shape(fl_type, backend, index): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") g = f[index] if index is not None else f v = g.to_latlon() assert isinstance(v, dict) - assert isinstance(v["lon"], np.ndarray) - assert isinstance(v["lat"], np.ndarray) + check_array_type(v["lon"], backend, dtype="float64") + check_array_type(v["lat"], backend, dtype="float64") # x assert v["lon"].shape == (7, 12) - assert v["lon"].dtype == np.float64 for x in v["lon"]: assert np.allclose(x, np.linspace(0, 330, 12)) # y assert v["lat"].shape == (7, 12) - assert v["lon"].dtype == np.float64 for i, y in enumerate(v["lat"]): assert np.allclose(y, np.ones(12) * (90 - i * 30)) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_latlon_multi(mode, dtype): - f = load_file_or_numpy_fs("test.grib", mode) +def test_grib_to_latlon_multi(fl_type, backend, dtype): + f = load_grib_data("test.grib", fl_type, backend) v_ref = f[0].to_latlon(flatten=True, dtype=dtype) v = f.to_latlon(flatten=True, dtype=dtype) @@ -101,29 +105,29 @@ def test_grib_to_latlon_multi(mode, dtype): assert v["lon"].dtype == dtype -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_latlon_multi_non_shared_grid(mode): - f1 = load_file_or_numpy_fs("test.grib", mode) - f2 = load_file_or_numpy_fs("test4.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_to_latlon_multi_non_shared_grid(fl_type, backend): + f1 = load_grib_data("test.grib", fl_type, backend) + f2 = load_grib_data("test4.grib", fl_type, backend) f = f1 + f2 with pytest.raises(ValueError): f.to_latlon() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_points_single(mode, index): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_to_points_single(fl_type, backend, index): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") eps = 1e-5 g = f[index] if index is not None else f v = g.to_points(flatten=True) assert isinstance(v, dict) - assert isinstance(v["x"], np.ndarray) - assert isinstance(v["y"], np.ndarray) - assert v["x"].dtype == np.float64 - assert v["y"].dtype == np.float64 + check_array_type(v["x"], backend, dtype="float64") + check_array_type(v["y"], backend, dtype="float64") check_array( v["x"], (84,), @@ -142,17 +146,19 @@ def test_grib_to_points_single(mode, index): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_points_unsupported_grid(mode): - f = load_file_or_numpy_fs("mercator.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_to_points_unsupported_grid(fl_type, backend): + f = load_grib_data("mercator.grib", fl_type, backend, folder="data") with pytest.raises(ValueError): f[0].to_points() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_points_multi(mode, dtype): - f = load_file_or_numpy_fs("test.grib", mode) +def test_grib_to_points_multi(fl_type, backend, dtype): + f = load_grib_data("test.grib", fl_type, backend) v_ref = f[0].to_points(flatten=True, dtype=dtype) v = f.to_points(flatten=True, dtype=dtype) @@ -166,29 +172,32 @@ def test_grib_to_points_multi(mode, dtype): assert v["y"].dtype == dtype -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_points_multi_non_shared_grid(mode): - f1 = load_file_or_numpy_fs("test.grib", mode) - f2 = load_file_or_numpy_fs("test4.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_to_points_multi_non_shared_grid(fl_type, backend): + f1 = load_grib_data("test.grib", fl_type, backend) + f2 = load_grib_data("test4.grib", fl_type, backend) f = f1 + f2 with pytest.raises(ValueError): f.to_points() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_bbox(mode): - ds = load_file_or_numpy_fs("test.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_bbox(fl_type, backend): + ds = load_grib_data("test.grib", fl_type, backend) bb = ds.bounding_box() assert len(bb) == 2 for b in bb: assert b.as_tuple() == (73, -27, 33, 45) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_projection_ll(mode, index): - f = load_file_or_numpy_fs("test.grib", mode) +def test_grib_projection_ll(fl_type, backend, index): + f = load_grib_data("test.grib", fl_type, backend) if index is not None: g = f[index] @@ -199,9 +208,10 @@ def test_grib_projection_ll(mode, index): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_projection_mercator(mode): - f = load_file_or_numpy_fs("mercator.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_projection_mercator(fl_type, backend): + f = load_grib_data("mercator.grib", fl_type, backend, folder="data") projection = f[0].projection() assert isinstance(projection, projections.Mercator) assert projection.parameters == { diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py index 9d4f52b8..635594ab 100644 --- a/tests/grib/test_grib_inidces.py +++ b/tests/grib/test_grib_inidces.py @@ -16,12 +16,13 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_indices_base(mode): - ds = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_indices_base(fl_type, backend): + ds = load_grib_data("tuv_pl.grib", fl_type, backend) ref = { "class": ["od"], @@ -52,9 +53,10 @@ def test_grib_indices_base(mode): assert r == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_indices_sel(mode): - ds = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_indices_sel(fl_type, backend): + ds = load_grib_data("tuv_pl.grib", fl_type, backend) ref = { "class": ["od"], @@ -81,10 +83,11 @@ def test_grib_indices_sel(mode): assert r == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_indices_multi(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_indices_multi(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") ds = f1 + f2 ref = { @@ -147,10 +150,11 @@ def test_grib_indices_multi(mode): assert r == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_indices_multi_Del(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_indices_multi_Del(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") ds = f1 + f2 ref = { @@ -172,9 +176,10 @@ def test_grib_indices_multi_Del(mode): assert r == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_indices_order_by(mode): - ds = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_indices_order_by(fl_type, backend): + ds = load_grib_data("tuv_pl.grib", fl_type, backend) ref = { "class": ["od"], diff --git a/tests/grib/test_grib_metadata.py b/tests/grib/test_grib_metadata.py index 660ed612..65a8030d 100644 --- a/tests/grib/test_grib_metadata.py +++ b/tests/grib/test_grib_metadata.py @@ -21,7 +21,7 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @@ -35,7 +35,8 @@ def repeat_list_items(items, count): return sum([[x] * count for x in items], []) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -53,15 +54,16 @@ def repeat_list_items(items, count): (("shortName", "level"), ("2t", 0)), ], ) -def test_grib_metadata_grib(mode, key, expected_value): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_metadata_grib(fl_type, backend, key, expected_value): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") sn = f.metadata(key) assert sn == [expected_value] sn = f[0].metadata(key) assert sn == expected_value -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -75,15 +77,16 @@ def test_grib_metadata_grib(mode, key, expected_value): ("level", int, 0), ], ) -def test_grib_metadata_astype_1(mode, key, astype, expected_value): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_metadata_astype_1(fl_type, backend, key, astype, expected_value): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") sn = f.metadata(key, astype=astype) assert sn == [expected_value] sn = f[0].metadata(key, astype=astype) assert sn == expected_value -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fs_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -95,13 +98,15 @@ def test_grib_metadata_astype_1(mode, key, astype, expected_value): ("level:int", repeat_list_items([1000, 850, 700, 500, 400, 300], 3)), ], ) -def test_grib_metadata_18(mode, key, expected_value): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) - sn = f.metadata(key) +def test_grib_metadata_18(fs_type, backend, key, expected_value): + # f = load_grib_data("tuv_pl.grib", mode) + ds = load_grib_data("tuv_pl.grib", fs_type, backend) + sn = ds.metadata(key) assert sn == expected_value -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -119,13 +124,14 @@ def test_grib_metadata_18(mode, key, expected_value): ), ], ) -def test_grib_metadata_astype_18(mode, key, astype, expected_value): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_metadata_astype_18(fl_type, backend, key, astype, expected_value): + f = load_grib_data("tuv_pl.grib", fl_type, backend) sn = f.metadata(key, astype=astype) assert sn == expected_value -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -134,14 +140,15 @@ def test_grib_metadata_astype_18(mode, key, astype, expected_value): ("latitudeOfFirstGridPointInDegrees:float", 90.0), ], ) -def test_grib_metadata_double_1(mode, key, expected_value): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_metadata_double_1(fl_type, backend, key, expected_value): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") r = f.metadata(key) assert len(r) == 1 assert np.isclose(r[0], expected_value) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key", [ @@ -150,15 +157,16 @@ def test_grib_metadata_double_1(mode, key, expected_value): ("latitudeOfFirstGridPointInDegrees:float"), ], ) -def test_grib_metadata_double_18(mode, key): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_metadata_double_18(fl_type, backend, key): + f = load_grib_data("tuv_pl.grib", fl_type, backend) ref = [90.0] * 18 r = f.metadata(key) np.testing.assert_allclose(r, ref, 0.001) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype", [ @@ -166,8 +174,8 @@ def test_grib_metadata_double_18(mode, key): ("latitudeOfFirstGridPointInDegrees", float), ], ) -def test_grib_metadata_double_astype_18(mode, key, astype): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_metadata_double_astype_18(fl_type, backend, key, astype): + f = load_grib_data("tuv_pl.grib", fl_type, backend) ref = [90.0] * 18 @@ -175,10 +183,11 @@ def test_grib_metadata_double_astype_18(mode, key, astype): np.testing.assert_allclose(r, ref, 0.001) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_get_long_array_1(mode): - f = load_file_or_numpy_fs( - "rgg_small_subarea_cellarea_ref.grib", mode, folder="data" +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_get_long_array_1(fl_type, backend): + f = load_grib_data( + "rgg_small_subarea_cellarea_ref.grib", fl_type, backend, folder="data" ) assert len(f) == 1 @@ -193,9 +202,10 @@ def test_grib_get_long_array_1(mode): assert pl[72] == 312 -@pytest.mark.parametrize("mode", ["file"]) -def test_grib_get_double_array_values_1(mode): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("backend", [None]) +def test_grib_get_double_array_values_1(fl_type, backend): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") v = f.metadata("values") assert len(v) == 1 @@ -212,9 +222,10 @@ def test_grib_get_double_array_values_1(mode): ) -@pytest.mark.parametrize("mode", ["file"]) -def test_grib_get_double_array_values_18(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("backend", [None]) +def test_grib_get_double_array_values_18(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) v = f.metadata("values") assert isinstance(v, list) assert len(v) == 18 @@ -242,9 +253,10 @@ def test_grib_get_double_array_values_18(mode): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_get_double_array_1(mode): - f = load_file_or_numpy_fs("ml_data.grib", mode, folder="data")[0] +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_get_double_array_1(fl_type, backend): + f = load_grib_data("ml_data.grib", fl_type, backend, folder="data")[0] # f is now a field! v = f.metadata("pv") assert isinstance(v, np.ndarray) @@ -255,9 +267,10 @@ def test_grib_get_double_array_1(mode): assert np.isclose(v[275], 1.0) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_get_double_array_18(mode): - f = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_get_double_array_18(fl_type, backend): + f = load_grib_data("ml_data.grib", fl_type, backend, folder="data") v = f.metadata("pv") assert isinstance(v, list) assert len(v) == 36 @@ -272,9 +285,10 @@ def test_grib_get_double_array_18(mode): assert np.isclose(v[17][20], 316.4207458496094, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_metadata_type_qualifier(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode)[0:4] +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_metadata_type_qualifier(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend)[0:4] # to str r = f.metadata("centre:s") @@ -311,9 +325,10 @@ def test_grib_metadata_type_qualifier(mode): assert all(isinstance(x, float) for x in r) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_metadata_astype(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode)[0:4] +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_metadata_astype(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend)[0:4] # to str r = f.metadata("centre", astype=None) @@ -345,9 +360,10 @@ def test_grib_metadata_astype(mode): f.metadata(["level", "cfVarName", "centre"], astype=(int, None)) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_metadata_generic(mode): - f_full = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_metadata_generic(fl_type, backend): + f_full = load_grib_data("tuv_pl.grib", fl_type, backend) f = f_full[0:4] @@ -374,9 +390,10 @@ def test_grib_metadata_generic(mode): assert lg == [1000, "t"] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_metadata_missing_value(mode): - f = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_metadata_missing_value(fl_type, backend): + f = load_grib_data("ml_data.grib", fl_type, backend, folder="data") with pytest.raises(KeyError): f[0].metadata("scaleFactorOfSecondFixedSurface") @@ -385,9 +402,10 @@ def test_grib_metadata_missing_value(mode): assert v is None -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_metadata_missing_key(mode): - f = load_file_or_numpy_fs("test.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_metadata_missing_key(fl_type, backend): + f = load_grib_data("test.grib", fl_type, backend) with pytest.raises(KeyError): f[0].metadata("_badkey_") @@ -396,9 +414,10 @@ def test_grib_metadata_missing_key(mode): assert v == 0 -@pytest.mark.parametrize("mode", ["file"]) -def test_grib_metadata_namespace(mode): - f = load_file_or_numpy_fs("test6.grib", mode) +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("backend", [None]) +def test_grib_metadata_namespace(fl_type, backend): + f = load_grib_data("test6.grib", fl_type, backend) r = f[0].metadata(namespace="vertical") ref = {"level": 1000, "typeOfLevel": "isobaricInhPa"} @@ -476,9 +495,10 @@ def test_grib_metadata_namespace(mode): assert "must be a str when key specified" in str(excinfo.value) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_datetime(mode): - s = load_file_or_numpy_fs("test.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_datetime(fl_type, backend): + s = load_grib_data("test.grib", fl_type, backend) ref = { "base_time": [datetime.datetime(2020, 5, 13, 12)], @@ -506,17 +526,19 @@ def test_grib_datetime(mode): assert s.datetime() == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_valid_datetime(mode): - ds = load_file_or_numpy_fs("t_time_series.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_valid_datetime(fl_type, backend): + ds = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") f = ds[4] assert f.metadata("valid_datetime") == datetime.datetime(2020, 12, 21, 18) -@pytest.mark.parametrize("mode", ["file"]) -def test_message(mode): - f = load_file_or_numpy_fs("test.grib", mode) +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("backend", [None]) +def test_message(fl_type, backend): + f = load_grib_data("test.grib", fl_type, backend) v = f[0].message() assert len(v) == 526 assert v[:4] == b"GRIB" diff --git a/tests/grib/test_grib_order_by.py b/tests/grib/test_grib_order_by.py index 414d2cbd..b89dc4ee 100644 --- a/tests/grib/test_grib_order_by.py +++ b/tests/grib/test_grib_order_by.py @@ -19,13 +19,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_order_by_single_message(mode): - s = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_order_by_single_message(fl_type, backend): + s = load_grib_data("test_single.grib", fl_type, backend, folder="data") r = s.order_by("shortName") assert len(r) == 1 @@ -53,7 +54,8 @@ def __call__(self, x, y): return -1 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -100,11 +102,12 @@ def __call__(self, x, y): ], ) def test_grib_order_by_single_file_( - mode, + fl_type, + backend, params, expected_meta, ): - f = load_file_or_numpy_fs("test6.grib", mode) + f = load_grib_data("test6.grib", fl_type, backend) g = f.order_by(params) assert len(g) == len(f) @@ -113,7 +116,8 @@ def test_grib_order_by_single_file_( assert g.metadata(k) == v -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -142,9 +146,9 @@ def test_grib_order_by_single_file_( ), ], ) -def test_grib_order_by_multi_file(mode, params, expected_meta): - f1 = load_file_or_numpy_fs("test4.grib", mode) - f2 = load_file_or_numpy_fs("test6.grib", mode) +def test_grib_order_by_multi_file(fl_type, backend, params, expected_meta): + f1 = load_grib_data("test4.grib", fl_type, backend) + f2 = load_grib_data("test6.grib", fl_type, backend) f = from_source("multi", [f1, f2]) g = f.order_by(params) @@ -154,9 +158,10 @@ def test_grib_order_by_multi_file(mode, params, expected_meta): assert g.metadata(k) == v -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_order_by_with_sel(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_order_by_with_sel(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.sel(level=500) assert len(g) == 3 @@ -171,9 +176,10 @@ def test_grib_order_by_with_sel(mode): assert r.metadata("shortName") == ["v", "u", "t"] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_order_by_valid_datetime(mode): - f = load_file_or_numpy_fs("t_time_series.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_order_by_valid_datetime(fl_type, backend): + f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") g = f.order_by(valid_datetime="descending") assert len(g) == 10 diff --git a/tests/grib/test_grib_sel.py b/tests/grib/test_grib_sel.py index a7642a5e..e667a688 100644 --- a/tests/grib/test_grib_sel.py +++ b/tests/grib/test_grib_sel.py @@ -20,21 +20,23 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_single_message(mode): - s = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_single_message(fl_type, backend): + s = load_grib_data("test_single.grib", fl_type, backend, folder="data") r = s.sel(shortName="2t") assert len(r) == 1 assert r[0].metadata("shortName") == "2t" -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -62,8 +64,8 @@ def test_grib_sel_single_message(mode): ), ], ) -def test_grib_sel_single_file_1(mode, params, expected_meta, metadata_keys): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_sel_single_file_1(fl_type, backend, params, expected_meta, metadata_keys): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.sel(**params) assert len(g) == len(expected_meta) @@ -76,9 +78,10 @@ def test_grib_sel_single_file_1(mode, params, expected_meta, metadata_keys): return -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_single_file_2(mode): - f = load_file_or_numpy_fs("t_time_series.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_single_file_2(fl_type, backend): + f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") g = f.sel(shortName=["t"], step=[3, 6]) assert len(g) == 2 @@ -97,9 +100,10 @@ def test_grib_sel_single_file_2(mode): ] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_single_file_as_dict(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_single_file_as_dict(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.sel({"shortName": "t", "level": [500, 700], "mars.type": "an"}) assert len(g) == 2 @@ -109,7 +113,8 @@ def test_grib_sel_single_file_as_dict(mode): ] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -121,8 +126,8 @@ def test_grib_sel_single_file_as_dict(mode): (131, (slice(510, 520)), []), ], ) -def test_grib_sel_slice_single_file(mode, param_id, level, expected_meta): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_sel_slice_single_file(fl_type, backend, param_id, level, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.sel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -130,10 +135,11 @@ def test_grib_sel_slice_single_file(mode, param_id, level, expected_meta): assert g.metadata(["paramId", "level"]) == expected_meta -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_multi_file(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_multi_file(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -147,10 +153,11 @@ def test_grib_sel_multi_file(mode): assert np.allclose(d, np.zeros(len(d))) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_slice_multi_file(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_slice_multi_file(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") f = from_source("multi", [f1, f2]) @@ -162,10 +169,11 @@ def test_grib_sel_slice_multi_file(mode): ] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_date(mode): +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_date(fl_type, backend): # date and time - f = load_file_or_numpy_fs("t_time_series.grib", mode, folder="data") + f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") g = f.sel(date=20201221, time=1200, step=9) # g = f.sel(date="20201221", time="12", step="9") @@ -180,9 +188,10 @@ def test_grib_sel_date(mode): assert g.metadata(ref_keys) == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_sel_valid_datetime(mode): - f = load_file_or_numpy_fs("t_time_series.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_sel_valid_datetime(fl_type, backend): + f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") g = f.sel(valid_datetime=datetime.datetime(2020, 12, 21, 21)) assert len(g) == 2 @@ -196,16 +205,18 @@ def test_grib_sel_valid_datetime(mode): assert g.metadata(ref_keys) == ref -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_isel_single_message(mode): - s = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_isel_single_message(fl_type, backend): + s = load_grib_data("test_single.grib", fl_type, backend, folder="data") r = s.isel(shortName=0) assert len(r) == 1 assert r[0].metadata("shortName") == "2t" -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -242,8 +253,8 @@ def test_grib_isel_single_message(mode): ), ], ) -def test_grib_isel_single_file(mode, params, expected_meta, metadata_keys): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_isel_single_file(fl_type, backend, params, expected_meta, metadata_keys): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.isel(**params) assert len(g) == len(expected_meta) @@ -255,7 +266,8 @@ def test_grib_isel_single_file(mode, params, expected_meta, metadata_keys): assert g.metadata(keys) == expected_meta -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -267,8 +279,8 @@ def test_grib_isel_single_file(mode, params, expected_meta, metadata_keys): (1, (slice(None, None, 2)), [[131, 850], [131, 500], [131, 300]]), ], ) -def test_grib_isel_slice_single_file(mode, param_id, level, expected_meta): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_isel_slice_single_file(fl_type, backend, param_id, level, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, backend) g = f.isel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -276,9 +288,10 @@ def test_grib_isel_slice_single_file(mode, param_id, level, expected_meta): assert g.metadata(["paramId", "level"]) == expected_meta -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_isel_slice_invalid(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_isel_slice_invalid(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) with pytest.raises(IndexError): f.isel(level=500) @@ -287,10 +300,11 @@ def test_grib_isel_slice_invalid(mode): f.isel(level="a") -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_isel_multi_file(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_isel_multi_file(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -303,10 +317,11 @@ def test_grib_isel_multi_file(mode): assert np.allclose(d, np.zeros(len(d))) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_isel_slice_multi_file(mode): - f1 = load_file_or_numpy_fs("tuv_pl.grib", mode) - f2 = load_file_or_numpy_fs("ml_data.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_isel_slice_multi_file(fl_type, backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, backend) + f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") f = from_source("multi", [f1, f2]) g = f.isel(shortName=1, level=slice(20, 22)) diff --git a/tests/grib/test_grib_slice.py b/tests/grib/test_grib_slice.py index 0dbcbd5a..dca31958 100644 --- a/tests/grib/test_grib_slice.py +++ b/tests/grib/test_grib_slice.py @@ -20,10 +20,11 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "index,expected_meta", [ @@ -34,8 +35,8 @@ (-5, ["u", 400]), ], ) -def test_grib_single_index(mode, index, expected_meta): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_single_index(fl_type, backend, index, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # f = from_source("file", earthkit_examples_file("tuv_pl.grib")) r = f[index] @@ -46,14 +47,16 @@ def test_grib_single_index(mode, index, expected_meta): # assert np.isclose(v[1088], 304.5642, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_single_index_bad(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_single_index_bad(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) with pytest.raises(IndexError): f[27] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes,expected_meta", [ @@ -65,8 +68,8 @@ def test_grib_single_index_bad(mode): (slice(14, None), [["v", 400], ["t", 300], ["u", 300], ["v", 300]]), ], ) -def test_grib_slice_single_file(mode, indexes, expected_meta): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_slice_single_file(fl_type, backend, indexes, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, backend) r = f[indexes] assert len(r) == 4 assert r.metadata(["shortName", "level"]) == expected_meta @@ -101,13 +104,14 @@ def test_grib_slice_multi_file(indexes, expected_meta): assert f.metadata("shortName") == ["2t", "msl", "t", "z", "t", "z"] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes1,indexes2", [(np.array([1, 16, 5, 9]), np.array([1, 3])), ([1, 16, 5, 9], [1, 3])], ) -def test_grib_array_indexing(mode, indexes1, indexes2): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_array_indexing(fl_type, backend, indexes1, indexes2): + f = load_grib_data("tuv_pl.grib", fl_type, backend) r = f[indexes1] assert len(r) == 4 @@ -118,17 +122,19 @@ def test_grib_array_indexing(mode, indexes1, indexes2): assert r1.metadata("shortName") == ["u", "t"] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("indexes", [(np.array([1, 19, 5, 9])), ([1, 19, 5, 9])]) -def test_grib_array_indexing_bad(mode, indexes): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_array_indexing_bad(fl_type, backend, indexes): + f = load_grib_data("tuv_pl.grib", fl_type, backend) with pytest.raises(IndexError): f[indexes] -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_fieldlist_iterator(mode): - g = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator(fl_type, backend): + g = load_grib_data("tuv_pl.grib", fl_type, backend) sn = g.metadata("shortName") assert len(sn) == 18 iter_sn = [f.metadata("shortName") for f in g] @@ -138,12 +144,13 @@ def test_grib_fieldlist_iterator(mode): assert iter_sn == sn -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_fieldlist_iterator_with_zip(mode): +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator_with_zip(fl_type, backend): # test something different to the iterator - does not try to # 'go off the edge' of the fieldlist, because the length is determined by # the list of levels - g = load_file_or_numpy_fs("tuv_pl.grib", mode) + g = load_grib_data("tuv_pl.grib", fl_type, backend) ref_levs = g.metadata("level") assert len(ref_levs) == 18 levs1 = [] @@ -155,10 +162,11 @@ def test_grib_fieldlist_iterator_with_zip(mode): assert levs2 == ref_levs -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_fieldlist_iterator_with_zip_multiple(mode): +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, backend): # same as test_fieldlist_iterator_with_zip() but multiple times - g = load_file_or_numpy_fs("tuv_pl.grib", mode) + g = load_grib_data("tuv_pl.grib", fl_type, backend) ref_levs = g.metadata("level") assert len(ref_levs) == 18 for i in range(2): @@ -171,9 +179,10 @@ def test_grib_fieldlist_iterator_with_zip_multiple(mode): assert levs2 == ref_levs, i -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_fieldlist_reverse_iterator(mode): - g = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_fieldlist_reverse_iterator(fl_type, backend): + g = load_grib_data("tuv_pl.grib", fl_type, backend) sn = g.metadata("shortName") sn_reversed = list(reversed(sn)) assert sn_reversed[0] == "v" diff --git a/tests/grib/test_grib_summary.py b/tests/grib/test_grib_summary.py index 11c5b67a..904f3024 100644 --- a/tests/grib/test_grib_summary.py +++ b/tests/grib/test_grib_summary.py @@ -15,12 +15,13 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_describe(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_describe(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # full contents df = f.describe() @@ -143,9 +144,10 @@ def test_grib_describe(mode): assert ref[0] == df[0].to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # default keys f1 = f[0:4] @@ -197,9 +199,10 @@ def test_grib_ls(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls_keys(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls_keys(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # default keys # positive num (=head) @@ -223,9 +226,10 @@ def test_grib_ls_keys(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls_namespace(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls_namespace(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) df = f.ls(n=2, namespace="vertical") ref = { @@ -244,9 +248,10 @@ def test_grib_ls_namespace(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls_invalid_num(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls_invalid_num(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) with pytest.raises(ValueError): f.ls(n=0) @@ -255,16 +260,18 @@ def test_grib_ls_invalid_num(mode): f.ls(0) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls_invalid_arg(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls_invalid_arg(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) with pytest.raises(TypeError): f.ls(invalid=1) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_ls_num(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_ls_num(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # default keys @@ -309,9 +316,10 @@ def test_grib_ls_num(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_head_num(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_head_num(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # default keys df = f.head(n=2) @@ -334,9 +342,10 @@ def test_grib_head_num(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_tail_num(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_tail_num(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) # default keys df = f.tail(n=2) @@ -359,9 +368,10 @@ def test_grib_tail_num(mode): assert ref == df.to_dict() -@pytest.mark.parametrize("mode", ["file"]) -def test_grib_dump(mode): - f = load_file_or_numpy_fs("test6.grib", mode) +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("backend", [None]) +def test_grib_dump(fl_type, backend): + f = load_grib_data("test6.grib", fl_type, backend) namespaces = ( "default", diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index e57fddd5..eb78496b 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -17,7 +17,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import load_file_or_numpy_fs # noqa: E402 +from grib_fixtures import ( # noqa: E402 + ARRAY_BACKENDS, + FL_TYPES, + check_array_type, + get_array, + get_array_namespace, + load_grib_data, +) def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @@ -27,15 +34,15 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): assert np.isclose(v.mean(), meanv, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_values_1(mode): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_values_1(fl_type, backend): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") eps = 1e-5 # whole file v = f.values - assert isinstance(v, np.ndarray) - assert v.dtype == np.float64 + check_array_type(v, backend, dtype="float64") assert v.shape == (1, 84) v = v[0].flatten() check_array( @@ -49,20 +56,21 @@ def test_grib_values_1(mode): # field v1 = f[0].values - assert isinstance(v1, np.ndarray) + + check_array_type(v1, backend) assert v1.shape == (84,) assert np.allclose(v, v1, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_values_18(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_values_18(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) eps = 1e-5 # whole file v = f.values - assert isinstance(v, np.ndarray) - assert v.dtype == np.float64 + check_array_type(v, backend, dtype="float64") assert v.shape == (18, 84) vf = v[0].flatten() check_array( @@ -85,9 +93,10 @@ def test_grib_values_18(mode): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_numpy_1(mode): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_to_numpy_1(fl_type, backend): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") eps = 1e-5 v = f.to_numpy() @@ -104,7 +113,8 @@ def test_grib_to_numpy_1(mode): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "first,options, expected_shape", [ @@ -116,8 +126,8 @@ def test_grib_to_numpy_1(mode): (True, {"flatten": False}, (7, 12)), ], ) -def test_grib_to_numpy_1_shape(mode, first, options, expected_shape): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_to_numpy_1_shape(fl_type, backend, first, options, expected_shape): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") v_ref = f[0].to_numpy().flatten() eps = 1e-5 @@ -131,9 +141,10 @@ def test_grib_to_numpy_1_shape(mode, first, options, expected_shape): assert np.allclose(v_ref, v1, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_to_numpy_18(mode): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_to_numpy_18(fl_type, backend): + f = load_grib_data("tuv_pl.grib", fl_type, backend) eps = 1e-5 @@ -163,7 +174,8 @@ def test_grib_to_numpy_18(mode): ) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "options, expected_shape", [ @@ -185,8 +197,8 @@ def test_grib_to_numpy_18(mode): ({"flatten": False}, (18, 7, 12)), ], ) -def test_grib_to_numpy_18_shape(mode, options, expected_shape): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_to_numpy_18_shape(fl_type, backend, options, expected_shape): + f = load_grib_data("tuv_pl.grib", fl_type, backend) eps = 1e-5 @@ -210,10 +222,11 @@ def test_grib_to_numpy_18_shape(mode, options, expected_shape): assert np.allclose(vf15, vr, eps) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_1_dtype(mode, dtype): - f = load_file_or_numpy_fs("test_single.grib", mode, folder="data") +def test_grib_to_numpy_1_dtype(fl_type, backend, dtype): + f = load_grib_data("test_single.grib", fl_type, backend, folder="data") v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -222,10 +235,11 @@ def test_grib_to_numpy_1_dtype(mode, dtype): assert v.dtype == dtype -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_18_dtype(mode, dtype): - f = load_file_or_numpy_fs("tuv_pl.grib", mode) +def test_grib_to_numpy_18_dtype(fl_type, backend, dtype): + f = load_grib_data("tuv_pl.grib", fl_type, backend) v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -234,7 +248,8 @@ def test_grib_to_numpy_18_dtype(mode, dtype): assert v.dtype == dtype -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ @@ -247,8 +262,8 @@ def test_grib_to_numpy_18_dtype(mode, dtype): ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_field_data(mode, kwarg, expected_shape, expected_dtype): - ds = load_file_or_numpy_fs("test.grib", mode) +def test_grib_field_data(fl_type, backend, kwarg, expected_shape, expected_dtype): + ds = load_grib_data("test.grib", fl_type, backend) latlon = ds[0].to_latlon(**kwarg) v = ds[0].to_numpy(**kwarg) @@ -285,7 +300,8 @@ def test_grib_field_data(mode, kwarg, expected_shape, expected_dtype): assert np.allclose(d[1], latlon["lon"]) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ["numpy"]) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ @@ -298,8 +314,8 @@ def test_grib_field_data(mode, kwarg, expected_shape, expected_dtype): ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_fieldlist_data(mode, kwarg, expected_shape, expected_dtype): - ds = load_file_or_numpy_fs("test.grib", mode) +def test_grib_fieldlist_data(fl_type, backend, kwarg, expected_shape, expected_dtype): + ds = load_grib_data("test.grib", fl_type, backend) latlon = ds.to_latlon(**kwarg) v = ds.to_numpy(**kwarg) @@ -337,22 +353,26 @@ def test_grib_fieldlist_data(mode, kwarg, expected_shape, expected_dtype): assert np.allclose(d[2], latlon["lon"]) -@pytest.mark.parametrize("mode", ["file", "numpy_fs"]) -def test_grib_values_with_missing(mode): - f = load_file_or_numpy_fs("test_single_with_missing.grib", mode, folder="data") +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_values_with_missing(fl_type, backend): + f = load_grib_data("test_single_with_missing.grib", fl_type, backend, folder="data") v = f[0].values - assert isinstance(v, np.ndarray) + check_array_type(v, backend) assert v.shape == (84,) eps = 0.001 - assert np.count_nonzero(np.isnan(v)) == 38 - mask = np.array([12, 14, 15, 24, 25, 26] + list(range(28, 60))) + + ns = get_array_namespace(backend) + + assert ns.count_nonzero(ns.isnan(v)) == 38 + mask = get_array([12, 14, 15, 24, 25, 26] + list(range(28, 60)), backend) assert np.isclose(v[0], 260.4356, eps) assert np.isclose(v[11], 260.4356, eps) assert np.isclose(v[-1], 227.1856, eps) m = v[mask] assert len(m) == 38 - assert np.count_nonzero(np.isnan(m)) == 38 + assert ns.count_nonzero(ns.isnan(m)) == 38 if __name__ == "__main__": From ecffa7beb5f31388b859eca87d19143688c01226 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Wed, 14 Feb 2024 18:39:02 +0000 Subject: [PATCH 07/18] Impelement array backends for fieldlist --- earthkit/data/core/array.py | 33 ++++++-- earthkit/data/sources/array_list.py | 1 - earthkit/data/testing.py | 53 ++++++++++++ .../array_fl_fixtures.py} | 56 +++++++------ .../test_numpy_fl_write.py} | 70 +++++++++++----- .../test_numpy_fs.py | 26 +++--- .../test_numpy_fs_concat.py | 54 ++++++------ .../test_numpy_fs_metadata.py | 18 ++-- .../test_numpy_fs_summary.py | 6 +- tests/grib/grib_fixtures.py | 84 +++++++++---------- tests/grib/test_grib_geography.py | 8 +- tests/grib/test_grib_inidces.py | 4 +- tests/grib/test_grib_metadata.py | 4 +- tests/grib/test_grib_order_by.py | 3 +- tests/grib/test_grib_output.py | 7 +- tests/grib/test_grib_sel.py | 3 +- tests/grib/test_grib_slice.py | 4 +- tests/grib/test_grib_summary.py | 4 +- tests/grib/test_grib_values.py | 10 +-- 19 files changed, 279 insertions(+), 169 deletions(-) rename tests/{numpy_fs/numpy_fs_fixtures.py => array_fieldlist/array_fl_fixtures.py} (70%) rename tests/{numpy_fs/test_numpy_fs_write.py => array_fieldlist/test_numpy_fl_write.py} (68%) rename tests/{numpy_fs => array_fieldlist}/test_numpy_fs.py (86%) rename tests/{numpy_fs => array_fieldlist}/test_numpy_fs_concat.py (67%) rename tests/{numpy_fs => array_fieldlist}/test_numpy_fs_metadata.py (91%) rename tests/{numpy_fs => array_fieldlist}/test_numpy_fs_summary.py (97%) diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py index 9c6fdfca..e58ba53d 100644 --- a/earthkit/data/core/array.py +++ b/earthkit/data/core/array.py @@ -78,6 +78,7 @@ class ArrayBackend(metaclass=ABCMeta): _default = "numpy" _name = None _array_name = "array" + _dtypes = {} def __init__(self): self.lock = threading.Lock() @@ -113,8 +114,19 @@ def to_array(self, v, backend=None): b = get_backend(v, strict=False) return b.to_backend(v, self) + def to_dtype(self, dtype): + if isinstance(dtype, str): + return self._dtypes.get(dtype, None) + return dtype + + def match_dtype(self, v, dtype): + if dtype is not None: + dtype = self.to_dtype(dtype) + return v.dtype == dtype if dtype is not None else False + return True + @abstractmethod - def is_native_array(self, v): + def is_native_array(self, v, **kwargs): pass @abstractmethod @@ -152,10 +164,17 @@ def _make_array_ns(self): return ns - def is_native_array(self, v): + def to_dtype(self, dtype): + return dtype + + def is_native_array(self, v, dtype=None): import numpy as np - return isinstance(v, np.ndarray) + if not isinstance(v, np.ndarray): + return False + if dtype is not None: + return v.dtype == dtype + return True def to_backend(self, v, backend): return backend.from_numpy(v) @@ -194,12 +213,16 @@ def _make_array_ns(self): except Exception: raise ImportError("pytorch is required to use pytorch backend") + self._dtypes = {"float64": torch.float64, "float32": torch.float32} + return array_api_compat.array_namespace(torch.ones(2)) - def is_native_array(self, v): + def is_native_array(self, v, dtype=None): import torch - return torch.is_tensor(v) + if not torch.is_tensor(v): + return False + return self.match_dtype(v, dtype) def to_backend(self, v, backend): return backend.from_pytorch(v) diff --git a/earthkit/data/sources/array_list.py b/earthkit/data/sources/array_list.py index 122e72d9..87255aa6 100644 --- a/earthkit/data/sources/array_list.py +++ b/earthkit/data/sources/array_list.py @@ -63,7 +63,6 @@ def write(self, f, **kwargs): from earthkit.data.writers import write write(f, self.to_numpy(flatten=True), self._metadata, **kwargs) - # write(f, self.values, self._metadata, **kwargs) class ArrayFieldListCore(PandasMixIn, XarrayMixIn, FieldList): diff --git a/earthkit/data/testing.py b/earthkit/data/testing.py index b5402581..a3ecff86 100644 --- a/earthkit/data/testing.py +++ b/earthkit/data/testing.py @@ -150,6 +150,59 @@ def load_nc_or_xr_source(path, mode): return from_object(xarray.open_dataset(path)) +# def check_numpy_array_type(v, dtype=None): +# import numpy as np + +# assert isinstance(v, np.ndarray) +# if dtype is not None: +# if dtype == "float64": +# dtype = np.float64 +# elif dtype == "float32": +# dtype = np.float32 +# else: +# raise ValueError("Unsupported dtype={dtype}") +# assert v.dtype == dtype + + +# def check_pytorch_array_type(v, dtype=None): +# import torch + +# assert torch.is_tensor(v) +# if dtype is not None: +# if dtype == "float64": +# dtype = torch.float64 +# elif dtype == "float32": +# dtype = torch.float32 +# else: +# raise ValueError("Unsupported dtype={dtype}") +# assert v.dtype == dtype + + +def check_array_type(v, backend, **kwargs): + from earthkit.data.core.array import ensure_backend + + b = ensure_backend(backend) + assert b.is_native_array(v, **kwargs) + + +def get_array_namespace(backend): + from earthkit.data.core.array import ensure_backend + + return ensure_backend(backend).array_ns + + +def get_array(v, backend, **kwargs): + from earthkit.data.core.array import ensure_backend + + b = ensure_backend(backend) + return b.from_other(v, **kwargs) + + +ARRAY_BACKENDS = ["numpy"] +if not NO_PYTORCH: + ARRAY_BACKENDS.append("pytorch") + + def main(path): import sys diff --git a/tests/numpy_fs/numpy_fs_fixtures.py b/tests/array_fieldlist/array_fl_fixtures.py similarity index 70% rename from tests/numpy_fs/numpy_fs_fixtures.py rename to tests/array_fieldlist/array_fl_fixtures.py index 8eb5b5cb..276e18df 100644 --- a/tests/numpy_fs/numpy_fs_fixtures.py +++ b/tests/array_fieldlist/array_fl_fixtures.py @@ -11,15 +11,13 @@ import os -import numpy as np - from earthkit.data import from_source from earthkit.data.core.fieldlist import FieldList from earthkit.data.core.temporary import temp_file -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import earthkit_examples_file, get_array_namespace -def load_numpy_fs(num): +def load_array_fl(num, backend=None): assert num in [1, 2, 3] files = ["test.grib", "test6.grib", "tuv_pl.grib"] files = files[:num] @@ -27,13 +25,15 @@ def load_numpy_fs(num): ds_in = [] md = [] for fname in files: - ds_in.append(from_source("file", earthkit_examples_file(fname))) + ds_in.append( + from_source("file", earthkit_examples_file(fname), backend=backend) + ) md += ds_in[-1].metadata("param") ds = [] for x in ds_in: ds.append( - FieldList.from_numpy( + FieldList.from_array( x.values, [m.override(edition=1) for m in x.metadata()] ) ) @@ -41,23 +41,25 @@ def load_numpy_fs(num): return (*ds, md) -def load_numpy_fs_file(fname): - ds_in = from_source("file", earthkit_examples_file(fname)) +def load_array_fl_file(fname, backend=None): + ds_in = from_source("file", earthkit_examples_file(fname), backend=backend) md = ds_in.metadata("param") - ds = FieldList.from_numpy( + ds = FieldList.from_array( ds_in.values, [m.override(edition=1) for m in ds_in.metadata()] ) return (ds, md) -def check_numpy_fs(ds, ds_input, md_full): +def check_array_fl(ds, ds_input, md_full, backend=None): assert len(ds_input) in [1, 2, 3] + ns = get_array_namespace(backend) + assert len(ds) == len(md_full) assert ds.metadata("param") == md_full - assert np.allclose(ds[0].values, ds_input[0][0].values) + assert ns.allclose(ds[0].values, ds_input[0][0].values) # # values metadata # keys = ["min", "max"] @@ -74,15 +76,15 @@ def check_numpy_fs(ds, ds_input, md_full): assert r.metadata("param") == ["msl", "t"] assert r[0].metadata("param") == "msl" assert r[1].metadata("param") == "t" - assert np.allclose(r[0].values, ds_input[0][1].values) - assert np.allclose(r[1].values, ds_input[1][0].values) + assert ns.allclose(r[0].values, ds_input[0][1].values) + assert ns.allclose(r[1].values, ds_input[1][0].values) # check sel r = ds.sel(shortName="msl") assert len(r) == 1 assert r.metadata("shortName") == ["msl"] assert r[0].metadata("param") == "msl" - assert np.allclose(r[0].values, ds_input[0][1].values) + assert ns.allclose(r[0].values, ds_input[0][1].values) if len(ds_input) == 3: r = ds[1:13:4] @@ -93,19 +95,23 @@ def check_numpy_fs(ds, ds_input, md_full): assert r[2].metadata("param") == "u" -def check_numpy_fs_from_to_fieldlist(ds, ds_input, md_full, flatten=False, dtype=None): +def check_array_fl_from_to_fieldlist( + ds, ds_input, md_full, backend=None, flatten=False, dtype=None +): assert len(ds_input) in [1, 2, 3] assert len(ds) == len(md_full) assert ds.metadata("param") == md_full + ns = get_array_namespace(backend) + np_kwargs = {"flatten": flatten, "dtype": dtype} - assert np.allclose( - ds[0].to_numpy(**np_kwargs), ds_input[0][0].to_numpy(**np_kwargs) + assert ns.allclose( + ds[0].to_array(**np_kwargs), ds_input[0][0].to_array(**np_kwargs) ) - assert ds.to_numpy(**np_kwargs).shape == ds_input[0].to_numpy(**np_kwargs).shape - assert ds._array.shape == ds_input[0].to_numpy(**np_kwargs).shape + assert ds.to_array(**np_kwargs).shape == ds_input[0].to_array(**np_kwargs).shape + assert ds._array.shape == ds_input[0].to_array(**np_kwargs).shape # check slice r = ds[1] @@ -117,11 +123,11 @@ def check_numpy_fs_from_to_fieldlist(ds, ds_input, md_full, flatten=False, dtype assert r.metadata("param") == ["msl", "t"] assert r[0].metadata("param") == "msl" assert r[1].metadata("param") == "t" - assert np.allclose( - r[0].to_numpy(**np_kwargs), ds_input[0][1].to_numpy(**np_kwargs) + assert ns.allclose( + r[0].to_array(**np_kwargs), ds_input[0][1].to_array(**np_kwargs) ) - assert np.allclose( - r[1].to_numpy(**np_kwargs), ds_input[1][0].to_numpy(**np_kwargs) + assert ns.allclose( + r[1].to_array(**np_kwargs), ds_input[1][0].to_array(**np_kwargs) ) # check sel @@ -129,8 +135,8 @@ def check_numpy_fs_from_to_fieldlist(ds, ds_input, md_full, flatten=False, dtype assert len(r) == 1 assert r.metadata("shortName") == ["msl"] assert r[0].metadata("param") == "msl" - assert np.allclose( - r[0].to_numpy(**np_kwargs), ds_input[0][1].to_numpy(**np_kwargs) + assert ns.allclose( + r[0].to_array(**np_kwargs), ds_input[0][1].to_array(**np_kwargs) ) if len(ds_input) == 3: diff --git a/tests/numpy_fs/test_numpy_fs_write.py b/tests/array_fieldlist/test_numpy_fl_write.py similarity index 68% rename from tests/numpy_fs/test_numpy_fs_write.py rename to tests/array_fieldlist/test_numpy_fl_write.py index 5b73a1db..3da88e5d 100644 --- a/tests/numpy_fs/test_numpy_fs_write.py +++ b/tests/array_fieldlist/test_numpy_fl_write.py @@ -19,46 +19,75 @@ from earthkit.data import from_source from earthkit.data.core.fieldlist import FieldList from earthkit.data.core.temporary import temp_file -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import ( + ARRAY_BACKENDS, + check_array_type, + earthkit_examples_file, + get_array_namespace, +) here = os.path.dirname(__file__) sys.path.insert(0, here) -from numpy_fs_fixtures import load_numpy_fs # noqa: E402 +from array_fl_fixtures import load_array_fl # noqa: E402 LOG = logging.getLogger(__name__) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_array_fl_grib_write(backend): + ds = from_source("file", earthkit_examples_file("test.grib"), backend=backend) + ns = get_array_namespace(backend) + + assert ds[0].metadata("shortName") == "2t" + assert len(ds) == 2 + v1 = ds[0].values + 1 + check_array_type(v1, backend) + + md = ds[0].metadata() + md1 = md.override(shortName="msl") + r = FieldList.from_array(v1, md1) + + with temp_file() as tmp: + r.save(tmp) + assert os.path.exists(tmp) + r_tmp = from_source("file", tmp, backend=backend) + v_tmp = r_tmp[0].values + assert ns.allclose(v1, v_tmp) + + +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize("_kwargs", [{}, {"check_nans": True}]) -def test_numpy_fs_grib_write_missing(_kwargs): - ds = from_source("file", earthkit_examples_file("test.grib")) +def test_array_fl_grib_write_missing(backend, _kwargs): + ds = from_source("file", earthkit_examples_file("test.grib"), backend=backend) + ns = get_array_namespace(backend) assert ds[0].metadata("shortName") == "2t" v = ds[0].values v1 = v + 1 - assert not np.isnan(v1[0]) - assert not np.isnan(v1[1]) - v1[0] = np.nan - assert np.isnan(v1[0]) - assert not np.isnan(v1[1]) + assert not ns.isnan(v1[0]) + assert not ns.isnan(v1[1]) + v1[0] = ns.nan + assert ns.isnan(v1[0]) + assert not ns.isnan(v1[1]) md = ds[0].metadata() md1 = md.override(shortName="msl") - r = FieldList.from_numpy(v1, md1) + r = FieldList.from_array(v1, md1) - assert np.isnan(r[0].values[0]) - assert not np.isnan(r[0].values[1]) + assert ns.isnan(r[0].values[0]) + assert not ns.isnan(r[0].values[1]) with temp_file() as tmp: r.save(tmp, **_kwargs) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp) + r_tmp = from_source("file", tmp, backend=backend) v_tmp = r_tmp[0].values - assert np.isnan(v_tmp[0]) - assert not np.isnan(v_tmp[1]) + assert ns.isnan(v_tmp[0]) + assert not ns.isnan(v_tmp[1]) -def test_numpy_fs_grib_write_check_nans_bad(): +def test_array_fl_grib_write_check_nans_bad(): ds = from_source("file", earthkit_examples_file("test.grib")) assert ds[0].metadata("shortName") == "2t" @@ -85,7 +114,7 @@ def test_numpy_fs_grib_write_check_nans_bad(): r.save(tmp, check_nans=False) -def test_numpy_fs_grib_write_append(): +def test_array_fl_grib_write_append(): ds = from_source("file", earthkit_examples_file("test.grib")) assert ds[0].metadata("shortName") == "2t" @@ -118,7 +147,7 @@ def test_numpy_fs_grib_write_append(): assert r_tmp.metadata("shortName") == ["msl", "2d"] -def test_numpy_fs_grib_write_generating_proc_id(): +def test_array_fl_grib_write_generating_proc_id(): ds = from_source("file", earthkit_examples_file("test.grib")) assert ds[0].metadata("shortName") == "2t" @@ -149,11 +178,12 @@ def test_numpy_fs_grib_write_generating_proc_id(): assert np.allclose(r_tmp.values[1], v2) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "_kwargs,expected_value", [({}, 16), ({"bits_per_value": 12}, 12)] ) -def test_numpy_fs_grib_write_bits_per_value(_kwargs, expected_value): - ds, _ = load_numpy_fs(1) +def test_array_fl_grib_write_bits_per_value(backend, _kwargs, expected_value): + ds, _ = load_array_fl(1, backend) with temp_file() as tmp: ds.save(tmp, **_kwargs) diff --git a/tests/numpy_fs/test_numpy_fs.py b/tests/array_fieldlist/test_numpy_fs.py similarity index 86% rename from tests/numpy_fs/test_numpy_fs.py rename to tests/array_fieldlist/test_numpy_fs.py index 7e7ccf25..626695dd 100644 --- a/tests/numpy_fs/test_numpy_fs.py +++ b/tests/array_fieldlist/test_numpy_fs.py @@ -22,13 +22,13 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from numpy_fs_fixtures import ( # noqa: E402 - check_numpy_fs, - check_numpy_fs_from_to_fieldlist, +from array_fl_fixtures import ( # noqa: E402 + check_array_fl, + check_array_fl_from_to_fieldlist, ) -def test_numpy_fs_grib_single_field(): +def test_array_fl_grib_single_field(): ds = from_source("file", earthkit_examples_file("test.grib")) assert ds[0].metadata("shortName") == "2t" @@ -60,7 +60,7 @@ def _check_field(r): _check_field(r_tmp) -def test_numpy_fs_grib_multi_field(): +def test_array_fl_grib_multi_field(): ds = from_source("file", earthkit_examples_file("test.grib")) assert ds[0].metadata("shortName") == "2t" @@ -91,7 +91,7 @@ def test_numpy_fs_grib_multi_field(): assert f.metadata("name") == "2 metre dewpoint temperature", f"name {i}" -def test_numpy_fs_grib_from_list_of_arrays(): +def test_array_fl_grib_from_list_of_arrays(): ds = from_source("file", earthkit_examples_file("test.grib")) md_full = ds.metadata("param") assert len(ds) == 2 @@ -100,10 +100,10 @@ def test_numpy_fs_grib_from_list_of_arrays(): md = [f.metadata().override(generatingProcessIdentifier=150) for f in ds] r = FieldList.from_numpy(v, md) - check_numpy_fs(r, [ds], md_full) + check_array_fl(r, [ds], md_full) -def test_numpy_fs_grib_from_list_of_arrays_bad(): +def test_array_fl_grib_from_list_of_arrays_bad(): ds = from_source("file", earthkit_examples_file("test.grib")) v = ds[0].values @@ -126,28 +126,28 @@ def test_numpy_fs_grib_from_list_of_arrays_bad(): {"flatten": True, "dtype": np.float32}, ], ) -def test_numpy_fs_grib_from_to_fieldlist(kwargs): +def test_array_fl_grib_from_to_fieldlist(kwargs): ds = from_source("file", earthkit_examples_file("test.grib")) md_full = ds.metadata("param") assert len(ds) == 2 r = ds.to_fieldlist("numpy", **kwargs) - check_numpy_fs_from_to_fieldlist(r, [ds], md_full, **kwargs) + check_array_fl_from_to_fieldlist(r, [ds], md_full, **kwargs) -def test_numpy_fs_grib_from_to_fieldlist_repeat(): +def test_array_fl_grib_from_to_fieldlist_repeat(): ds = from_source("file", earthkit_examples_file("test.grib")) md_full = ds.metadata("param") assert len(ds) == 2 kwargs = {} r = ds.to_fieldlist("numpy", **kwargs) - check_numpy_fs_from_to_fieldlist(r, [ds], md_full, **kwargs) + check_array_fl_from_to_fieldlist(r, [ds], md_full, **kwargs) kwargs = {"flatten": True, "dtype": np.float32} r1 = r.to_fieldlist("numpy", **kwargs) assert r1 is not r - check_numpy_fs_from_to_fieldlist(r1, [ds], md_full, **kwargs) + check_array_fl_from_to_fieldlist(r1, [ds], md_full, **kwargs) if __name__ == "__main__": diff --git a/tests/numpy_fs/test_numpy_fs_concat.py b/tests/array_fieldlist/test_numpy_fs_concat.py similarity index 67% rename from tests/numpy_fs/test_numpy_fs_concat.py rename to tests/array_fieldlist/test_numpy_fs_concat.py index b15b543d..3cd98ffb 100644 --- a/tests/numpy_fs/test_numpy_fs_concat.py +++ b/tests/array_fieldlist/test_numpy_fs_concat.py @@ -19,38 +19,38 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from numpy_fs_fixtures import ( # noqa: E402 - check_numpy_fs, +from array_fl_fixtures import ( # noqa: E402 + check_array_fl, check_save_to_disk, - load_numpy_fs, + load_array_fl, ) @pytest.mark.parametrize("mode", ["oper", "multi"]) -def test_numpy_fs_grib_concat_2a(mode): - ds1, ds2, md = load_numpy_fs(2) +def test_array_fl_grib_concat_2a(mode): + ds1, ds2, md = load_array_fl(2) if mode == "oper": ds = ds1 + ds2 else: ds = from_source("multi", ds1, ds2) - check_numpy_fs(ds, [ds1, ds2], md) + check_array_fl(ds, [ds1, ds2], md) check_save_to_disk(ds, 8, md) -def test_numpy_fs_grib_concat_2b(): - ds1, ds2, md = load_numpy_fs(2) +def test_array_fl_grib_concat_2b(): + ds1, ds2, md = load_array_fl(2) ds1_ori = ds1 ds1 += ds2 - check_numpy_fs(ds1, [ds1_ori, ds2], md) + check_array_fl(ds1, [ds1_ori, ds2], md) check_save_to_disk(ds1, 8, md) @pytest.mark.parametrize("mode", ["oper", "multi"]) -def test_numpy_fs_grib_concat_3a(mode): - ds1, ds2, ds3, md = load_numpy_fs(3) +def test_array_fl_grib_concat_3a(mode): + ds1, ds2, ds3, md = load_array_fl(3) if mode == "oper": ds = ds1 + ds2 @@ -59,26 +59,26 @@ def test_numpy_fs_grib_concat_3a(mode): ds = from_source("multi", ds1, ds2) ds = from_source("multi", ds, ds3) - check_numpy_fs(ds, [ds1, ds2, ds3], md) + check_array_fl(ds, [ds1, ds2, ds3], md) check_save_to_disk(ds, 26, md) @pytest.mark.parametrize("mode", ["oper", "multi"]) -def test_numpy_fs_grib_concat_3b(mode): - ds1, ds2, ds3, md = load_numpy_fs(3) +def test_array_fl_grib_concat_3b(mode): + ds1, ds2, ds3, md = load_array_fl(3) if mode == "oper": ds = ds1 + ds2 + ds3 else: ds = from_source("multi", ds1, ds2, ds3) - check_numpy_fs(ds, [ds1, ds2, ds3], md) + check_array_fl(ds, [ds1, ds2, ds3], md) check_save_to_disk(ds, 26, md) -def test_numpy_fs_grib_from_empty_1(): +def test_array_fl_grib_from_empty_1(): ds_e = FieldList() - ds, md = load_numpy_fs(1) + ds, md = load_array_fl(1) ds1 = ds_e + ds assert id(ds1) == id(ds) assert len(ds1) == 2 @@ -86,9 +86,9 @@ def test_numpy_fs_grib_from_empty_1(): check_save_to_disk(ds1, 2, md) -def test_numpy_fs_grib_from_empty_2(): +def test_array_fl_grib_from_empty_2(): ds_e = FieldList() - ds, md = load_numpy_fs(1) + ds, md = load_array_fl(1) ds1 = ds + ds_e assert id(ds1) == id(ds) assert len(ds1) == 2 @@ -96,18 +96,18 @@ def test_numpy_fs_grib_from_empty_2(): check_save_to_disk(ds1, 2, md) -def test_numpy_fs_grib_from_empty_3(): +def test_array_fl_grib_from_empty_3(): ds_e = FieldList() - ds1, ds2, md = load_numpy_fs(2) + ds1, ds2, md = load_array_fl(2) ds = ds_e + ds1 + ds2 - check_numpy_fs(ds, [ds1, ds2], md) + check_array_fl(ds, [ds1, ds2], md) check_save_to_disk(ds, 8, md) -def test_numpy_fs_grib_from_empty_4(): +def test_array_fl_grib_from_empty_4(): ds = FieldList() - ds1, md = load_numpy_fs(1) + ds1, md = load_array_fl(1) ds += ds1 assert id(ds) == id(ds1) assert len(ds) == 2 @@ -115,12 +115,12 @@ def test_numpy_fs_grib_from_empty_4(): check_save_to_disk(ds, 2, md) -def test_numpy_fs_grib_from_empty_5(): +def test_array_fl_grib_from_empty_5(): ds = FieldList() - ds1, ds2, md = load_numpy_fs(2) + ds1, ds2, md = load_array_fl(2) ds += ds1 + ds2 - check_numpy_fs(ds, [ds1, ds2], md) + check_array_fl(ds, [ds1, ds2], md) check_save_to_disk(ds, 8, md) diff --git a/tests/numpy_fs/test_numpy_fs_metadata.py b/tests/array_fieldlist/test_numpy_fs_metadata.py similarity index 91% rename from tests/numpy_fs/test_numpy_fs_metadata.py rename to tests/array_fieldlist/test_numpy_fs_metadata.py index 52672a16..d6dac338 100644 --- a/tests/numpy_fs/test_numpy_fs_metadata.py +++ b/tests/array_fieldlist/test_numpy_fs_metadata.py @@ -16,14 +16,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from numpy_fs_fixtures import load_numpy_fs, load_numpy_fs_file # noqa: E402 +from array_fl_fixtures import load_array_fl, load_array_fl_file # noqa: E402 # Note: Almost all grib metadata tests are also run for numpyfs. # See grib/test_grib_metadata.py -def test_numpy_fs_values_metadata(): - ds, _ = load_numpy_fs(1) +def test_array_fl_values_metadata(): + ds, _ = load_array_fl(1) # values metadata keys = [ @@ -45,8 +45,8 @@ def test_numpy_fs_values_metadata(): ds[0].metadata(k) -def test_numpy_fs_values_metadata_internal(): - ds, _ = load_numpy_fs(1) +def test_array_fl_values_metadata_internal(): + ds, _ = load_array_fl(1) keys = { "shortName": "2t", @@ -57,8 +57,8 @@ def test_numpy_fs_values_metadata_internal(): assert ds[0].metadata(k) == v, k -def test_numpy_fs_metadata_keys(): - ds, _ = load_numpy_fs(1) +def test_array_fl_metadata_keys(): + ds, _ = load_array_fl(1) # The number/order of metadata keys can vary with the ecCodes version. # The same is true for the namespaces. @@ -90,8 +90,8 @@ def test_numpy_fs_metadata_keys(): assert "validityDate" in md -def test_numpy_fs_metadata_namespace(): - f, _ = load_numpy_fs_file("tuv_pl.grib") +def test_array_fl_metadata_namespace(): + f, _ = load_array_fl_file("tuv_pl.grib") r = f[0].metadata(namespace="vertical") ref = {"level": 1000, "typeOfLevel": "isobaricInhPa"} diff --git a/tests/numpy_fs/test_numpy_fs_summary.py b/tests/array_fieldlist/test_numpy_fs_summary.py similarity index 97% rename from tests/numpy_fs/test_numpy_fs_summary.py rename to tests/array_fieldlist/test_numpy_fs_summary.py index 8b1646fd..0d44762a 100644 --- a/tests/numpy_fs/test_numpy_fs_summary.py +++ b/tests/array_fieldlist/test_numpy_fs_summary.py @@ -14,14 +14,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from numpy_fs_fixtures import load_numpy_fs_file # noqa: E402 +from array_fl_fixtures import load_array_fl_file # noqa: E402 # Note: Almost all grib metadata tests are also run for numpyfs. # See grib/test_grib_summary.py -def test_numpy_fs_dump(): - f, _ = load_numpy_fs_file("test6.grib") +def test_array_fl_dump(): + f, _ = load_array_fl_file("test6.grib") namespaces = ( "default", diff --git a/tests/grib/grib_fixtures.py b/tests/grib/grib_fixtures.py index 9ad29033..fc4c2975 100644 --- a/tests/grib/grib_fixtures.py +++ b/tests/grib/grib_fixtures.py @@ -12,11 +12,7 @@ from earthkit.data import from_source from earthkit.data.core.fieldlist import FieldList -from earthkit.data.testing import ( - NO_PYTORCH, - earthkit_examples_file, - earthkit_test_data_file, -) +from earthkit.data.testing import earthkit_examples_file, earthkit_test_data_file def load_array_fieldlist(path, backend): @@ -42,58 +38,58 @@ def load_grib_data(filename, fl_type, backend, folder="example"): raise ValueError("Invalid fl_type={fl_type}") -def check_numpy_array_type(v, dtype=None): - import numpy as np +# def check_numpy_array_type(v, dtype=None): +# import numpy as np - assert isinstance(v, np.ndarray) - if dtype is not None: - if dtype == "float64": - dtype = np.float64 - elif dtype == "float32": - dtype = np.float32 - else: - raise ValueError("Unsupported dtype={dtype}") - assert v.dtype == dtype +# assert isinstance(v, np.ndarray) +# if dtype is not None: +# if dtype == "float64": +# dtype = np.float64 +# elif dtype == "float32": +# dtype = np.float32 +# else: +# raise ValueError("Unsupported dtype={dtype}") +# assert v.dtype == dtype -def check_pytorch_array_type(v, dtype=None): - import torch +# def check_pytorch_array_type(v, dtype=None): +# import torch - assert torch.is_tensor(v) - if dtype is not None: - if dtype == "float64": - dtype = torch.float64 - elif dtype == "float32": - dtype = torch.float32 - else: - raise ValueError("Unsupported dtype={dtype}") - assert v.dtype == dtype +# assert torch.is_tensor(v) +# if dtype is not None: +# if dtype == "float64": +# dtype = torch.float64 +# elif dtype == "float32": +# dtype = torch.float32 +# else: +# raise ValueError("Unsupported dtype={dtype}") +# assert v.dtype == dtype -def check_array_type(v, backend, **kwargs): - if backend is None or backend == "numpy": - check_numpy_array_type(v, **kwargs) - elif backend == "pytorch": - check_pytorch_array_type(v, **kwargs) - else: - raise ValueError("Invalid backend={backend}") +# def check_array_type(v, backend, **kwargs): +# if backend is None or backend == "numpy": +# check_numpy_array_type(v, **kwargs) +# elif backend == "pytorch": +# check_pytorch_array_type(v, **kwargs) +# else: +# raise ValueError("Invalid backend={backend}") -def get_array_namespace(backend): - from earthkit.data.core.array import ensure_backend +# def get_array_namespace(backend): +# from earthkit.data.core.array import ensure_backend - return ensure_backend(backend).array_ns +# return ensure_backend(backend).array_ns -def get_array(v, backend): - from earthkit.data.core.array import ensure_backend +# def get_array(v, backend): +# from earthkit.data.core.array import ensure_backend - b = ensure_backend(backend) - return b.from_other(v) +# b = ensure_backend(backend) +# return b.from_other(v) FL_TYPES = ["file", "array"] -ARRAY_BACKENDS = ["numpy"] -if not NO_PYTORCH: - ARRAY_BACKENDS.append("pytorch") +# ARRAY_BACKENDS = ["numpy"] +# if not NO_PYTORCH: +# ARRAY_BACKENDS.append("pytorch") diff --git a/tests/grib/test_grib_geography.py b/tests/grib/test_grib_geography.py index b5babd2a..a53a5d59 100644 --- a/tests/grib/test_grib_geography.py +++ b/tests/grib/test_grib_geography.py @@ -15,16 +15,12 @@ import numpy as np import pytest +from earthkit.data.testing import ARRAY_BACKENDS, check_array_type from earthkit.data.utils import projections here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ( # noqa: E402 - ARRAY_BACKENDS, - FL_TYPES, - check_array_type, - load_grib_data, -) +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py index 635594ab..5a58b095 100644 --- a/tests/grib/test_grib_inidces.py +++ b/tests/grib/test_grib_inidces.py @@ -14,9 +14,11 @@ import pytest +from earthkit.data.testing import ARRAY_BACKENDS + here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 @pytest.mark.parametrize("fl_type", FL_TYPES) diff --git a/tests/grib/test_grib_metadata.py b/tests/grib/test_grib_metadata.py index 65a8030d..07c96ba5 100644 --- a/tests/grib/test_grib_metadata.py +++ b/tests/grib/test_grib_metadata.py @@ -17,11 +17,11 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import ARRAY_BACKENDS, earthkit_examples_file here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): diff --git a/tests/grib/test_grib_order_by.py b/tests/grib/test_grib_order_by.py index b89dc4ee..f36c01f5 100644 --- a/tests/grib/test_grib_order_by.py +++ b/tests/grib/test_grib_order_by.py @@ -16,10 +16,11 @@ import pytest from earthkit.data import from_source +from earthkit.data.testing import ARRAY_BACKENDS here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") diff --git a/tests/grib/test_grib_output.py b/tests/grib/test_grib_output.py index a7ffa272..d708d04b 100644 --- a/tests/grib/test_grib_output.py +++ b/tests/grib/test_grib_output.py @@ -20,13 +20,14 @@ import earthkit.data from earthkit.data import from_source from earthkit.data.core.temporary import temp_file -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import ARRAY_BACKENDS, earthkit_examples_file EPSILON = 1e-4 -def test_grib_save_when_loaded_from_file(): - fs = from_source("file", earthkit_examples_file("test6.grib")) +@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +def test_grib_save_when_loaded_from_file(backend): + fs = from_source("file", earthkit_examples_file("test6.grib"), backend=backend) assert len(fs) == 6 with temp_file() as tmp: fs.save(tmp) diff --git a/tests/grib/test_grib_sel.py b/tests/grib/test_grib_sel.py index e667a688..e3a2df1a 100644 --- a/tests/grib/test_grib_sel.py +++ b/tests/grib/test_grib_sel.py @@ -17,10 +17,11 @@ import pytest from earthkit.data import from_source +from earthkit.data.testing import ARRAY_BACKENDS here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") diff --git a/tests/grib/test_grib_slice.py b/tests/grib/test_grib_slice.py index dca31958..e17356d0 100644 --- a/tests/grib/test_grib_slice.py +++ b/tests/grib/test_grib_slice.py @@ -16,11 +16,11 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import ARRAY_BACKENDS, earthkit_examples_file here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 @pytest.mark.parametrize("fl_type", FL_TYPES) diff --git a/tests/grib/test_grib_summary.py b/tests/grib/test_grib_summary.py index 904f3024..da3bf54e 100644 --- a/tests/grib/test_grib_summary.py +++ b/tests/grib/test_grib_summary.py @@ -13,9 +13,11 @@ import pytest +from earthkit.data.testing import ARRAY_BACKENDS + here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import ARRAY_BACKENDS, FL_TYPES, load_grib_data # noqa: E402 +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 @pytest.mark.parametrize("fl_type", FL_TYPES) diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index eb78496b..aa7a09a9 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -15,17 +15,17 @@ import numpy as np import pytest -here = os.path.dirname(__file__) -sys.path.insert(0, here) -from grib_fixtures import ( # noqa: E402 +from earthkit.data.testing import ( ARRAY_BACKENDS, - FL_TYPES, check_array_type, get_array, get_array_namespace, - load_grib_data, ) +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import FL_TYPES, load_grib_data # noqa: E402 + def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): assert v.shape == shape From 372fe9d5c9e15a292a51fc6ad94371969889cf75 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Sun, 18 Feb 2024 17:43:48 +0000 Subject: [PATCH 08/18] Impelement array backends for fieldlist --- Untitled.ipynb | 33 ++ docs/examples/grib_array_backends.ipynb | 355 +++++++++++-------- earthkit/data/core/array.py | 272 -------------- earthkit/data/core/fieldlist.py | 286 ++++++++++----- earthkit/data/readers/grib/index/__init__.py | 8 +- earthkit/data/readers/grib/memory.py | 24 +- earthkit/data/readers/grib/reader.py | 6 +- earthkit/data/readers/netcdf.py | 16 +- earthkit/data/sources/array_list.py | 52 +-- earthkit/data/sources/constants.py | 2 +- earthkit/data/sources/list_of_dicts.py | 2 +- earthkit/data/sources/numpy_list.py | 5 +- earthkit/data/testing.py | 36 +- earthkit/data/utils/array/__init__.py | 214 +++++++++++ earthkit/data/utils/array/numpy.py | 60 ++++ earthkit/data/utils/array/pytorch.py | 65 ++++ tests/array_fieldlist/array_fl_fixtures.py | 20 +- tests/array_fieldlist/test_numpy_fl_write.py | 32 +- tests/array_fieldlist/test_numpy_fs.py | 6 +- tests/documentation/test_notebooks.py | 6 +- tests/grib/grib_fixtures.py | 64 +--- tests/grib/test_grib_backend.py | 20 +- tests/grib/test_grib_convert.py | 18 +- tests/grib/test_grib_geography.py | 82 ++--- tests/grib/test_grib_inidces.py | 34 +- tests/grib/test_grib_metadata.py | 126 +++---- tests/grib/test_grib_order_by.py | 32 +- tests/grib/test_grib_output.py | 8 +- tests/grib/test_grib_sel.py | 106 +++--- tests/grib/test_grib_slice.py | 54 +-- tests/grib/test_grib_stream.py | 23 +- tests/grib/test_grib_summary.py | 60 ++-- tests/grib/test_grib_url_stream.py | 2 +- tests/grib/test_grib_values.py | 86 ++--- tests/{core => utils}/test_array.py | 6 +- 35 files changed, 1218 insertions(+), 1003 deletions(-) create mode 100644 Untitled.ipynb delete mode 100644 earthkit/data/core/array.py create mode 100644 earthkit/data/utils/array/__init__.py create mode 100644 earthkit/data/utils/array/numpy.py create mode 100644 earthkit/data/utils/array/pytorch.py rename tests/{core => utils}/test_array.py (93%) diff --git a/Untitled.ipynb b/Untitled.ipynb new file mode 100644 index 00000000..8c538c3f --- /dev/null +++ b/Untitled.ipynb @@ -0,0 +1,33 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "a922d141-d705-453e-9816-0899283d9cbb", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "dev_ecc", + "language": "python", + "name": "dev_ecc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/grib_array_backends.ipynb b/docs/examples/grib_array_backends.ipynb index 93b28f64..c7c62774 100644 --- a/docs/examples/grib_array_backends.ipynb +++ b/docs/examples/grib_array_backends.ipynb @@ -57,7 +57,7 @@ "tags": [] }, "source": [ - "When reading GRIB data with :func:`from_source` we can specify the array ``backend`` we want to use when extracting the field values. The default backend is \"numpy\". For this example we choose the \"pytorch\" backend. Since pytorch is an optional dependency for earthkit-data we need to ensure it is installed in the environment." + "When reading GRIB data with :func:`from_source` we can specify the ``array_backend`` we want to use when extracting the field values. The default backend is \"numpy\". For this example we choose the \"pytorch\" backend. Since pytorch is an optional dependency for earthkit-data we need to ensure it is installed in the environment. We also need to install \"array_api_compat\" to make the array backends work." ] }, { @@ -73,7 +73,8 @@ }, "outputs": [], "source": [ - "!pip install torch --quiet" + "!pip install torch --quiet\n", + "!pip install array_api_compat --quiet" ] }, { @@ -89,7 +90,7 @@ }, "outputs": [], "source": [ - "ds = earthkit.data.from_source(\"file\", \"test6.grib\", backend=\"pytorch\")" + "ds = earthkit.data.from_source(\"file\", \"test4.grib\", array_backend=\"pytorch\")" ] }, { @@ -143,8 +144,8 @@ " ecmf\n", " t\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -154,10 +155,10 @@ " \n", " 1\n", " ecmf\n", - " u\n", + " z\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -167,23 +168,10 @@ " \n", " 2\n", " ecmf\n", - " v\n", - " isobaricInhPa\n", - " 1000\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 3\n", - " ecmf\n", " t\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -191,25 +179,12 @@ " regular_ll\n", " \n", " \n", - " 4\n", - " ecmf\n", - " u\n", - " isobaricInhPa\n", - " 850\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 5\n", + " 3\n", " ecmf\n", - " v\n", + " z\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -222,20 +197,16 @@ ], "text/plain": [ " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", - "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", - "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", - "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", - "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", - "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", - "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "0 ecmf t isobaricInhPa 500 20070101 1200 0 \n", + "1 ecmf z isobaricInhPa 500 20070101 1200 0 \n", + "2 ecmf t isobaricInhPa 850 20070101 1200 0 \n", + "3 ecmf z isobaricInhPa 850 20070101 1200 0 \n", "\n", " dataType number gridType \n", "0 an 0 regular_ll \n", "1 an 0 regular_ll \n", "2 an 0 regular_ll \n", - "3 an 0 regular_ll \n", - "4 an 0 regular_ll \n", - "5 an 0 regular_ll " + "3 an 0 regular_ll " ] }, "execution_count": 4, @@ -291,8 +262,8 @@ { "data": { "text/plain": [ - "tensor([272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642,\n", - " 272.5642, 272.5642, 272.5642], dtype=torch.float64)" + "tensor([228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460,\n", + " 228.0460, 228.0460, 228.0460], dtype=torch.float64)" ] }, "execution_count": 5, @@ -319,7 +290,7 @@ { "data": { "text/plain": [ - "torch.Size([84])" + "torch.Size([65160])" ] }, "execution_count": 6, @@ -346,7 +317,7 @@ { "data": { "text/plain": [ - "torch.Size([6, 84])" + "torch.Size([4, 65160])" ] }, "execution_count": 7, @@ -402,8 +373,8 @@ { "data": { "text/plain": [ - "tensor([[272.5642, 272.5642],\n", - " [288.5642, 296.5642]], dtype=torch.float64)" + "tensor([[228.0460, 228.0460],\n", + " [228.6085, 228.5792]], dtype=torch.float64)" ] }, "execution_count": 8, @@ -424,7 +395,7 @@ { "data": { "text/plain": [ - "torch.Size([6, 7, 12])" + "torch.Size([4, 181, 360])" ] }, "execution_count": 9, @@ -451,7 +422,7 @@ { "data": { "text/plain": [ - "torch.Size([6, 84])" + "torch.Size([4, 65160])" ] }, "execution_count": 10, @@ -507,10 +478,10 @@ { "data": { "text/html": [ - "ArrayFieldList(fields=6)" + "ArrayFieldList(fields=4)" ], "text/plain": [ - "ArrayFieldList(fields=6)" + "ArrayFieldList(fields=4)" ] }, "execution_count": 11, @@ -574,8 +545,8 @@ " ecmf\n", " t\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -585,10 +556,10 @@ " \n", " 1\n", " ecmf\n", - " u\n", + " z\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -598,36 +569,10 @@ " \n", " 2\n", " ecmf\n", - " v\n", - " isobaricInhPa\n", - " 1000\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 3\n", - " ecmf\n", " t\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 4\n", - " ecmf\n", - " u\n", - " isobaricInhPa\n", - " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -635,12 +580,12 @@ " regular_ll\n", " \n", " \n", - " 5\n", + " 3\n", " ecmf\n", - " v\n", + " z\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -653,20 +598,16 @@ ], "text/plain": [ " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", - "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", - "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", - "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", - "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", - "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", - "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "0 ecmf t isobaricInhPa 500 20070101 1200 0 \n", + "1 ecmf z isobaricInhPa 500 20070101 1200 0 \n", + "2 ecmf t isobaricInhPa 850 20070101 1200 0 \n", + "3 ecmf z isobaricInhPa 850 20070101 1200 0 \n", "\n", " dataType number gridType \n", "0 an 0 regular_ll \n", "1 an 0 regular_ll \n", "2 an 0 regular_ll \n", - "3 an 0 regular_ll \n", - "4 an 0 regular_ll \n", - "5 an 0 regular_ll " + "3 an 0 regular_ll " ] }, "execution_count": 12, @@ -693,8 +634,8 @@ { "data": { "text/plain": [ - "tensor([272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642, 272.5642,\n", - " 272.5642, 272.5642, 272.5642], dtype=torch.float64)" + "tensor([228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460,\n", + " 228.0460, 228.0460, 228.0460], dtype=torch.float64)" ] }, "execution_count": 13, @@ -707,17 +648,18 @@ ] }, { - "cell_type": "markdown", - "id": "a78665fd-9a37-456a-8fda-a11c358aba64", + "cell_type": "raw", + "id": "3cdcbb75-86b1-4c38-b667-73ac563c8d97", "metadata": { "editable": true, + "raw_mimetype": "text/restructuredtext", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ - "Whe can build a new ArrayFiedlList straight from metadata and array values. This can be used for computations, when we want to alter the values and store the result in a new FieldList." + "Whe can build a new :py:class:`~data.sources.array_list.ArrayFieldList` straight from metadata and array values. This can be used for computations when we want to alter the values and store the result in a new FieldList." ] }, { @@ -771,8 +713,8 @@ " ecmf\n", " t\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -782,10 +724,10 @@ " \n", " 1\n", " ecmf\n", - " u\n", + " z\n", " isobaricInhPa\n", - " 1000\n", - " 20180801\n", + " 500\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -795,36 +737,10 @@ " \n", " 2\n", " ecmf\n", - " v\n", - " isobaricInhPa\n", - " 1000\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 3\n", - " ecmf\n", " t\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", - " 1200\n", - " 0\n", - " an\n", - " 0\n", - " regular_ll\n", - " \n", - " \n", - " 4\n", - " ecmf\n", - " u\n", - " isobaricInhPa\n", - " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -832,12 +748,12 @@ " regular_ll\n", " \n", " \n", - " 5\n", + " 3\n", " ecmf\n", - " v\n", + " z\n", " isobaricInhPa\n", " 850\n", - " 20180801\n", + " 20070101\n", " 1200\n", " 0\n", " an\n", @@ -850,20 +766,16 @@ ], "text/plain": [ " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", - "0 ecmf t isobaricInhPa 1000 20180801 1200 0 \n", - "1 ecmf u isobaricInhPa 1000 20180801 1200 0 \n", - "2 ecmf v isobaricInhPa 1000 20180801 1200 0 \n", - "3 ecmf t isobaricInhPa 850 20180801 1200 0 \n", - "4 ecmf u isobaricInhPa 850 20180801 1200 0 \n", - "5 ecmf v isobaricInhPa 850 20180801 1200 0 \n", + "0 ecmf t isobaricInhPa 500 20070101 1200 0 \n", + "1 ecmf z isobaricInhPa 500 20070101 1200 0 \n", + "2 ecmf t isobaricInhPa 850 20070101 1200 0 \n", + "3 ecmf z isobaricInhPa 850 20070101 1200 0 \n", "\n", " dataType number gridType \n", "0 an 0 regular_ll \n", "1 an 0 regular_ll \n", "2 an 0 regular_ll \n", - "3 an 0 regular_ll \n", - "4 an 0 regular_ll \n", - "5 an 0 regular_ll " + "3 an 0 regular_ll " ] }, "execution_count": 14, @@ -889,7 +801,7 @@ "tags": [] }, "source": [ - "As expected, the values are now differing by 2 from the ones in the originial FieldList." + "As expected, the values in *r1* are now differing by 2 from the ones in the originial FieldList (*r*)." ] }, { @@ -907,8 +819,8 @@ { "data": { "text/plain": [ - "tensor([274.5642, 274.5642, 274.5642, 274.5642, 274.5642, 274.5642, 274.5642,\n", - " 274.5642, 274.5642, 274.5642], dtype=torch.float64)" + "tensor([230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460,\n", + " 230.0460, 230.0460, 230.0460], dtype=torch.float64)" ] }, "execution_count": 15, @@ -920,10 +832,25 @@ "r1[0].values[:10]" ] }, + { + "cell_type": "raw", + "id": "dc2332e6-434c-406f-9f03-567e16afc876", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "We can save the :py:class:`~data.sources.numpy_list.ArrayFieldList` into a GRIB file:" + ] + }, { "cell_type": "code", - "execution_count": null, - "id": "3e44a0c8-dbf8-4454-8519-c57330cc4d71", + "execution_count": 16, + "id": "9ceea7b3-5059-4378-9a26-d282caa3b74a", "metadata": { "editable": true, "slideshow": { @@ -931,8 +858,122 @@ }, "tags": [] }, - "outputs": [], - "source": [] + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
centreshortNametypeOfLevelleveldataDatedataTimestepRangedataTypenumbergridType
0ecmftisobaricInhPa5002007010112000an0regular_ll
1ecmfzisobaricInhPa5002007010112000an0regular_ll
2ecmftisobaricInhPa8502007010112000an0regular_ll
3ecmfzisobaricInhPa8502007010112000an0regular_ll
\n", + "
" + ], + "text/plain": [ + " centre shortName typeOfLevel level dataDate dataTime stepRange \\\n", + "0 ecmf t isobaricInhPa 500 20070101 1200 0 \n", + "1 ecmf z isobaricInhPa 500 20070101 1200 0 \n", + "2 ecmf t isobaricInhPa 850 20070101 1200 0 \n", + "3 ecmf z isobaricInhPa 850 20070101 1200 0 \n", + "\n", + " dataType number gridType \n", + "0 an 0 regular_ll \n", + "1 an 0 regular_ll \n", + "2 an 0 regular_ll \n", + "3 an 0 regular_ll " + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "path = \"_from_pytroch.grib\"\n", + "r1.save(path)\n", + "ds1 = earthkit.data.from_source(\"file\", path)\n", + "ds1.ls()" + ] } ], "metadata": { diff --git a/earthkit/data/core/array.py b/earthkit/data/core/array.py deleted file mode 100644 index e58ba53d..00000000 --- a/earthkit/data/core/array.py +++ /dev/null @@ -1,272 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import threading -from abc import ABCMeta, abstractmethod - - -class ArrayBackendItem: - def __init__(self, backend_type): - self.type = backend_type - self._obj = None - self._avail = None - self.lock = threading.Lock() - - def obj(self): - if self._obj is None: - with self.lock: - if self._obj is None: - self._obj = self.type() - return self._obj - - def available(self): - if self._avail is None: - if self._obj is not None: - self._avail = True - else: - try: - self.obj() - self._avail = True - except Exception: - self._avail = False - return self._avail - - -class ArrayBackendManager: - def __init__(self): - """The backend objects are created on demand to avoid unnecessary imports""" - self.backends = {k: ArrayBackendItem(v) for k, v in array_backend_types.items()} - self._np_backend = None - - def find_for_name(self, name): - b = self.backends.get(name, None) - if b is None: - raise TypeError(f"No backend found for name={name}") - - # this will try to create the backend if it does not exist yet and - # throw an exception when it is not possible - return b.obj() - - def find_for_array(self, v, guess=None): - if guess is not None: - if guess.is_native_array(v): - return guess - - # try all the backends - for _, b in self.backends.items(): - # this will try create the backend if it does not exist yest. - # If it fails available() will return False from this moment on. - if b.available() and b.obj().is_native_array(v): - return b.obj() - - raise TypeError(f"No backend found for array type={type(v)}") - - def numpy_backend(self): - if self._np_backend is None: - self._np_backend = self.find_for_name("numpy") - return self._np_backend - - -class ArrayBackend(metaclass=ABCMeta): - _array_ns = None - _default = "numpy" - _name = None - _array_name = "array" - _dtypes = {} - - def __init__(self): - self.lock = threading.Lock() - - @property - def array_ns(self): - """Delayed construction of array namespace""" - if self._array_ns is None: - with self.lock: - if self._array_ns is None: - self._array_ns = self._make_array_ns() - return self._array_ns - - @abstractmethod - def _make_array_ns(self): - pass - - @property - def name(self): - return self._name - - @property - def array_name(self): - return f"{self._name} {self._array_name}" - - def to_array(self, v, backend=None): - if backend is not None: - if backend is self: - return v - - return backend.to_backend(v, self) - else: - b = get_backend(v, strict=False) - return b.to_backend(v, self) - - def to_dtype(self, dtype): - if isinstance(dtype, str): - return self._dtypes.get(dtype, None) - return dtype - - def match_dtype(self, v, dtype): - if dtype is not None: - dtype = self.to_dtype(dtype) - return v.dtype == dtype if dtype is not None else False - return True - - @abstractmethod - def is_native_array(self, v, **kwargs): - pass - - @abstractmethod - def to_backend(self, v, backend): - pass - - @abstractmethod - def from_numpy(self, v): - pass - - @abstractmethod - def from_pytorch(self, v): - pass - - @abstractmethod - def from_other(self, v, **kwargs): - pass - - -class NumpyBackend(ArrayBackend): - _name = "numpy" - - def __init__(self): - super().__init__() - - def _make_array_ns(self): - import numpy as np - - try: - import array_api_compat - - ns = array_api_compat.array_namespace(np.ones(2)) - except Exception: - ns = np - - return ns - - def to_dtype(self, dtype): - return dtype - - def is_native_array(self, v, dtype=None): - import numpy as np - - if not isinstance(v, np.ndarray): - return False - if dtype is not None: - return v.dtype == dtype - return True - - def to_backend(self, v, backend): - return backend.from_numpy(v) - - def from_numpy(self, v): - return v - - def from_pytorch(self, v): - return v.numpy() - - def from_other(self, v, **kwargs): - import numpy as np - - return np.array(v, **kwargs) - - -class PytorchBackend(ArrayBackend): - _name = "pytorch" - _array_name = "tensor" - - def __init__(self): - super().__init__() - # pytorch is an optional dependency, we need to see on init - # if we can load it - self.array_ns - - def _make_array_ns(self): - try: - import array_api_compat - - except Exception: - raise ImportError("array_api_compat is required to use pytorch backend") - - try: - import torch - except Exception: - raise ImportError("pytorch is required to use pytorch backend") - - self._dtypes = {"float64": torch.float64, "float32": torch.float32} - - return array_api_compat.array_namespace(torch.ones(2)) - - def is_native_array(self, v, dtype=None): - import torch - - if not torch.is_tensor(v): - return False - return self.match_dtype(v, dtype) - - def to_backend(self, v, backend): - return backend.from_pytorch(v) - - def from_numpy(self, v): - import torch - - return torch.from_numpy(v) - - def from_pytorch(self, v): - return v - - def from_other(self, v, **kwargs): - import torch - - return torch.tensor(v, **kwargs) - - -array_backend_types = {"numpy": NumpyBackend, "pytorch": PytorchBackend} - -_MANAGER = ArrayBackendManager() - -NUMPY_BACKEND = _MANAGER.numpy_backend() - - -def ensure_backend(backend): - if backend is None: - return _MANAGER.find_for_name(ArrayBackend._default) - elif isinstance(backend, str): - return _MANAGER.find_for_name(backend) - else: - return backend - - -def get_backend(array, guess=None, strict=True): - if isinstance(array, list): - array = array[0] - - if guess is not None: - guess = ensure_backend(guess) - - b = _MANAGER.find_for_array(array, guess=guess) - if strict and guess is not None and b is not guess: - raise ValueError( - f"array type={b.array_name} and specified backend={guess} do not match" - ) - return b diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index 3d349a1c..886d333e 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -12,58 +12,102 @@ from collections import defaultdict from earthkit.data.core import Base -from earthkit.data.core.array import NUMPY_BACKEND, ensure_backend from earthkit.data.core.index import Index from earthkit.data.decorators import cached_method, detect_out_filename +from earthkit.data.utils.array import ensure_backend, numpy_backend from earthkit.data.utils.metadata import metadata_argument class Field(Base): - r"""Represents a Field.""" + r"""Represent a Field.""" + + def __init__( + self, + array_backend, + metadata=None, + raw_values_backend=numpy_backend(), + raw_other_backend=numpy_backend(), + ): + self.__metadata = metadata + self._array_backend = array_backend + self._raw_values_backend = raw_values_backend + self._raw_other_backend = raw_other_backend - raw_values_backend = NUMPY_BACKEND - raw_other_backend = NUMPY_BACKEND + @property + def array_backend(self): + r""":obj:`ArrayBackend`: Return the array backend of the field.""" + return self._array_backend - def __init__(self, backend, metadata=None): - self.__metadata = metadata - self.backend = backend + @property + def raw_values_backend(self): + r""":obj:`ArrayBackend`: Return the array backend the low level API + uses to extract the field values. + """ + return self._raw_values_backend + + @property + def raw_other_backend(self): + r""":obj:`ArrayBackend`: Return the array backend the low level API + uses to extract non-field-related values, e.g. latitudes, longitudes. + """ + return self._raw_other_backend + + def _to_array(self, v, array_backend=None, source_backend=None): + r"""Convert an array into an ``array backend``. - def _to_array(self, v, backend=None, raw=None): - if backend is None: - return self.backend.to_array(v, raw) + Parameters + ---------- + v: array-like + The values. + array_backend: :obj:`ArrayBackend` + The target array backend. + source_backend: :obj:`ArrayBackend` + The array backend of ``v``. When None, it will be automatically detected. + + Returns + ------- + array-like + ``v`` converted onto the ``array_backend``. + + """ + if array_backend is None: + return self._array_backend.to_array(v, source_backend) else: - backend = ensure_backend(backend) - return backend.to_array(v, raw) + array_backend = ensure_backend(array_backend) + return array_backend.to_array(v, source_backend) @abstractmethod def _values(self, dtype=None): - r"""Return the values as stored in the field as an ndarray. + r"""Return the raw values extracted from the underlying storage format + of the field. Parameters ---------- - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is - ``np.float64``. + ``float64``. - The original shape and backend type of the values is kept. + The original shape and array backend type of the raw values are kept. Returns ------- - ndarray - Field values + array-like + Field values in the format specified by :attr:`raw_values_backend`. """ self._not_implemented() @property def values(self): - r"""ndarray: Get the values stored in the field as a 1D ndarray.""" - v = self._to_array(self._values(), raw=self.raw_values_backend) + r"""array-like: Get the values stored in the field as a 1D array. The array type + is defined by :attr:`array_backend` + """ + v = self._to_array(self._values(), source_backend=self.raw_values_backend) if len(v.shape) != 1: n = math.prod(v.shape) n = (n,) - return self.backend.array_ns.reshape(v, n) + return self._array_backend.array_ns.reshape(v, n) return v def _make_metadata(self): @@ -87,7 +131,7 @@ def to_numpy(self, flatten=False, dtype=None): :obj:`shape` is returned. dtype: str, numpy.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default - type used by the underlying data accessor is used. For GRIB it is ``np.float64``. + type used by the underlying data accessor is used. For GRIB it is ``float64``. Returns ------- @@ -96,36 +140,39 @@ def to_numpy(self, flatten=False, dtype=None): """ v = self._values(dtype=dtype) - v = NUMPY_BACKEND.to_array(v, self.raw_values_backend) + v = numpy_backend().to_array(v, self.raw_values_backend) shape = self._required_shape(flatten) if shape != v.shape: return v.reshape(shape) return v - def to_array(self, flatten=False, dtype=None, backend=None): - r"""Return the values stored in the field as an ndarray. + def to_array(self, flatten=False, dtype=None, array_backend=None): + r"""Return the values stored in the field in the + format of :attr:`array_backend`. Parameters ---------- flatten: bool - When it is True a flat ndarray is returned. Otherwise an ndarray with the field's + When it is True a flat array is returned. Otherwise an array with the field's :obj:`shape` is returned. - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default - type used by the underlying data accessor is used. For GRIB it is ``np.float64``. + type used by the underlying data accessor is used. For GRIB it is ``float64``. Returns ------- - ndarray - Field values + array-array + Field values in the format od :attr:`array_backend`. """ v = self._to_array( - self._values(dtype=dtype), backend=backend, raw=self.raw_values_backend + self._values(dtype=dtype), + array_backend=array_backend, + source_backend=self.raw_values_backend, ) shape = self._required_shape(flatten) if shape != v.shape: - return self.backend.array_ns.reshape(v, shape) + return self._array_backend.array_ns.reshape(v, shape) return v def _required_shape(self, flatten): @@ -146,17 +193,18 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): flatten: bool When it is True a flat ndarray per key is returned. Otherwise an ndarray with the field's :obj:`shape` is returned for each key. - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default - type used by the underlying data accessor is used. For GRIB it is ``np.float64``. + type used by the underlying data accessor is used. For GRIB it is ``float64``. Returns ------- - ndarray - An ndarray containing one ndarray per key is returned + array-like + An multi-dimensional array containing one array per key is returned (following the order in ``keys``). When ``keys`` is a single value only the - ndarray belonging to the key is returned. + ndarray belonging to the key is returned. The array format is specified by + :attr:`array_backend`. Examples @@ -201,19 +249,19 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): if k not in _keys: raise ValueError(f"data: invalid argument: {k}") - r = [self._to_array(_keys[k][0](dtype=dtype), raw=_keys[k][1]) for k in keys] + r = [ + self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) + for k in keys + ] shape = self._required_shape(flatten) if shape != r[0].shape: # r = [x.reshape(shape) for x in r] - r = [self.backend.array_ns.reshape(x, shape) for x in r] + r = [self._array_backend.array_ns.reshape(x, shape) for x in r] if len(r) == 1: return r[0] else: - return self.backend.array_ns.stack(r) - # import numpy as np - - # return np.array(r) + return self._array_backend.array_ns.stack(r) def to_points(self, flatten=False, dtype=None): r"""Return the geographical coordinates in the data's original @@ -222,18 +270,19 @@ def to_points(self, flatten=False, dtype=None): Parameters ---------- flatten: bool - When it is True 1D ndarrays are returned. Otherwise ndarrays with the field's + When it is True 1D arrays are returned. Otherwise arrays with the field's :obj:`shape` are returned. - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is - ``np.float64``. + ``float64``. Returns ------- dict - Dictionary with items "x" and "y", containing the ndarrays of the x and - y coordinates, respectively. + Dictionary with items "x" and "y", containing the arrays of the x and + y coordinates, respectively. The array format is specified by + :attr:`array_backend`. Raises ------ @@ -248,12 +297,12 @@ def to_points(self, flatten=False, dtype=None): x = self._metadata.geography.x(dtype=dtype) y = self._metadata.geography.y(dtype=dtype) if x is not None and y is not None: - x = self._to_array(x, raw=self.raw_other_backend) - y = self._to_array(y, raw=self.raw_other_backend) + x = self._to_array(x, source_backend=self.raw_other_backend) + y = self._to_array(y, source_backend=self.raw_other_backend) shape = self._required_shape(flatten) if shape != x.shape: - x = self.backend.array_ns.reshape(x, shape) - y = self.backend.array_ns.reshape(y, shape) + x = self._array_backend.array_ns.reshape(x, shape) + y = self._array_backend.array_ns.reshape(y, shape) return dict(x=x, y=y) elif self.projection().CARTOPY_CRS == "PlateCarree": lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype) @@ -269,18 +318,19 @@ def to_latlon(self, flatten=False, dtype=None): Parameters ---------- flatten: bool - When it is True 1D ndarrays are returned. Otherwise ndarrays with the field's + When it is True 1D arrays are returned. Otherwise arrays with the field's :obj:`shape` are returned. - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is - ``np.float64``. + ``float64``. Returns ------- dict - Dictionary with items "lat" and "lon", containing the ndarrays of the latitudes and - longitudes, respectively. + Dictionary with items "lat" and "lon", containing the arrays of the latitudes and + longitudes, respectively. The array format is specified by + :attr:`array_backend`. See Also -------- @@ -581,32 +631,59 @@ def _attributes(self, names): class FieldList(Index): - r"""Represents a list of :obj:`Field` \s.""" + r"""Represent a list of :obj:`Field` \s. + + Parameters + ---------- + array_backend: str, :obj:`ArrayBackend` + The array backend. When it is None the array backend + defaults to "numpy". + """ _md_indices = {} - def __init__(self, backend=None, **kwargs): - self.backend = ensure_backend(backend) + def __init__(self, array_backend=None, **kwargs): + self._array_backend = ensure_backend(array_backend) super().__init__(**kwargs) def _init_from_mask(self, index): - self.backend = index._index.backend + self._array_backend = index._index.array_backend def _init_from_multi(self, index): - self.backend = index._indexes[0].backend + self._array_backend = index._indexes[0].array_backend @staticmethod def from_numpy(array, metadata): from earthkit.data.sources.array_list import ArrayFieldList - return ArrayFieldList(array, metadata, backend=NUMPY_BACKEND) + return ArrayFieldList(array, metadata, array_backend=numpy_backend()) @staticmethod def from_array(array, metadata): + r"""Create an :class:`ArrayFieldList`. + + Parameters + ---------- + array: array-like, list + The fields' values. When it is a list must contain one array per field. The array + type must be supported by :class:`ArrayBackend`. + metadata: list + The fields' metadata. Must contain one :class:`Metadata` object per field. + + In the generated :class:`ArrayFieldList`, each field is represented by an array + storing the field values and a :class:`MetaData` object holding + the field metadata. The shape and dtype of the array is controlled by the ``kwargs``. + Please note that generated :class:`ArrayFieldList` stores all the field values in + a single array. + """ from earthkit.data.sources.array_list import ArrayFieldList return ArrayFieldList(array, metadata) + @property + def array_backend(self): + return self._array_backend + def ignore(self): # When the concrete type is Fieldlist we assume the object was # created with Fieldlist() i.e. it is empty. We ignore it from @@ -722,7 +799,7 @@ def index(self, key): return self._md_indices[key] def to_numpy(self, **kwargs): - r"""Return the field values as an ndarray. It is formed as the array of the + r"""Return all the fields' values as an ndarray. It is formed as the array of the :obj:`data.core.fieldlist.Field.to_numpy` values per field. Parameters @@ -737,6 +814,7 @@ def to_numpy(self, **kwargs): See Also -------- + to_array values """ import numpy as np @@ -744,17 +822,36 @@ def to_numpy(self, **kwargs): return np.array([f.to_numpy(**kwargs) for f in self]) def to_array(self, **kwargs): + r"""Return all the fields' values as an array. It is formed as the array of the + :obj:`data.core.fieldlist.Field.to_array` values per field. + + Parameters + ---------- + **kwargs: dict, optional + Keyword arguments passed to :obj:`data.core.fieldlist.Field.to_array` + + Returns + ------- + array-like + Array containing the field values. The array format is specified by + :attr:`array_backend`. + + See Also + -------- + values + to_numpy + """ x = [f.to_array(**kwargs) for f in self] - return self.backend.array_ns.stack(x) + return self._array_backend.array_ns.stack(x) @property def values(self): - r"""ndarray: Get the field values as a 2D ndarray. It is formed as the array of + r"""ndarray: Get all the fields' values as a 2D array. It is formed as the array of :obj:`GribField.values ` per field. See Also -------- - to_numpy + to_array >>> import earthkit.data @@ -772,7 +869,7 @@ def values(self): """ x = [f.values for f in self] - return self.backend.array_ns.stack(x) + return self._array_backend.array_ns.stack(x) def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): r"""Return the values and/or the geographical coordinates. @@ -787,15 +884,15 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): flatten: bool When it is True the "lat", "lon" arrays and the "value" arrays per field will all be flattened. Otherwise they will preserve the field's :obj:`shape`. - dtype: str, numpy.dtype or None + dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is - ``np.float64``. + ``float64``. Returns ------- - ndarray - The elements of the ndarray (in the order of the ``keys``) are as follows: + array-like + The elements of the array (in the order of the ``keys``) are as follows: * the latitudes array from the first field when "lat" is in ``keys`` * the longitudes array from the first field when "lon" is in ``keys`` @@ -858,10 +955,10 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): else: raise ValueError(f"data: invalid argument: {k}") - return self.backend.array_ns.stack(r) + return self._array_backend.array_ns.stack(r) elif len(self) == 0: - return self.backend.array_ns.stack([]) + return self._array_backend.array_ns.stack([]) else: raise ValueError("Fields do not have the same grid geometry") @@ -1098,8 +1195,9 @@ def to_points(self, **kwargs): Returns ------- dict - Dictionary with items "x" and "y", containing the ndarrays of the x and - y coordinates, respectively. + Dictionary with items "x" and "y", containing the arrays of the x and + y coordinates, respectively. The array format is specified by + :attr:`array_backend`. Raises ------ @@ -1125,8 +1223,9 @@ def to_latlon(self, **kwargs): Returns ------- dict - Dictionary with items "lat" and "lon", containing the ndarrays of the latitudes and - longitudes, respectively. + Dictionary with items "lat" and "lon", containing the arrays of the latitudes and + longitudes, respectively. The array format is specified by + :attr:`array_backend`. Raises ------ @@ -1255,54 +1354,51 @@ def write(self, f, **kwargs): for s in self: s.write(f, **kwargs) - def to_fieldlist(self, backend, **kwargs): - r"""Convert to a new :class:`FieldList` based on the ``backend``. + def to_fieldlist(self, array_backend=None, **kwargs): + r"""Convert to a new :class:`FieldList` based on the ``array_backend``. When the :class:`FieldList` is already in the required format no new :class:`FieldList` is created but the current one is returned. Parameters ---------- - backend: str - Specifies the backend for the generated fieldlist. The supported values are as follows: - - - "numpy": the generated fieldlist is a :class:`NumpyFieldList`, which represents - each field by an ndarray storing the field values and a :class:`MetaData` object holding - the field metadata. The shape and dtype of the ndarray is controlled by the ``kwargs``. - Please note that generated :class:`NumpyFieldList` stores all the field values in - a single ndarray. + array_backend: str, :obj:`ArrayBackend` + Specifies the array backend for the generated fieldlist. The array + type must be supported by :class:`ArrayBackend`. **kwargs: dict, optional - When ``backend`` is "numpy" ``kwargs`` are passed to :obj:`to_numpy` to + ``kwargs`` are passed to :obj:`to_array` to extract the field values the resulting object will store. Returns ------- :class:`FieldList` - the current :class:`FieldList` if it is already in the required format - - :class:`NumpyFieldList` when ``backend`` is "numpy" + - a new :class:`ArrayFieldList` otherwise Examples -------- The following example will convert a fieldlist read from a file into a - :class:`NumpyFieldList` storing single precision field values. + :class:`ArrayFieldList` storing single precision field values. >>> import numpy as np >>> import earthkit.data >>> ds = earthkit.data.from_source("file", "docs/examples/tuv_pl.grib") >>> ds.path 'docs/examples/tuv_pl.grib' - >>> r = ds.to_fieldlist("numpy", dtype=np.float32) + >>> r = ds.to_fieldlist(array_backend="numpy", dtype=np.float32) >>> r - NumpyFieldList(fields=18) + ArrayFieldList(fields=18) >>> hasattr(r, "path") False >>> r.to_numpy().dtype dtype('float32') """ - backend = ensure_backend(backend) - return self._to_array_fieldlist(backend=backend, **kwargs) + if array_backend is None: + array_backend = self._array_backend + array_backend = ensure_backend(array_backend) + return self._to_array_fieldlist(array_backend=array_backend, **kwargs) def _to_array_fieldlist(self, **kwargs): md = [f.metadata() for f in self] diff --git a/earthkit/data/readers/grib/index/__init__.py b/earthkit/data/readers/grib/index/__init__.py index e68c7c6e..4875df1e 100644 --- a/earthkit/data/readers/grib/index/__init__.py +++ b/earthkit/data/readers/grib/index/__init__.py @@ -126,8 +126,10 @@ def merge(cls, sources): raise ValueError( "GribFieldList can only be merged to another GribFieldLists" ) - if not all(s.backend is s[0].backend for s in sources): - raise ValueError("Only fieldlists with the same backend can be merged") + if not all(s.array_backend is s[0].array_backend for s in sources): + raise ValueError( + "Only fieldlists with the same array backend can be merged" + ) return GribMultiFieldList(sources) @@ -203,7 +205,7 @@ class GribFieldListInFiles(GribFieldList): def _getitem(self, n): if isinstance(n, int): part = self.part(n if n >= 0 else len(self) + n) - return GribField(part.path, part.offset, part.length, self.backend) + return GribField(part.path, part.offset, part.length, self.array_backend) def __len__(self): return self.number_of_parts() diff --git a/earthkit/data/readers/grib/memory.py b/earthkit/data/readers/grib/memory.py index d1f230c3..b75a902c 100644 --- a/earthkit/data/readers/grib/memory.py +++ b/earthkit/data/readers/grib/memory.py @@ -11,18 +11,18 @@ import eccodes -from earthkit.data.core.array import ensure_backend from earthkit.data.readers import Reader from earthkit.data.readers.grib.codes import GribCodesHandle, GribField from earthkit.data.readers.grib.index import GribFieldList +from earthkit.data.utils.array import ensure_backend LOG = logging.getLogger(__name__) class GribMemoryReader(Reader): - def __init__(self, backend=None): + def __init__(self, array_backend=None): self._peeked = None - self.backend = ensure_backend(backend) + self._array_backend = ensure_backend(array_backend) def __iter__(self): return self @@ -43,7 +43,9 @@ def _next_handle(self): def _message_from_handle(self, handle): if handle is not None: - return GribFieldInMemory(GribCodesHandle(handle, None, None), self.backend) + return GribFieldInMemory( + GribCodesHandle(handle, None, None), self._array_backend + ) def peek(self): """Returns the next available message without consuming it""" @@ -126,7 +128,7 @@ class GribStreamReader(GribMemoryReader): def __init__(self, stream, **kwargs): super().__init__() self._stream = stream - self._reader = eccodes.StreamReader(stream, **kwargs) + self._reader = eccodes.StreamReader(stream) def __del__(self): self._stream.close() @@ -144,8 +146,8 @@ def mutate_source(self): class GribFieldInMemory(GribField): """Represents a GRIB message in memory""" - def __init__(self, handle, backend=None): - super().__init__(None, None, None, backend) + def __init__(self, handle, array_backend=None): + super().__init__(None, None, None, array_backend) self._handle = handle @GribField.handle.getter @@ -161,10 +163,10 @@ class GribFieldListInMemory(GribFieldList, Reader): """Represent a GRIB field list in memory""" @staticmethod - def from_fields(fields, backend=None): - if backend is None and len(fields) > 0: - backend = fields[0].backend - fs = GribFieldListInMemory(None, None, backend=backend) + def from_fields(fields, array_backend=None): + if array_backend is None and len(fields) > 0: + array_backend = fields[0].array_backend + fs = GribFieldListInMemory(None, None, array_backend=array_backend) fs._fields = fields fs._loaded = True return fs diff --git a/earthkit/data/readers/grib/reader.py b/earthkit/data/readers/grib/reader.py index 1186b0af..01e1ff6c 100644 --- a/earthkit/data/readers/grib/reader.py +++ b/earthkit/data/readers/grib/reader.py @@ -19,10 +19,12 @@ class GRIBReader(GribFieldListInOneFile, Reader): appendable = True # GRIB messages can be added to the same file def __init__(self, source, path, parts=None): - backend = source._kwargs.get("backend", None) + array_backend = source._kwargs.get("array_backend", None) Reader.__init__(self, source, path) - GribFieldListInOneFile.__init__(self, path, parts=parts, backend=backend) + GribFieldListInOneFile.__init__( + self, path, parts=parts, array_backend=array_backend + ) def __repr__(self): return "GRIBReader(%s)" % (self.path,) diff --git a/earthkit/data/readers/netcdf.py b/earthkit/data/readers/netcdf.py index 32c068e0..5d94c629 100644 --- a/earthkit/data/readers/netcdf.py +++ b/earthkit/data/readers/netcdf.py @@ -151,7 +151,7 @@ def bbox(self, variable): def get_fields_from_ds( ds, - backend, + array_backend, field_type=None, check_only=False, ): # noqa C901 @@ -261,7 +261,7 @@ def get_fields_from_ds( if check_only: return True - fields.append(field_type(ds, name, slices, non_dim_coords, backend)) + fields.append(field_type(ds, name, slices, non_dim_coords, array_backend)) # if not fields: # raise Exception("NetCDFReader no 2D fields found in %s" % (self.path,)) @@ -378,8 +378,8 @@ def _valid_datetime(self): class XArrayField(Field): - def __init__(self, ds, variable, slices, non_dim_coords, backend): - super().__init__(backend) + def __init__(self, ds, variable, slices, non_dim_coords, array_backend): + super().__init__(array_backend) self._ds = ds self._da = ds[variable] @@ -467,7 +467,7 @@ def has_fields(self): if self._fields is None: return get_fields_from_ds( DataSet(self.ds), - self.backend, + self.array_backend, field_type=self.FIELD_TYPE, check_only=True, ) @@ -480,7 +480,7 @@ def _scan(self): def _get_fields(self): return get_fields_from_ds( - DataSet(self.ds), self.backend, field_type=self.FIELD_TYPE + DataSet(self.ds), self.array_backend, field_type=self.FIELD_TYPE ) def to_pandas(self): @@ -565,7 +565,7 @@ def _get_fields(self): xr.open_mfdataset(self.path, combine="by_coords") ) as ds: # or nested return get_fields_from_ds( - DataSet(ds), self.backend, field_type=self.FIELD_TYPE + DataSet(ds), self.array_backend, field_type=self.FIELD_TYPE ) def has_fields(self): @@ -577,7 +577,7 @@ def has_fields(self): ) as ds: # or nested return get_fields_from_ds( DataSet(ds), - self.backend, + self.array_backend, field_type=self.FIELD_TYPE, check_only=True, ) diff --git a/earthkit/data/sources/array_list.py b/earthkit/data/sources/array_list.py index 87255aa6..95bbc476 100644 --- a/earthkit/data/sources/array_list.py +++ b/earthkit/data/sources/array_list.py @@ -10,11 +10,11 @@ import logging import math -from earthkit.data.core.array import get_backend from earthkit.data.core.fieldlist import Field, FieldList from earthkit.data.core.index import MaskIndex, MultiIndex from earthkit.data.readers.grib.pandas import PandasMixIn from earthkit.data.readers.grib.xarray import XarrayMixIn +from earthkit.data.utils.array import get_backend LOG = logging.getLogger(__name__) @@ -28,14 +28,15 @@ class ArrayField(Field): Array storing the values of the field metadata: :class:`Metadata` Metadata object describing the field metadata. - backend: str, ArrayBackend + array_backend: str, ArrayBackend Array backend. Must match the type of ``array``. """ - def __init__(self, array, metadata, backend): - super().__init__(backend, metadata=metadata) + def __init__(self, array, metadata, array_backend): + super().__init__( + array_backend, raw_values_backend=array_backend, metadata=metadata + ) self._array = array - self.raw_values_backend = backend def _make_metadata(self): pass @@ -45,7 +46,7 @@ def _values(self, dtype=None): if dtype is None: return self._array else: - return self.backend.array_ns.astype(self._array, dtype, copy=False) + return self.array_backend.array_ns.astype(self._array, dtype, copy=False) def __repr__(self): return f"{self.__class__.__name__}()" @@ -66,7 +67,7 @@ def write(self, f, **kwargs): class ArrayFieldListCore(PandasMixIn, XarrayMixIn, FieldList): - def __init__(self, array, metadata, *args, backend=None, **kwargs): + def __init__(self, array, metadata, *args, array_backend=None, **kwargs): self._array = array self._metadata = metadata @@ -74,17 +75,17 @@ def __init__(self, array, metadata, *args, backend=None, **kwargs): self._metadata = [self._metadata] # get backend and check consistency - backend = get_backend(self._array, guess=backend, strict=True) + array_backend = get_backend(self._array, guess=array_backend, strict=True) - FieldList.__init__(self, *args, backend=backend, **kwargs) + FieldList.__init__(self, *args, array_backend=array_backend, **kwargs) - if self.backend.is_native_array(self._array): + if self.array_backend.is_native_array(self._array): if self._array.shape[0] != len(self._metadata): # we have a single array and a single metadata if len(self._metadata) == 1 and self._shape_match( self._array.shape, self._metadata[0].geography.shape() ): - self._array = self.backend.array_ns.stack([self._array]) + self._array = self.array_backend.array_ns.stack([self._array]) else: raise ValueError( ( @@ -102,17 +103,20 @@ def __init__(self, array, metadata, *args, backend=None, **kwargs): ) for i, a in enumerate(self._array): - if not self.backend.is_native_array(a): + if not self.array_backend.is_native_array(a): raise ValueError( ( - f"All array element must be an {self.backend.array_name}." + f"All array element must be an {self.array_backend.array_name}." " Type at position={i} is {type(a)}" ) ) else: raise TypeError( - f"array must be an {self.backend.array_name} or a list of {self.backend.array_name}s" + ( + f"array must be an {self.array_backend.array_name} or a" + f" list of {self.array_backend.array_name}s" + ) ) # hide internal metadata related to values @@ -135,8 +139,10 @@ def merge(cls, sources): raise ValueError( "ArrayFieldList can only be merged to another ArrayFieldLists" ) - if not all(s.backend is s[0].backend for s in sources): - raise ValueError("Only fieldlists with the same backend can be merged") + if not all(s.array_backend is s[0].array_backend for s in sources): + raise ValueError( + "Only fieldlists with the same array backend can be merged" + ) merger = ListMerger(sources) return merger.to_fieldlist() @@ -144,11 +150,13 @@ def merge(cls, sources): def __repr__(self): return f"{self.__class__.__name__}(fields={len(self)})" - def _to_array_fieldlist(self, backend=None, **kwargs): + def _to_array_fieldlist(self, array_backend=None, **kwargs): if self[0]._array_matches(self._array[0], **kwargs): return self else: - return type(self)(self.to_array(backend=backend, **kwargs), self._metadata) + return type(self)( + self.to_array(array_backend=array_backend, **kwargs), self._metadata + ) def save(self, filename, append=False, check_nans=True, bits_per_value=16): r"""Write all the fields into a file. @@ -203,8 +211,10 @@ def to_fieldlist(self): for f in s: array.append(f._array) metadata.append(f._metadata) - backend = None if len(self.sources) == 0 else self.sources[0].backend - return ArrayFieldList(array, metadata, backend=backend) + array_backend = ( + None if len(self.sources) == 0 else self.sources[0].array_backend + ) + return ArrayFieldList(array, metadata, array_backend=array_backend) class ArrayFieldList(ArrayFieldListCore): @@ -222,7 +232,7 @@ class ArrayFieldList(ArrayFieldListCore): def _getitem(self, n): if isinstance(n, int): - return ArrayField(self._array[n], self._metadata[n], self.backend) + return ArrayField(self._array[n], self._metadata[n], self.array_backend) def __len__(self): return ( diff --git a/earthkit/data/sources/constants.py b/earthkit/data/sources/constants.py index 91107733..724cfad2 100644 --- a/earthkit/data/sources/constants.py +++ b/earthkit/data/sources/constants.py @@ -328,7 +328,7 @@ def _getitem(self, n): self.procs[param], self.maker.shape, self.maker.field.metadata().geography, - self.backend, + self.array_backend, ) diff --git a/earthkit/data/sources/list_of_dicts.py b/earthkit/data/sources/list_of_dicts.py index 67e11202..1287f33d 100644 --- a/earthkit/data/sources/list_of_dicts.py +++ b/earthkit/data/sources/list_of_dicts.py @@ -191,7 +191,7 @@ def __init__(self, list_of_dicts, *args, **kwargs): super().__init__(*args, **kwargs) def __getitem__(self, n): - return VirtualGribField(self.list_of_dicts[n], self.backend) + return VirtualGribField(self.list_of_dicts[n], self.array_backend) def __len__(self): return len(self.list_of_dicts) diff --git a/earthkit/data/sources/numpy_list.py b/earthkit/data/sources/numpy_list.py index 2f3767c9..acd65cec 100644 --- a/earthkit/data/sources/numpy_list.py +++ b/earthkit/data/sources/numpy_list.py @@ -7,11 +7,12 @@ # nor does it submit to any jurisdiction. # -from earthkit.data.core.array import NUMPY_BACKEND from earthkit.data.sources.array_list import ArrayFieldList class NumpyFieldList(ArrayFieldList): def __init__(self, *args, **kwargs): + from earthkit.data.utils.array import numpy_backend + kwargs.pop("backend", None) - super().__init__(*args, backend=NUMPY_BACKEND, **kwargs) + super().__init__(*args, array_backend=numpy_backend(), **kwargs) diff --git a/earthkit/data/testing.py b/earthkit/data/testing.py index a3ecff86..33a5ce0b 100644 --- a/earthkit/data/testing.py +++ b/earthkit/data/testing.py @@ -150,49 +150,21 @@ def load_nc_or_xr_source(path, mode): return from_object(xarray.open_dataset(path)) -# def check_numpy_array_type(v, dtype=None): -# import numpy as np - -# assert isinstance(v, np.ndarray) -# if dtype is not None: -# if dtype == "float64": -# dtype = np.float64 -# elif dtype == "float32": -# dtype = np.float32 -# else: -# raise ValueError("Unsupported dtype={dtype}") -# assert v.dtype == dtype - - -# def check_pytorch_array_type(v, dtype=None): -# import torch - -# assert torch.is_tensor(v) -# if dtype is not None: -# if dtype == "float64": -# dtype = torch.float64 -# elif dtype == "float32": -# dtype = torch.float32 -# else: -# raise ValueError("Unsupported dtype={dtype}") -# assert v.dtype == dtype - - def check_array_type(v, backend, **kwargs): - from earthkit.data.core.array import ensure_backend + from earthkit.data.utils.array import ensure_backend b = ensure_backend(backend) - assert b.is_native_array(v, **kwargs) + assert b.is_native_array(v, **kwargs), f"{type(v)}, {backend=}, {kwargs=}" def get_array_namespace(backend): - from earthkit.data.core.array import ensure_backend + from earthkit.data.utils.array import ensure_backend return ensure_backend(backend).array_ns def get_array(v, backend, **kwargs): - from earthkit.data.core.array import ensure_backend + from earthkit.data.utils.array import ensure_backend b = ensure_backend(backend) return b.from_other(v, **kwargs) diff --git a/earthkit/data/utils/array/__init__.py b/earthkit/data/utils/array/__init__.py new file mode 100644 index 00000000..3a69dae5 --- /dev/null +++ b/earthkit/data/utils/array/__init__.py @@ -0,0 +1,214 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +import os +import threading +from abc import ABCMeta, abstractmethod +from importlib import import_module + +LOG = logging.getLogger(__name__) + + +class ArrayBackendManager: + def __init__(self): + self.backends = None + self._np_backend = None + self.loaded = None + self.lock = threading.Lock() + + def find_for_name(self, name): + self._load() + b = self.backends.get(name, None) + if b is None: + raise TypeError(f"No array backend found for name={name}") + + # throw an exception when the backend is not available + if not b.available: + raise Exception(f"Could not load array backend for name={name}") + + return b + + def find_for_array(self, v, guess=None): + self._load() + if guess is not None and guess.is_native_array(v): + return guess + + # try all the backends. This will only try to load/import an unloaded/unimported + # backend when necessary + for _, b in self.backends.items(): + if b.is_native_array(v): + return b + + raise TypeError(f"No array backend found for array type={type(v)}") + + @property + def numpy_backend(self): + if self._np_backend is None: + self._np_backend = self.find_for_name("numpy") + return self._np_backend + + def _load(self): + """Load the available backend objects""" + if self.loaded is None: + with self.lock: + self.backends = {} + here = os.path.dirname(__file__) + for path in sorted(os.listdir(here)): + if path[0] in ("_", "."): + continue + + if path.endswith(".py") or os.path.isdir(os.path.join(here, path)): + name, _ = os.path.splitext(path) + try: + module = import_module(f".{name}", package=__name__) + if hasattr(module, "Backend"): + w = getattr(module, "Backend") + self.backends[name] = w() + except Exception as e: + LOG.exception( + f"Failed to import array backend code {name} from {path}. {e}" + ) + self.loaded = True + + +class ArrayBackendCore: + def __init__(self, backend): + self.ns = None + self.dtypes = None + + try: + self.ns, self.dtypes = backend._load() + self.avail = True + except Exception as e: + LOG.exception(f"Failed to load array backend {backend.name}. {e}") + self.avail = False + + +class ArrayBackend(metaclass=ABCMeta): + """The backend objects are created upfront but only loaded on + demand to avoid unnecessary imports + """ + + _name = None + _array_name = "array" + _core = None + + def __init__(self): + self.lock = threading.Lock() + + def _load_core(self): + if self._core is None: + with self.lock: + if self._core is None: + self._core = ArrayBackendCore(self) + + @property + def available(self): + self._load_core() + return self._core.avail + + @abstractmethod + def _load(self): + """Called from arrayBackendCore. It must return ns and dtypes""" + pass + + @property + def array_ns(self): + """Delayed construction of array namespace""" + self._load_core() + return self._core.ns + + @property + def name(self): + return self._name + + @property + def array_name(self): + return f"{self._name} {self._array_name}" + + def to_array(self, v, backend=None): + if backend is not None: + if backend is self: + return v + + return backend.to_backend(v, self) + else: + b = get_backend(v, strict=False) + return b.to_backend(v, self) + + @property + def _dtypes(self): + self._load_core() + return self._core.dtypes + + def to_dtype(self, dtype): + if isinstance(dtype, str): + return self._dtypes.get(dtype, None) + return dtype + + def match_dtype(self, v, dtype): + if dtype is not None: + dtype = self.to_dtype(dtype) + f = v.dtype == dtype if dtype is not None else False + return f + return True + + @abstractmethod + def is_native_array(self, v, **kwargs): + pass + + @abstractmethod + def to_backend(self, v, backend): + pass + + @abstractmethod + def from_numpy(self, v): + pass + + @abstractmethod + def from_pytorch(self, v): + pass + + @abstractmethod + def from_other(self, v, **kwargs): + pass + + +_MANAGER = ArrayBackendManager() + +# The public API + + +def ensure_backend(backend): + if backend is None: + return numpy_backend() + if isinstance(backend, str): + return _MANAGER.find_for_name(backend) + else: + return backend + + +def get_backend(array, guess=None, strict=True): + if isinstance(array, list): + array = array[0] + + if guess is not None: + guess = ensure_backend(guess) + + b = _MANAGER.find_for_array(array, guess=guess) + if strict and guess is not None and b is not guess: + raise ValueError( + f"array type={b.array_name} and specified backend={guess} do not match" + ) + return b + + +def numpy_backend(): + return _MANAGER.numpy_backend diff --git a/earthkit/data/utils/array/numpy.py b/earthkit/data/utils/array/numpy.py new file mode 100644 index 00000000..35f1a9c0 --- /dev/null +++ b/earthkit/data/utils/array/numpy.py @@ -0,0 +1,60 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os + +from . import ArrayBackend + + +class NumpyBackend(ArrayBackend): + _name = "numpy" + + def _load(self): + import numpy as np + + try: + import array_api_compat + + ns = array_api_compat.array_namespace(np.ones(2)) + except Exception: + ns = np + + return ns, {} + + def to_dtype(self, dtype): + return dtype + + def is_native_array(self, v, dtype=None): + if self.available is None and "numpy" not in os.modules: + return False + + import numpy as np + + if not isinstance(v, np.ndarray): + return False + if dtype is not None: + return v.dtype == dtype + return True + + def to_backend(self, v, backend): + return backend.from_numpy(v) + + def from_numpy(self, v): + return v + + def from_pytorch(self, v): + return v.numpy() + + def from_other(self, v, **kwargs): + import numpy as np + + return np.array(v, **kwargs) + + +Backend = NumpyBackend diff --git a/earthkit/data/utils/array/pytorch.py b/earthkit/data/utils/array/pytorch.py new file mode 100644 index 00000000..0a0c366e --- /dev/null +++ b/earthkit/data/utils/array/pytorch.py @@ -0,0 +1,65 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os + +from . import ArrayBackend + + +class PytorchBackend(ArrayBackend): + _name = "pytorch" + _array_name = "tensor" + + def _load(self): + try: + import array_api_compat + + except Exception as e: + raise ImportError( + f"array_api_compat is required to use pytorch backend, {e}" + ) + + try: + import torch + except Exception as e: + raise ImportError(f"torch is required to use pytorch backend, {e}") + + dt = {"float64": torch.float64, "float32": torch.float32} + ns = array_api_compat.array_namespace(torch.ones(2)) + + return ns, dt + + def is_native_array(self, v, dtype=None): + if self.available is None and "torch" not in os.modules: + return False + + import torch + + if not torch.is_tensor(v): + return False + return self.match_dtype(v, dtype) + + def to_backend(self, v, backend): + return backend.from_pytorch(v) + + def from_numpy(self, v): + import torch + + return torch.from_numpy(v) + + def from_pytorch(self, v): + return v + + def from_other(self, v, **kwargs): + import torch + + return torch.tensor(v, **kwargs) + + +Backend = PytorchBackend diff --git a/tests/array_fieldlist/array_fl_fixtures.py b/tests/array_fieldlist/array_fl_fixtures.py index 276e18df..f3bc0c6e 100644 --- a/tests/array_fieldlist/array_fl_fixtures.py +++ b/tests/array_fieldlist/array_fl_fixtures.py @@ -17,7 +17,7 @@ from earthkit.data.testing import earthkit_examples_file, get_array_namespace -def load_array_fl(num, backend=None): +def load_array_fl(num, array_backend=None): assert num in [1, 2, 3] files = ["test.grib", "test6.grib", "tuv_pl.grib"] files = files[:num] @@ -26,7 +26,9 @@ def load_array_fl(num, backend=None): md = [] for fname in files: ds_in.append( - from_source("file", earthkit_examples_file(fname), backend=backend) + from_source( + "file", earthkit_examples_file(fname), array_backend=array_backend + ) ) md += ds_in[-1].metadata("param") @@ -41,8 +43,10 @@ def load_array_fl(num, backend=None): return (*ds, md) -def load_array_fl_file(fname, backend=None): - ds_in = from_source("file", earthkit_examples_file(fname), backend=backend) +def load_array_fl_file(fname, array_backend=None): + ds_in = from_source( + "file", earthkit_examples_file(fname), array_backend=array_backend + ) md = ds_in.metadata("param") ds = FieldList.from_array( @@ -52,10 +56,10 @@ def load_array_fl_file(fname, backend=None): return (ds, md) -def check_array_fl(ds, ds_input, md_full, backend=None): +def check_array_fl(ds, ds_input, md_full, array_backend=None): assert len(ds_input) in [1, 2, 3] - ns = get_array_namespace(backend) + ns = get_array_namespace(array_backend) assert len(ds) == len(md_full) assert ds.metadata("param") == md_full @@ -96,13 +100,13 @@ def check_array_fl(ds, ds_input, md_full, backend=None): def check_array_fl_from_to_fieldlist( - ds, ds_input, md_full, backend=None, flatten=False, dtype=None + ds, ds_input, md_full, array_backend=None, flatten=False, dtype=None ): assert len(ds_input) in [1, 2, 3] assert len(ds) == len(md_full) assert ds.metadata("param") == md_full - ns = get_array_namespace(backend) + ns = get_array_namespace(array_backend) np_kwargs = {"flatten": flatten, "dtype": dtype} diff --git a/tests/array_fieldlist/test_numpy_fl_write.py b/tests/array_fieldlist/test_numpy_fl_write.py index 3da88e5d..b1a3c752 100644 --- a/tests/array_fieldlist/test_numpy_fl_write.py +++ b/tests/array_fieldlist/test_numpy_fl_write.py @@ -33,15 +33,17 @@ LOG = logging.getLogger(__name__) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_array_fl_grib_write(backend): - ds = from_source("file", earthkit_examples_file("test.grib"), backend=backend) - ns = get_array_namespace(backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_array_fl_grib_write(array_backend): + ds = from_source( + "file", earthkit_examples_file("test.grib"), array_backend=array_backend + ) + ns = get_array_namespace(array_backend) assert ds[0].metadata("shortName") == "2t" assert len(ds) == 2 v1 = ds[0].values + 1 - check_array_type(v1, backend) + check_array_type(v1, array_backend) md = ds[0].metadata() md1 = md.override(shortName="msl") @@ -50,16 +52,18 @@ def test_array_fl_grib_write(backend): with temp_file() as tmp: r.save(tmp) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, backend=backend) + r_tmp = from_source("file", tmp, array_backend=array_backend) v_tmp = r_tmp[0].values assert ns.allclose(v1, v_tmp) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("_kwargs", [{}, {"check_nans": True}]) -def test_array_fl_grib_write_missing(backend, _kwargs): - ds = from_source("file", earthkit_examples_file("test.grib"), backend=backend) - ns = get_array_namespace(backend) +def test_array_fl_grib_write_missing(array_backend, _kwargs): + ds = from_source( + "file", earthkit_examples_file("test.grib"), array_backend=array_backend + ) + ns = get_array_namespace(array_backend) assert ds[0].metadata("shortName") == "2t" @@ -81,7 +85,7 @@ def test_array_fl_grib_write_missing(backend, _kwargs): with temp_file() as tmp: r.save(tmp, **_kwargs) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, backend=backend) + r_tmp = from_source("file", tmp, array_backend=array_backend) v_tmp = r_tmp[0].values assert ns.isnan(v_tmp[0]) assert not ns.isnan(v_tmp[1]) @@ -178,12 +182,12 @@ def test_array_fl_grib_write_generating_proc_id(): assert np.allclose(r_tmp.values[1], v2) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "_kwargs,expected_value", [({}, 16), ({"bits_per_value": 12}, 12)] ) -def test_array_fl_grib_write_bits_per_value(backend, _kwargs, expected_value): - ds, _ = load_array_fl(1, backend) +def test_array_fl_grib_write_bits_per_value(array_backend, _kwargs, expected_value): + ds, _ = load_array_fl(1, array_backend) with temp_file() as tmp: ds.save(tmp, **_kwargs) diff --git a/tests/array_fieldlist/test_numpy_fs.py b/tests/array_fieldlist/test_numpy_fs.py index 626695dd..4a323f20 100644 --- a/tests/array_fieldlist/test_numpy_fs.py +++ b/tests/array_fieldlist/test_numpy_fs.py @@ -131,7 +131,7 @@ def test_array_fl_grib_from_to_fieldlist(kwargs): md_full = ds.metadata("param") assert len(ds) == 2 - r = ds.to_fieldlist("numpy", **kwargs) + r = ds.to_fieldlist(array_backend="numpy", **kwargs) check_array_fl_from_to_fieldlist(r, [ds], md_full, **kwargs) @@ -141,11 +141,11 @@ def test_array_fl_grib_from_to_fieldlist_repeat(): assert len(ds) == 2 kwargs = {} - r = ds.to_fieldlist("numpy", **kwargs) + r = ds.to_fieldlist(array_backend="numpy", **kwargs) check_array_fl_from_to_fieldlist(r, [ds], md_full, **kwargs) kwargs = {"flatten": True, "dtype": np.float32} - r1 = r.to_fieldlist("numpy", **kwargs) + r1 = r.to_fieldlist(array_backend="numpy", **kwargs) assert r1 is not r check_array_fl_from_to_fieldlist(r1, [ds], md_full, **kwargs) diff --git a/tests/documentation/test_notebooks.py b/tests/documentation/test_notebooks.py index 411d8170..5d91e385 100644 --- a/tests/documentation/test_notebooks.py +++ b/tests/documentation/test_notebooks.py @@ -15,7 +15,7 @@ import pytest -from earthkit.data.testing import MISSING, earthkit_file +from earthkit.data.testing import MISSING, NO_PYTORCH, earthkit_file # See https://www.blog.pythonlibrary.org/2018/10/16/testing-jupyter-notebooks/ @@ -31,9 +31,11 @@ "polytope.ipynb", "grib_fdb_write.ipynb", "demo_source_plugin.ipynb", - "grib_array_backends.ipynb", ] +if NO_PYTORCH: + SKIP.append("grib_array_backends.ipynb") + def notebooks_list(): notebooks = [] diff --git a/tests/grib/grib_fixtures.py b/tests/grib/grib_fixtures.py index fc4c2975..62d3948c 100644 --- a/tests/grib/grib_fixtures.py +++ b/tests/grib/grib_fixtures.py @@ -15,14 +15,14 @@ from earthkit.data.testing import earthkit_examples_file, earthkit_test_data_file -def load_array_fieldlist(path, backend): - ds = from_source("file", path, backend=backend) +def load_array_fieldlist(path, array_backend): + ds = from_source("file", path, array_backend=array_backend) return FieldList.from_array( ds.values, [m.override(generatingProcessIdentifier=120) for m in ds.metadata()] ) -def load_grib_data(filename, fl_type, backend, folder="example"): +def load_grib_data(filename, fl_type, array_backend, folder="example"): if folder == "example": path = earthkit_examples_file(filename) elif folder == "data": @@ -31,65 +31,11 @@ def load_grib_data(filename, fl_type, backend, folder="example"): raise ValueError("Invalid folder={folder}") if fl_type == "file": - return from_source("file", path, backend=backend) + return from_source("file", path, array_backend=array_backend) elif fl_type == "array": - return load_array_fieldlist(path, backend) + return load_array_fieldlist(path, array_backend) else: raise ValueError("Invalid fl_type={fl_type}") -# def check_numpy_array_type(v, dtype=None): -# import numpy as np - -# assert isinstance(v, np.ndarray) -# if dtype is not None: -# if dtype == "float64": -# dtype = np.float64 -# elif dtype == "float32": -# dtype = np.float32 -# else: -# raise ValueError("Unsupported dtype={dtype}") -# assert v.dtype == dtype - - -# def check_pytorch_array_type(v, dtype=None): -# import torch - -# assert torch.is_tensor(v) -# if dtype is not None: -# if dtype == "float64": -# dtype = torch.float64 -# elif dtype == "float32": -# dtype = torch.float32 -# else: -# raise ValueError("Unsupported dtype={dtype}") -# assert v.dtype == dtype - - -# def check_array_type(v, backend, **kwargs): -# if backend is None or backend == "numpy": -# check_numpy_array_type(v, **kwargs) -# elif backend == "pytorch": -# check_pytorch_array_type(v, **kwargs) -# else: -# raise ValueError("Invalid backend={backend}") - - -# def get_array_namespace(backend): -# from earthkit.data.core.array import ensure_backend - -# return ensure_backend(backend).array_ns - - -# def get_array(v, backend): -# from earthkit.data.core.array import ensure_backend - -# b = ensure_backend(backend) -# return b.from_other(v) - - FL_TYPES = ["file", "array"] - -# ARRAY_BACKENDS = ["numpy"] -# if not NO_PYTORCH: -# ARRAY_BACKENDS.append("pytorch") diff --git a/tests/grib/test_grib_backend.py b/tests/grib/test_grib_backend.py index 1c0d0a15..3b2cdb1d 100644 --- a/tests/grib/test_grib_backend.py +++ b/tests/grib/test_grib_backend.py @@ -16,7 +16,7 @@ from earthkit.data.testing import NO_PYTORCH, earthkit_examples_file -@pytest.mark.parametrize("_kwargs", [{}, {"backend": "numpy"}]) +@pytest.mark.parametrize("_kwargs", [{}, {"array_backend": "numpy"}]) def test_grib_file_numpy_backend(_kwargs): ds = from_source("file", earthkit_examples_file("test6.grib"), **_kwargs) @@ -43,12 +43,19 @@ def test_grib_file_numpy_backend(_kwargs): assert isinstance(ds.to_numpy(), np.ndarray) assert ds.to_numpy().shape == (6, 7, 12) + ds1 = ds.to_fieldlist() + assert len(ds1) == len(ds) + assert ds1.array_backend.name == "numpy" + assert getattr(ds1, "path", None) is None + @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_grib_file_pytorch_backend(): import torch - ds = from_source("file", earthkit_examples_file("test6.grib"), backend="pytorch") + ds = from_source( + "file", earthkit_examples_file("test6.grib"), array_backend="pytorch" + ) assert len(ds) == 6 @@ -73,6 +80,11 @@ def test_grib_file_pytorch_backend(): assert isinstance(ds.to_numpy(), np.ndarray) assert ds.to_numpy().shape == (6, 7, 12) + ds1 = ds.to_fieldlist() + assert len(ds1) == len(ds) + assert ds1.array_backend.name == "pytorch" + assert getattr(ds1, "path", None) is None + def test_grib_array_numpy_backend(): s = from_source("file", earthkit_examples_file("test6.grib")) @@ -111,7 +123,9 @@ def test_grib_array_numpy_backend(): def test_grib_array_pytorch_backend(): import torch - s = from_source("file", earthkit_examples_file("test6.grib"), backend="pytorch") + s = from_source( + "file", earthkit_examples_file("test6.grib"), array_backend="pytorch" + ) ds = FieldList.from_array( s.values, diff --git a/tests/grib/test_grib_convert.py b/tests/grib/test_grib_convert.py index 2780e073..b7d9b7bb 100644 --- a/tests/grib/test_grib_convert.py +++ b/tests/grib/test_grib_convert.py @@ -21,10 +21,10 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) -def test_icon_to_xarray(fl_type, backend): +@pytest.mark.parametrize("array_backend", ["numpy"]) +def test_icon_to_xarray(fl_type, array_backend): # test the conversion to xarray for an icon (unstructured grid) grib file. - g = load_grib_data("test_icon.grib", fl_type, backend, folder="data") + g = load_grib_data("test_icon.grib", fl_type, array_backend, folder="data") ds = g.to_xarray() assert len(ds.data_vars) == 1 @@ -35,9 +35,9 @@ def test_icon_to_xarray(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) -def test_to_xarray_filter_by_keys(fl_type, backend): - g = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ["numpy"]) +def test_to_xarray_filter_by_keys(fl_type, array_backend): + g = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = g.sel(param="t", level=500) + g.sel(param="u") assert len(g) > 1 @@ -53,9 +53,9 @@ def test_to_xarray_filter_by_keys(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) -def test_grib_to_pandas(fl_type, backend): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ["numpy"]) +def test_grib_to_pandas(fl_type, array_backend): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") # all points df = f.to_pandas() diff --git a/tests/grib/test_grib_geography.py b/tests/grib/test_grib_geography.py index a53a5d59..cbb81af1 100644 --- a/tests/grib/test_grib_geography.py +++ b/tests/grib/test_grib_geography.py @@ -31,17 +31,17 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single(fl_type, backend, index): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_to_latlon_single(fl_type, array_backend, index): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") eps = 1e-5 g = f[index] if index is not None else f v = g.to_latlon(flatten=True) assert isinstance(v, dict) - check_array_type(v["lon"], backend, dtype="float64") - check_array_type(v["lat"], backend, dtype="float64") + check_array_type(v["lon"], array_backend, dtype="float64") + check_array_type(v["lat"], array_backend, dtype="float64") check_array( v["lon"], (84,), @@ -61,16 +61,16 @@ def test_grib_to_latlon_single(fl_type, backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single_shape(fl_type, backend, index): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_to_latlon_single_shape(fl_type, array_backend, index): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") g = f[index] if index is not None else f v = g.to_latlon() assert isinstance(v, dict) - check_array_type(v["lon"], backend, dtype="float64") - check_array_type(v["lat"], backend, dtype="float64") + check_array_type(v["lon"], array_backend, dtype="float64") + check_array_type(v["lat"], array_backend, dtype="float64") # x assert v["lon"].shape == (7, 12) @@ -84,10 +84,10 @@ def test_grib_to_latlon_single_shape(fl_type, backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_latlon_multi(fl_type, backend, dtype): - f = load_grib_data("test.grib", fl_type, backend) +def test_grib_to_latlon_multi(fl_type, array_backend, dtype): + f = load_grib_data("test.grib", fl_type, array_backend) v_ref = f[0].to_latlon(flatten=True, dtype=dtype) v = f.to_latlon(flatten=True, dtype=dtype) @@ -102,10 +102,10 @@ def test_grib_to_latlon_multi(fl_type, backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_to_latlon_multi_non_shared_grid(fl_type, backend): - f1 = load_grib_data("test.grib", fl_type, backend) - f2 = load_grib_data("test4.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_latlon_multi_non_shared_grid(fl_type, array_backend): + f1 = load_grib_data("test.grib", fl_type, array_backend) + f2 = load_grib_data("test4.grib", fl_type, array_backend) f = f1 + f2 with pytest.raises(ValueError): @@ -113,17 +113,17 @@ def test_grib_to_latlon_multi_non_shared_grid(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_points_single(fl_type, backend, index): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_to_points_single(fl_type, array_backend, index): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") eps = 1e-5 g = f[index] if index is not None else f v = g.to_points(flatten=True) assert isinstance(v, dict) - check_array_type(v["x"], backend, dtype="float64") - check_array_type(v["y"], backend, dtype="float64") + check_array_type(v["x"], array_backend, dtype="float64") + check_array_type(v["y"], array_backend, dtype="float64") check_array( v["x"], (84,), @@ -143,18 +143,18 @@ def test_grib_to_points_single(fl_type, backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_to_points_unsupported_grid(fl_type, backend): - f = load_grib_data("mercator.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_points_unsupported_grid(fl_type, array_backend): + f = load_grib_data("mercator.grib", fl_type, array_backend, folder="data") with pytest.raises(ValueError): f[0].to_points() @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_points_multi(fl_type, backend, dtype): - f = load_grib_data("test.grib", fl_type, backend) +def test_grib_to_points_multi(fl_type, array_backend, dtype): + f = load_grib_data("test.grib", fl_type, array_backend) v_ref = f[0].to_points(flatten=True, dtype=dtype) v = f.to_points(flatten=True, dtype=dtype) @@ -169,10 +169,10 @@ def test_grib_to_points_multi(fl_type, backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_to_points_multi_non_shared_grid(fl_type, backend): - f1 = load_grib_data("test.grib", fl_type, backend) - f2 = load_grib_data("test4.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_points_multi_non_shared_grid(fl_type, array_backend): + f1 = load_grib_data("test.grib", fl_type, array_backend) + f2 = load_grib_data("test4.grib", fl_type, array_backend) f = f1 + f2 with pytest.raises(ValueError): @@ -180,9 +180,9 @@ def test_grib_to_points_multi_non_shared_grid(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_bbox(fl_type, backend): - ds = load_grib_data("test.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_bbox(fl_type, array_backend): + ds = load_grib_data("test.grib", fl_type, array_backend) bb = ds.bounding_box() assert len(bb) == 2 for b in bb: @@ -190,10 +190,10 @@ def test_bbox(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_projection_ll(fl_type, backend, index): - f = load_grib_data("test.grib", fl_type, backend) +def test_grib_projection_ll(fl_type, array_backend, index): + f = load_grib_data("test.grib", fl_type, array_backend) if index is not None: g = f[index] @@ -205,9 +205,9 @@ def test_grib_projection_ll(fl_type, backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_projection_mercator(fl_type, backend): - f = load_grib_data("mercator.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_projection_mercator(fl_type, array_backend): + f = load_grib_data("mercator.grib", fl_type, array_backend, folder="data") projection = f[0].projection() assert isinstance(projection, projections.Mercator) assert projection.parameters == { diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py index 5a58b095..f700f547 100644 --- a/tests/grib/test_grib_inidces.py +++ b/tests/grib/test_grib_inidces.py @@ -22,9 +22,9 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_indices_base(fl_type, backend): - ds = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_indices_base(fl_type, array_backend): + ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref = { "class": ["od"], @@ -56,9 +56,9 @@ def test_grib_indices_base(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_indices_sel(fl_type, backend): - ds = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_indices_sel(fl_type, array_backend): + ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref = { "class": ["od"], @@ -86,10 +86,10 @@ def test_grib_indices_sel(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_indices_multi(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_indices_multi(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") ds = f1 + f2 ref = { @@ -153,10 +153,10 @@ def test_grib_indices_multi(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_indices_multi_Del(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_indices_multi_Del(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") ds = f1 + f2 ref = { @@ -179,9 +179,9 @@ def test_grib_indices_multi_Del(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_indices_order_by(fl_type, backend): - ds = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_indices_order_by(fl_type, array_backend): + ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref = { "class": ["od"], diff --git a/tests/grib/test_grib_metadata.py b/tests/grib/test_grib_metadata.py index 07c96ba5..ced1e290 100644 --- a/tests/grib/test_grib_metadata.py +++ b/tests/grib/test_grib_metadata.py @@ -36,7 +36,7 @@ def repeat_list_items(items, count): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -54,8 +54,8 @@ def repeat_list_items(items, count): (("shortName", "level"), ("2t", 0)), ], ) -def test_grib_metadata_grib(fl_type, backend, key, expected_value): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_metadata_grib(fl_type, array_backend, key, expected_value): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") sn = f.metadata(key) assert sn == [expected_value] sn = f[0].metadata(key) @@ -63,7 +63,7 @@ def test_grib_metadata_grib(fl_type, backend, key, expected_value): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -77,8 +77,8 @@ def test_grib_metadata_grib(fl_type, backend, key, expected_value): ("level", int, 0), ], ) -def test_grib_metadata_astype_1(fl_type, backend, key, astype, expected_value): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_metadata_astype_1(fl_type, array_backend, key, astype, expected_value): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") sn = f.metadata(key, astype=astype) assert sn == [expected_value] sn = f[0].metadata(key, astype=astype) @@ -86,7 +86,7 @@ def test_grib_metadata_astype_1(fl_type, backend, key, astype, expected_value): @pytest.mark.parametrize("fs_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -98,15 +98,15 @@ def test_grib_metadata_astype_1(fl_type, backend, key, astype, expected_value): ("level:int", repeat_list_items([1000, 850, 700, 500, 400, 300], 3)), ], ) -def test_grib_metadata_18(fs_type, backend, key, expected_value): +def test_grib_metadata_18(fs_type, array_backend, key, expected_value): # f = load_grib_data("tuv_pl.grib", mode) - ds = load_grib_data("tuv_pl.grib", fs_type, backend) + ds = load_grib_data("tuv_pl.grib", fs_type, array_backend) sn = ds.metadata(key) assert sn == expected_value @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -124,14 +124,14 @@ def test_grib_metadata_18(fs_type, backend, key, expected_value): ), ], ) -def test_grib_metadata_astype_18(fl_type, backend, key, astype, expected_value): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_metadata_astype_18(fl_type, array_backend, key, astype, expected_value): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) sn = f.metadata(key, astype=astype) assert sn == expected_value @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -140,15 +140,15 @@ def test_grib_metadata_astype_18(fl_type, backend, key, astype, expected_value): ("latitudeOfFirstGridPointInDegrees:float", 90.0), ], ) -def test_grib_metadata_double_1(fl_type, backend, key, expected_value): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_metadata_double_1(fl_type, array_backend, key, expected_value): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") r = f.metadata(key) assert len(r) == 1 assert np.isclose(r[0], expected_value) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key", [ @@ -157,8 +157,8 @@ def test_grib_metadata_double_1(fl_type, backend, key, expected_value): ("latitudeOfFirstGridPointInDegrees:float"), ], ) -def test_grib_metadata_double_18(fl_type, backend, key): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_metadata_double_18(fl_type, array_backend, key): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref = [90.0] * 18 r = f.metadata(key) @@ -166,7 +166,7 @@ def test_grib_metadata_double_18(fl_type, backend, key): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype", [ @@ -174,8 +174,8 @@ def test_grib_metadata_double_18(fl_type, backend, key): ("latitudeOfFirstGridPointInDegrees", float), ], ) -def test_grib_metadata_double_astype_18(fl_type, backend, key, astype): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_metadata_double_astype_18(fl_type, array_backend, key, astype): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref = [90.0] * 18 @@ -184,10 +184,10 @@ def test_grib_metadata_double_astype_18(fl_type, backend, key, astype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_get_long_array_1(fl_type, backend): +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_get_long_array_1(fl_type, array_backend): f = load_grib_data( - "rgg_small_subarea_cellarea_ref.grib", fl_type, backend, folder="data" + "rgg_small_subarea_cellarea_ref.grib", fl_type, array_backend, folder="data" ) assert len(f) == 1 @@ -203,9 +203,9 @@ def test_grib_get_long_array_1(fl_type, backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("backend", [None]) -def test_grib_get_double_array_values_1(fl_type, backend): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", [None]) +def test_grib_get_double_array_values_1(fl_type, array_backend): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") v = f.metadata("values") assert len(v) == 1 @@ -223,9 +223,9 @@ def test_grib_get_double_array_values_1(fl_type, backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("backend", [None]) -def test_grib_get_double_array_values_18(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", [None]) +def test_grib_get_double_array_values_18(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) v = f.metadata("values") assert isinstance(v, list) assert len(v) == 18 @@ -254,9 +254,9 @@ def test_grib_get_double_array_values_18(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_get_double_array_1(fl_type, backend): - f = load_grib_data("ml_data.grib", fl_type, backend, folder="data")[0] +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_get_double_array_1(fl_type, array_backend): + f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data")[0] # f is now a field! v = f.metadata("pv") assert isinstance(v, np.ndarray) @@ -268,9 +268,9 @@ def test_grib_get_double_array_1(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_get_double_array_18(fl_type, backend): - f = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_get_double_array_18(fl_type, array_backend): + f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") v = f.metadata("pv") assert isinstance(v, list) assert len(v) == 36 @@ -286,9 +286,9 @@ def test_grib_get_double_array_18(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_metadata_type_qualifier(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend)[0:4] +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_metadata_type_qualifier(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend)[0:4] # to str r = f.metadata("centre:s") @@ -326,9 +326,9 @@ def test_grib_metadata_type_qualifier(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_metadata_astype(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend)[0:4] +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_metadata_astype(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend)[0:4] # to str r = f.metadata("centre", astype=None) @@ -361,9 +361,9 @@ def test_grib_metadata_astype(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_metadata_generic(fl_type, backend): - f_full = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_metadata_generic(fl_type, array_backend): + f_full = load_grib_data("tuv_pl.grib", fl_type, array_backend) f = f_full[0:4] @@ -391,9 +391,9 @@ def test_grib_metadata_generic(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_metadata_missing_value(fl_type, backend): - f = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_metadata_missing_value(fl_type, array_backend): + f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") with pytest.raises(KeyError): f[0].metadata("scaleFactorOfSecondFixedSurface") @@ -403,9 +403,9 @@ def test_grib_metadata_missing_value(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_metadata_missing_key(fl_type, backend): - f = load_grib_data("test.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_metadata_missing_key(fl_type, array_backend): + f = load_grib_data("test.grib", fl_type, array_backend) with pytest.raises(KeyError): f[0].metadata("_badkey_") @@ -415,9 +415,9 @@ def test_grib_metadata_missing_key(fl_type, backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("backend", [None]) -def test_grib_metadata_namespace(fl_type, backend): - f = load_grib_data("test6.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", [None]) +def test_grib_metadata_namespace(fl_type, array_backend): + f = load_grib_data("test6.grib", fl_type, array_backend) r = f[0].metadata(namespace="vertical") ref = {"level": 1000, "typeOfLevel": "isobaricInhPa"} @@ -496,9 +496,9 @@ def test_grib_metadata_namespace(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_datetime(fl_type, backend): - s = load_grib_data("test.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_datetime(fl_type, array_backend): + s = load_grib_data("test.grib", fl_type, array_backend) ref = { "base_time": [datetime.datetime(2020, 5, 13, 12)], @@ -527,18 +527,18 @@ def test_grib_datetime(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_valid_datetime(fl_type, backend): - ds = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_valid_datetime(fl_type, array_backend): + ds = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") f = ds[4] assert f.metadata("valid_datetime") == datetime.datetime(2020, 12, 21, 18) @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("backend", [None]) -def test_message(fl_type, backend): - f = load_grib_data("test.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", [None]) +def test_message(fl_type, array_backend): + f = load_grib_data("test.grib", fl_type, array_backend) v = f[0].message() assert len(v) == 526 assert v[:4] == b"GRIB" diff --git a/tests/grib/test_grib_order_by.py b/tests/grib/test_grib_order_by.py index f36c01f5..f515cf7e 100644 --- a/tests/grib/test_grib_order_by.py +++ b/tests/grib/test_grib_order_by.py @@ -25,9 +25,9 @@ # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_order_by_single_message(fl_type, backend): - s = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_order_by_single_message(fl_type, array_backend): + s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") r = s.order_by("shortName") assert len(r) == 1 @@ -56,7 +56,7 @@ def __call__(self, x, y): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -104,11 +104,11 @@ def __call__(self, x, y): ) def test_grib_order_by_single_file_( fl_type, - backend, + array_backend, params, expected_meta, ): - f = load_grib_data("test6.grib", fl_type, backend) + f = load_grib_data("test6.grib", fl_type, array_backend) g = f.order_by(params) assert len(g) == len(f) @@ -118,7 +118,7 @@ def test_grib_order_by_single_file_( @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -147,9 +147,9 @@ def test_grib_order_by_single_file_( ), ], ) -def test_grib_order_by_multi_file(fl_type, backend, params, expected_meta): - f1 = load_grib_data("test4.grib", fl_type, backend) - f2 = load_grib_data("test6.grib", fl_type, backend) +def test_grib_order_by_multi_file(fl_type, array_backend, params, expected_meta): + f1 = load_grib_data("test4.grib", fl_type, array_backend) + f2 = load_grib_data("test6.grib", fl_type, array_backend) f = from_source("multi", [f1, f2]) g = f.order_by(params) @@ -160,9 +160,9 @@ def test_grib_order_by_multi_file(fl_type, backend, params, expected_meta): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_order_by_with_sel(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_order_by_with_sel(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.sel(level=500) assert len(g) == 3 @@ -178,9 +178,9 @@ def test_grib_order_by_with_sel(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_order_by_valid_datetime(fl_type, backend): - f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_order_by_valid_datetime(fl_type, array_backend): + f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") g = f.order_by(valid_datetime="descending") assert len(g) == 10 diff --git a/tests/grib/test_grib_output.py b/tests/grib/test_grib_output.py index d708d04b..bb07752e 100644 --- a/tests/grib/test_grib_output.py +++ b/tests/grib/test_grib_output.py @@ -25,9 +25,11 @@ EPSILON = 1e-4 -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_save_when_loaded_from_file(backend): - fs = from_source("file", earthkit_examples_file("test6.grib"), backend=backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_save_when_loaded_from_file(array_backend): + fs = from_source( + "file", earthkit_examples_file("test6.grib"), array_backend=array_backend + ) assert len(fs) == 6 with temp_file() as tmp: fs.save(tmp) diff --git a/tests/grib/test_grib_sel.py b/tests/grib/test_grib_sel.py index e3a2df1a..26a15a1a 100644 --- a/tests/grib/test_grib_sel.py +++ b/tests/grib/test_grib_sel.py @@ -27,9 +27,9 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_single_message(fl_type, backend): - s = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_single_message(fl_type, array_backend): + s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") r = s.sel(shortName="2t") assert len(r) == 1 @@ -37,7 +37,7 @@ def test_grib_sel_single_message(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -65,8 +65,10 @@ def test_grib_sel_single_message(fl_type, backend): ), ], ) -def test_grib_sel_single_file_1(fl_type, backend, params, expected_meta, metadata_keys): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_sel_single_file_1( + fl_type, array_backend, params, expected_meta, metadata_keys +): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.sel(**params) assert len(g) == len(expected_meta) @@ -80,9 +82,9 @@ def test_grib_sel_single_file_1(fl_type, backend, params, expected_meta, metadat @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_single_file_2(fl_type, backend): - f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_single_file_2(fl_type, array_backend): + f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") g = f.sel(shortName=["t"], step=[3, 6]) assert len(g) == 2 @@ -102,9 +104,9 @@ def test_grib_sel_single_file_2(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_single_file_as_dict(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_single_file_as_dict(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.sel({"shortName": "t", "level": [500, 700], "mars.type": "an"}) assert len(g) == 2 @@ -115,7 +117,7 @@ def test_grib_sel_single_file_as_dict(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -127,8 +129,10 @@ def test_grib_sel_single_file_as_dict(fl_type, backend): (131, (slice(510, 520)), []), ], ) -def test_grib_sel_slice_single_file(fl_type, backend, param_id, level, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_sel_slice_single_file( + fl_type, array_backend, param_id, level, expected_meta +): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.sel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -137,10 +141,10 @@ def test_grib_sel_slice_single_file(fl_type, backend, param_id, level, expected_ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_multi_file(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_multi_file(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -155,10 +159,10 @@ def test_grib_sel_multi_file(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_slice_multi_file(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_slice_multi_file(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") f = from_source("multi", [f1, f2]) @@ -171,10 +175,10 @@ def test_grib_sel_slice_multi_file(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_date(fl_type, backend): +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_date(fl_type, array_backend): # date and time - f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") + f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") g = f.sel(date=20201221, time=1200, step=9) # g = f.sel(date="20201221", time="12", step="9") @@ -190,9 +194,9 @@ def test_grib_sel_date(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_sel_valid_datetime(fl_type, backend): - f = load_grib_data("t_time_series.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_sel_valid_datetime(fl_type, array_backend): + f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") g = f.sel(valid_datetime=datetime.datetime(2020, 12, 21, 21)) assert len(g) == 2 @@ -207,9 +211,9 @@ def test_grib_sel_valid_datetime(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_isel_single_message(fl_type, backend): - s = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_isel_single_message(fl_type, array_backend): + s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") r = s.isel(shortName=0) assert len(r) == 1 @@ -217,7 +221,7 @@ def test_grib_isel_single_message(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -254,8 +258,10 @@ def test_grib_isel_single_message(fl_type, backend): ), ], ) -def test_grib_isel_single_file(fl_type, backend, params, expected_meta, metadata_keys): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_isel_single_file( + fl_type, array_backend, params, expected_meta, metadata_keys +): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.isel(**params) assert len(g) == len(expected_meta) @@ -268,7 +274,7 @@ def test_grib_isel_single_file(fl_type, backend, params, expected_meta, metadata @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -280,8 +286,10 @@ def test_grib_isel_single_file(fl_type, backend, params, expected_meta, metadata (1, (slice(None, None, 2)), [[131, 850], [131, 500], [131, 300]]), ], ) -def test_grib_isel_slice_single_file(fl_type, backend, param_id, level, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_isel_slice_single_file( + fl_type, array_backend, param_id, level, expected_meta +): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) g = f.isel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -290,9 +298,9 @@ def test_grib_isel_slice_single_file(fl_type, backend, param_id, level, expected @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_isel_slice_invalid(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_isel_slice_invalid(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) with pytest.raises(IndexError): f.isel(level=500) @@ -302,10 +310,10 @@ def test_grib_isel_slice_invalid(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_isel_multi_file(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_isel_multi_file(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -319,10 +327,10 @@ def test_grib_isel_multi_file(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_isel_slice_multi_file(fl_type, backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, backend) - f2 = load_grib_data("ml_data.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_isel_slice_multi_file(fl_type, array_backend): + f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) + f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") f = from_source("multi", [f1, f2]) g = f.isel(shortName=1, level=slice(20, 22)) diff --git a/tests/grib/test_grib_slice.py b/tests/grib/test_grib_slice.py index e17356d0..1e3a42df 100644 --- a/tests/grib/test_grib_slice.py +++ b/tests/grib/test_grib_slice.py @@ -24,7 +24,7 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "index,expected_meta", [ @@ -35,8 +35,8 @@ (-5, ["u", 400]), ], ) -def test_grib_single_index(fl_type, backend, index, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_single_index(fl_type, array_backend, index, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # f = from_source("file", earthkit_examples_file("tuv_pl.grib")) r = f[index] @@ -48,15 +48,15 @@ def test_grib_single_index(fl_type, backend, index, expected_meta): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_single_index_bad(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_single_index_bad(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) with pytest.raises(IndexError): f[27] @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes,expected_meta", [ @@ -68,8 +68,8 @@ def test_grib_single_index_bad(fl_type, backend): (slice(14, None), [["v", 400], ["t", 300], ["u", 300], ["v", 300]]), ], ) -def test_grib_slice_single_file(fl_type, backend, indexes, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_slice_single_file(fl_type, array_backend, indexes, expected_meta): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) r = f[indexes] assert len(r) == 4 assert r.metadata(["shortName", "level"]) == expected_meta @@ -105,13 +105,13 @@ def test_grib_slice_multi_file(indexes, expected_meta): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes1,indexes2", [(np.array([1, 16, 5, 9]), np.array([1, 3])), ([1, 16, 5, 9], [1, 3])], ) -def test_grib_array_indexing(fl_type, backend, indexes1, indexes2): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_array_indexing(fl_type, array_backend, indexes1, indexes2): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) r = f[indexes1] assert len(r) == 4 @@ -123,18 +123,18 @@ def test_grib_array_indexing(fl_type, backend, indexes1, indexes2): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("indexes", [(np.array([1, 19, 5, 9])), ([1, 19, 5, 9])]) -def test_grib_array_indexing_bad(fl_type, backend, indexes): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_array_indexing_bad(fl_type, array_backend, indexes): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) with pytest.raises(IndexError): f[indexes] @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator(fl_type, backend): - g = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator(fl_type, array_backend): + g = load_grib_data("tuv_pl.grib", fl_type, array_backend) sn = g.metadata("shortName") assert len(sn) == 18 iter_sn = [f.metadata("shortName") for f in g] @@ -145,12 +145,12 @@ def test_grib_fieldlist_iterator(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator_with_zip(fl_type, backend): +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator_with_zip(fl_type, array_backend): # test something different to the iterator - does not try to # 'go off the edge' of the fieldlist, because the length is determined by # the list of levels - g = load_grib_data("tuv_pl.grib", fl_type, backend) + g = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref_levs = g.metadata("level") assert len(ref_levs) == 18 levs1 = [] @@ -163,10 +163,10 @@ def test_grib_fieldlist_iterator_with_zip(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, backend): +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, array_backend): # same as test_fieldlist_iterator_with_zip() but multiple times - g = load_grib_data("tuv_pl.grib", fl_type, backend) + g = load_grib_data("tuv_pl.grib", fl_type, array_backend) ref_levs = g.metadata("level") assert len(ref_levs) == 18 for i in range(2): @@ -180,9 +180,9 @@ def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_fieldlist_reverse_iterator(fl_type, backend): - g = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_fieldlist_reverse_iterator(fl_type, array_backend): + g = load_grib_data("tuv_pl.grib", fl_type, array_backend) sn = g.metadata("shortName") sn_reversed = list(reversed(sn)) assert sn_reversed[0] == "v" diff --git a/tests/grib/test_grib_stream.py b/tests/grib/test_grib_stream.py index b004cf6a..59d4a1f8 100644 --- a/tests/grib/test_grib_stream.py +++ b/tests/grib/test_grib_stream.py @@ -14,7 +14,7 @@ from earthkit.data import from_source from earthkit.data.core.temporary import temp_file -from earthkit.data.testing import earthkit_examples_file +from earthkit.data.testing import ARRAY_BACKENDS, earthkit_examples_file def repeat_list_items(items, count): @@ -37,6 +37,7 @@ def test_grib_from_stream_invalid_args(_kwargs, error): from_source("stream", stream, **_kwargs) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "_kwargs", [ @@ -46,9 +47,9 @@ def test_grib_from_stream_invalid_args(_kwargs, error): {"group_by": ["level", "gridType"]}, ], ) -def test_grib_from_stream_group_by(_kwargs): +def test_grib_from_stream_group_by(array_backend, _kwargs): with open(earthkit_examples_file("test6.grib"), "rb") as stream: - fs = from_source("stream", stream, **_kwargs) + fs = from_source("stream", stream, **_kwargs, array_backend=array_backend) # no methods are available with pytest.raises(TypeError): @@ -61,7 +62,9 @@ def test_grib_from_stream_group_by(_kwargs): for i, f in enumerate(fs): assert len(f) == 3 assert f.metadata(("param", "level")) == ref[i] - assert f.to_fieldlist("numpy") is not f + afl = f.to_fieldlist(array_backend=array_backend) + assert afl is not f + assert len(afl) == 3 # stream consumed, no data is available assert sum([1 for _ in fs]) == 0 @@ -95,12 +98,12 @@ def test_grib_from_stream_group_by_convert_to_numpy(convert_kwargs, expected_sha convert_kwargs = {} for i, f in enumerate(ds): - df = f.to_fieldlist("numpy", **convert_kwargs) + df = f.to_fieldlist(array_backend="numpy", **convert_kwargs) assert len(df) == 3 assert df.metadata(("param", "level")) == ref[i] assert df._array.shape == expected_shape assert df.to_numpy(**convert_kwargs).shape == expected_shape - assert df.to_fieldlist("numpy", **convert_kwargs) is df + assert df.to_fieldlist(array_backend="numpy", **convert_kwargs) is df # stream consumed, no data is available assert sum([1 for _ in ds]) == 0 @@ -190,11 +193,11 @@ def test_grib_from_stream_multi_batch_convert_to_numpy(convert_kwargs, expected_ convert_kwargs = {} for i, f in enumerate(ds): - df = f.to_fieldlist("numpy", **convert_kwargs) + df = f.to_fieldlist(array_backend="numpy", **convert_kwargs) assert df.metadata(("param", "level")) == ref[i], i assert df._array.shape == expected_shape, i assert df.to_numpy(**convert_kwargs).shape == expected_shape, i - assert df.to_fieldlist("numpy", **convert_kwargs) is df, i + assert df.to_fieldlist(array_backend="numpy", **convert_kwargs) is df, i # stream consumed, no data is available assert sum([1 for _ in ds]) == 0 @@ -286,7 +289,7 @@ def test_grib_from_stream_in_memory_convert_to_numpy(convert_kwargs, expected_sh batch_size=0, ) - ds = ds_s.to_fieldlist("numpy", **convert_kwargs) + ds = ds_s.to_fieldlist(array_backend="numpy", **convert_kwargs) assert len(ds) == 6 @@ -326,7 +329,7 @@ def test_grib_from_stream_in_memory_convert_to_numpy(convert_kwargs, expected_sh assert np.allclose(vals, ref) assert ds._array.shape == expected_shape - assert ds.to_fieldlist("numpy", **convert_kwargs) is ds + assert ds.to_fieldlist(array_backend="numpy", **convert_kwargs) is ds def test_grib_save_when_loaded_from_stream(): diff --git a/tests/grib/test_grib_summary.py b/tests/grib/test_grib_summary.py index da3bf54e..3b7e3718 100644 --- a/tests/grib/test_grib_summary.py +++ b/tests/grib/test_grib_summary.py @@ -21,9 +21,9 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_describe(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_describe(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # full contents df = f.describe() @@ -147,9 +147,9 @@ def test_grib_describe(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # default keys f1 = f[0:4] @@ -202,9 +202,9 @@ def test_grib_ls(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls_keys(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls_keys(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # default keys # positive num (=head) @@ -229,9 +229,9 @@ def test_grib_ls_keys(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls_namespace(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls_namespace(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) df = f.ls(n=2, namespace="vertical") ref = { @@ -251,9 +251,9 @@ def test_grib_ls_namespace(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls_invalid_num(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls_invalid_num(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) with pytest.raises(ValueError): f.ls(n=0) @@ -263,17 +263,17 @@ def test_grib_ls_invalid_num(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls_invalid_arg(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls_invalid_arg(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) with pytest.raises(TypeError): f.ls(invalid=1) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_ls_num(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_ls_num(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # default keys @@ -319,9 +319,9 @@ def test_grib_ls_num(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_head_num(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_head_num(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # default keys df = f.head(n=2) @@ -345,9 +345,9 @@ def test_grib_head_num(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_tail_num(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_tail_num(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) # default keys df = f.tail(n=2) @@ -371,9 +371,9 @@ def test_grib_tail_num(fl_type, backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("backend", [None]) -def test_grib_dump(fl_type, backend): - f = load_grib_data("test6.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", [None]) +def test_grib_dump(fl_type, array_backend): + f = load_grib_data("test6.grib", fl_type, array_backend) namespaces = ( "default", diff --git a/tests/grib/test_grib_url_stream.py b/tests/grib/test_grib_url_stream.py index 82612eea..55fd2741 100644 --- a/tests/grib/test_grib_url_stream.py +++ b/tests/grib/test_grib_url_stream.py @@ -70,7 +70,7 @@ def test_grib_url_stream_group_by(_kwargs): for i, f in enumerate(fs): assert len(f) == 3 assert f.metadata(("param", "level")) == ref[i] - assert f.to_fieldlist("numpy") is not f + assert f.to_fieldlist(array_backend="numpy") is not f cnt += 1 assert cnt == len(ref) diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index aa7a09a9..b29094cb 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -35,14 +35,14 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_values_1(fl_type, backend): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_values_1(fl_type, array_backend): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") eps = 1e-5 # whole file v = f.values - check_array_type(v, backend, dtype="float64") + check_array_type(v, array_backend, dtype="float64") assert v.shape == (1, 84) v = v[0].flatten() check_array( @@ -57,20 +57,22 @@ def test_grib_values_1(fl_type, backend): # field v1 = f[0].values - check_array_type(v1, backend) + check_array_type(v1, array_backend) assert v1.shape == (84,) assert np.allclose(v, v1, eps) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_values_18(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +# @pytest.mark.parametrize("fl_type", FL_TYPES) +# @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("fl_type", ["file"]) +@pytest.mark.parametrize("array_backend", ["pytorch"]) +def test_grib_values_18(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) eps = 1e-5 # whole file v = f.values - check_array_type(v, backend, dtype="float64") + check_array_type(v, array_backend, dtype="float64") assert v.shape == (18, 84) vf = v[0].flatten() check_array( @@ -94,9 +96,9 @@ def test_grib_values_18(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_to_numpy_1(fl_type, backend): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_numpy_1(fl_type, array_backend): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") eps = 1e-5 v = f.to_numpy() @@ -114,7 +116,7 @@ def test_grib_to_numpy_1(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "first,options, expected_shape", [ @@ -126,8 +128,8 @@ def test_grib_to_numpy_1(fl_type, backend): (True, {"flatten": False}, (7, 12)), ], ) -def test_grib_to_numpy_1_shape(fl_type, backend, first, options, expected_shape): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_to_numpy_1_shape(fl_type, array_backend, first, options, expected_shape): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") v_ref = f[0].to_numpy().flatten() eps = 1e-5 @@ -142,9 +144,9 @@ def test_grib_to_numpy_1_shape(fl_type, backend, first, options, expected_shape) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_to_numpy_18(fl_type, backend): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_numpy_18(fl_type, array_backend): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) eps = 1e-5 @@ -175,7 +177,7 @@ def test_grib_to_numpy_18(fl_type, backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "options, expected_shape", [ @@ -197,8 +199,8 @@ def test_grib_to_numpy_18(fl_type, backend): ({"flatten": False}, (18, 7, 12)), ], ) -def test_grib_to_numpy_18_shape(fl_type, backend, options, expected_shape): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_to_numpy_18_shape(fl_type, array_backend, options, expected_shape): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) eps = 1e-5 @@ -223,10 +225,10 @@ def test_grib_to_numpy_18_shape(fl_type, backend, options, expected_shape): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_1_dtype(fl_type, backend, dtype): - f = load_grib_data("test_single.grib", fl_type, backend, folder="data") +def test_grib_to_numpy_1_dtype(fl_type, array_backend, dtype): + f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -236,10 +238,10 @@ def test_grib_to_numpy_1_dtype(fl_type, backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_18_dtype(fl_type, backend, dtype): - f = load_grib_data("tuv_pl.grib", fl_type, backend) +def test_grib_to_numpy_18_dtype(fl_type, array_backend, dtype): + f = load_grib_data("tuv_pl.grib", fl_type, array_backend) v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -249,7 +251,7 @@ def test_grib_to_numpy_18_dtype(fl_type, backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ @@ -262,8 +264,8 @@ def test_grib_to_numpy_18_dtype(fl_type, backend, dtype): ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_field_data(fl_type, backend, kwarg, expected_shape, expected_dtype): - ds = load_grib_data("test.grib", fl_type, backend) +def test_grib_field_data(fl_type, array_backend, kwarg, expected_shape, expected_dtype): + ds = load_grib_data("test.grib", fl_type, array_backend) latlon = ds[0].to_latlon(**kwarg) v = ds[0].to_numpy(**kwarg) @@ -301,7 +303,7 @@ def test_grib_field_data(fl_type, backend, kwarg, expected_shape, expected_dtype @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ["numpy"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ @@ -314,8 +316,10 @@ def test_grib_field_data(fl_type, backend, kwarg, expected_shape, expected_dtype ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_fieldlist_data(fl_type, backend, kwarg, expected_shape, expected_dtype): - ds = load_grib_data("test.grib", fl_type, backend) +def test_grib_fieldlist_data( + fl_type, array_backend, kwarg, expected_shape, expected_dtype +): + ds = load_grib_data("test.grib", fl_type, array_backend) latlon = ds.to_latlon(**kwarg) v = ds.to_numpy(**kwarg) @@ -354,19 +358,21 @@ def test_grib_fieldlist_data(fl_type, backend, kwarg, expected_shape, expected_d @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("backend", ARRAY_BACKENDS) -def test_grib_values_with_missing(fl_type, backend): - f = load_grib_data("test_single_with_missing.grib", fl_type, backend, folder="data") +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_values_with_missing(fl_type, array_backend): + f = load_grib_data( + "test_single_with_missing.grib", fl_type, array_backend, folder="data" + ) v = f[0].values - check_array_type(v, backend) + check_array_type(v, array_backend) assert v.shape == (84,) eps = 0.001 - ns = get_array_namespace(backend) + ns = get_array_namespace(array_backend) assert ns.count_nonzero(ns.isnan(v)) == 38 - mask = get_array([12, 14, 15, 24, 25, 26] + list(range(28, 60)), backend) + mask = get_array([12, 14, 15, 24, 25, 26] + list(range(28, 60)), array_backend) assert np.isclose(v[0], 260.4356, eps) assert np.isclose(v[11], 260.4356, eps) assert np.isclose(v[-1], 227.1856, eps) diff --git a/tests/core/test_array.py b/tests/utils/test_array.py similarity index 93% rename from tests/core/test_array.py rename to tests/utils/test_array.py index 6b6e59fb..afbb6761 100644 --- a/tests/core/test_array.py +++ b/tests/utils/test_array.py @@ -11,11 +11,11 @@ import pytest -from earthkit.data.core.array import ensure_backend, get_backend from earthkit.data.testing import NO_PYTORCH +from earthkit.data.utils.array import ensure_backend, get_backend -def test_core_array_backend_numpy(): +def test_utils_array_backend_numpy(): b = ensure_backend("numpy") assert b.name == "numpy" @@ -43,7 +43,7 @@ def test_core_array_backend_numpy(): @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") -def test_core_array_backend_pytorch(): +def test_utils_array_backend_pytorch(): b = ensure_backend("pytorch") assert b.name == "pytorch" From 6d94e1b073a6f7aafe7705f25a136ca9e6f25602 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 26 Feb 2024 10:51:42 +0000 Subject: [PATCH 09/18] Add cupy backend --- earthkit/data/testing.py | 11 ++- earthkit/data/utils/array/__init__.py | 7 ++ earthkit/data/utils/array/cupy.py | 68 ++++++++++++++++ earthkit/data/utils/array/numpy.py | 9 ++- earthkit/data/utils/array/pytorch.py | 7 +- tests/grib/test_grib_backend.py | 112 +++++++++++++++++++++++--- tests/grib/test_grib_values.py | 2 +- tests/utils/test_array.py | 28 ++++++- 8 files changed, 223 insertions(+), 21 deletions(-) create mode 100644 earthkit/data/utils/array/cupy.py diff --git a/earthkit/data/testing.py b/earthkit/data/testing.py index 33a5ce0b..5c8aab14 100644 --- a/earthkit/data/testing.py +++ b/earthkit/data/testing.py @@ -102,7 +102,13 @@ def modules_installed(*modules): NO_POLYTOPE = not os.path.exists(os.path.expanduser("~/.polytopeapirc")) NO_ECCOVJSON = not modules_installed("eccovjson") NO_PYTORCH = not modules_installed("torch") - +NO_CUPY = not modules_installed("cupy") +if not NO_CUPY: + try: + import cupy as cp + a = cp.ones(2) + except Exception: + NO_CUPY = True def MISSING(*modules): return not modules_installed(*modules) @@ -174,6 +180,9 @@ def get_array(v, backend, **kwargs): if not NO_PYTORCH: ARRAY_BACKENDS.append("pytorch") +if not NO_CUPY: + ARRAY_BACKENDS.append("cupy") + def main(path): import sys diff --git a/earthkit/data/utils/array/__init__.py b/earthkit/data/utils/array/__init__.py index 3a69dae5..7453a263 100644 --- a/earthkit/data/utils/array/__init__.py +++ b/earthkit/data/utils/array/__init__.py @@ -108,6 +108,9 @@ def _load_core(self): with self.lock: if self._core is None: self._core = ArrayBackendCore(self) + + def _loaded(self): + return self._core is not None @property def available(self): @@ -176,6 +179,10 @@ def from_numpy(self, v): def from_pytorch(self, v): pass + @abstractmethod + def from_cupy(self, v): + pass + @abstractmethod def from_other(self, v, **kwargs): pass diff --git a/earthkit/data/utils/array/cupy.py b/earthkit/data/utils/array/cupy.py new file mode 100644 index 00000000..0525b7fa --- /dev/null +++ b/earthkit/data/utils/array/cupy.py @@ -0,0 +1,68 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import sys + +from . import ArrayBackend + + +class CupyBackend(ArrayBackend): + _name = "cupy" + _array_name = "tensor" + + def _load(self): + try: + import array_api_compat + + except Exception as e: + raise ImportError( + f"array_api_compat is required to use cupy backend, {e}" + ) + + try: + import cupy as cp + except Exception as e: + raise ImportError(f"cupy is required to use cupy backend, {e}") + + dt = {"float64": cp.float64, "float32": cp.float32} + ns = array_api_compat.array_namespace(cp.ones(2)) + + return ns, dt + + def is_native_array(self, v, dtype=None): + if (not self._loaded() and "cupy" not in sys.modules) or not self.available: + return False + + import cupy as cp + + if not isinstance(v, cp.ndarray): + return False + return self.match_dtype(v, dtype) + + def to_backend(self, v, backend): + return backend.from_cupy(v) + + def from_numpy(self, v): + import cupy as cp + + return cp.array(v) + + def from_pytorch(self, v): + return None + + def from_cupy(self, v): + return v + + def from_other(self, v, **kwargs): + import cupy as cp + + return cp.array(v, **kwargs) + + +Backend = CupyBackend diff --git a/earthkit/data/utils/array/numpy.py b/earthkit/data/utils/array/numpy.py index 35f1a9c0..1ead3441 100644 --- a/earthkit/data/utils/array/numpy.py +++ b/earthkit/data/utils/array/numpy.py @@ -7,7 +7,7 @@ # nor does it submit to any jurisdiction. # -import os +import sys from . import ArrayBackend @@ -31,9 +31,9 @@ def to_dtype(self, dtype): return dtype def is_native_array(self, v, dtype=None): - if self.available is None and "numpy" not in os.modules: + if (not self._loaded() and "numpy" not in sys.modules) or not self.available: return False - + import numpy as np if not isinstance(v, np.ndarray): @@ -51,6 +51,9 @@ def from_numpy(self, v): def from_pytorch(self, v): return v.numpy() + def from_cupy(self, v): + return v.get() + def from_other(self, v, **kwargs): import numpy as np diff --git a/earthkit/data/utils/array/pytorch.py b/earthkit/data/utils/array/pytorch.py index 0a0c366e..8ffebab3 100644 --- a/earthkit/data/utils/array/pytorch.py +++ b/earthkit/data/utils/array/pytorch.py @@ -7,7 +7,7 @@ # nor does it submit to any jurisdiction. # -import os +import sys from . import ArrayBackend @@ -36,7 +36,7 @@ def _load(self): return ns, dt def is_native_array(self, v, dtype=None): - if self.available is None and "torch" not in os.modules: + if (not self._loaded() and "torch" not in sys.modules) or not self.available: return False import torch @@ -53,6 +53,9 @@ def from_numpy(self, v): return torch.from_numpy(v) + def from_cupy(self, v): + return None + def from_pytorch(self, v): return v diff --git a/tests/grib/test_grib_backend.py b/tests/grib/test_grib_backend.py index 3b2cdb1d..a886f129 100644 --- a/tests/grib/test_grib_backend.py +++ b/tests/grib/test_grib_backend.py @@ -13,7 +13,7 @@ import pytest from earthkit.data import FieldList, from_source -from earthkit.data.testing import NO_PYTORCH, earthkit_examples_file +from earthkit.data.testing import NO_CUPY, NO_PYTORCH, earthkit_examples_file @pytest.mark.parametrize("_kwargs", [{}, {"array_backend": "numpy"}]) @@ -51,14 +51,14 @@ def test_grib_file_numpy_backend(_kwargs): @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_grib_file_pytorch_backend(): - import torch - ds = from_source( "file", earthkit_examples_file("test6.grib"), array_backend="pytorch" ) assert len(ds) == 6 + import torch + assert torch.is_tensor(ds[0].values) assert ds[0].values.shape == (84,) @@ -68,23 +68,67 @@ def test_grib_file_pytorch_backend(): 84, ) - assert torch.is_tensor(ds[0].to_array()) - assert ds[0].to_array().shape == (7, 12) + x = ds[0].to_array() + assert torch.is_tensor(x) + assert x.shape == (7, 12) - assert torch.is_tensor(ds.to_array()) - assert ds.to_array().shape == (6, 7, 12) + x = ds.to_array() + assert torch.is_tensor(x) + assert x.shape == (6, 7, 12) - assert isinstance(ds[0].to_numpy(), np.ndarray) - assert ds[0].to_numpy().shape == (7, 12) + x = ds[0].to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (7, 12) - assert isinstance(ds.to_numpy(), np.ndarray) - assert ds.to_numpy().shape == (6, 7, 12) + x = ds.to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (6, 7, 12) ds1 = ds.to_fieldlist() assert len(ds1) == len(ds) assert ds1.array_backend.name == "pytorch" assert getattr(ds1, "path", None) is None +@pytest.mark.skipif(NO_CUPY, reason="No cupy installed") +def test_grib_file_cupy_backend(): + ds = from_source( + "file", earthkit_examples_file("test6.grib"), array_backend="cupy" + ) + + import cupy as cp + + assert len(ds) == 6 + + assert isinstance(ds[0].values, cp.ndarray) + assert ds[0].values.shape == (84,) + + assert isinstance(ds.values, cp.ndarray) + assert ds.values.shape == ( + 6, + 84, + ) + + x = ds[0].to_array() + assert isinstance(x, cp.ndarray) + assert x.shape == (7, 12) + + x = ds.to_array() + assert isinstance(x, cp.ndarray) + assert x.shape == (6, 7, 12) + + x = ds[0].to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (7, 12) + + x = ds.to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (6, 7, 12) + + ds1 = ds.to_fieldlist() + assert len(ds1) == len(ds) + assert ds1.array_backend.name == "cupy" + assert getattr(ds1, "path", None) is None + def test_grib_array_numpy_backend(): s = from_source("file", earthkit_examples_file("test6.grib")) @@ -121,8 +165,6 @@ def test_grib_array_numpy_backend(): @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_grib_array_pytorch_backend(): - import torch - s = from_source( "file", earthkit_examples_file("test6.grib"), array_backend="pytorch" ) @@ -135,6 +177,8 @@ def test_grib_array_pytorch_backend(): with pytest.raises(AttributeError): ds.path + import torch + assert torch.is_tensor(ds[0].values) assert ds[0].values.shape == (84,) @@ -156,6 +200,48 @@ def test_grib_array_pytorch_backend(): assert isinstance(ds.to_numpy(), np.ndarray) assert ds.to_numpy().shape == (6, 7, 12) +@pytest.mark.skipif(NO_CUPY, reason="No cupy installed") +def test_grib_array_cupy_backend(): + s = from_source( + "file", earthkit_examples_file("test6.grib"), array_backend="cupy" + ) + + ds = FieldList.from_array( + s.values, + [m for m in s.metadata()], + ) + assert len(ds) == 6 + with pytest.raises(AttributeError): + ds.path + + import cupy as cp + + assert isinstance(ds[0].values, cp.ndarray) + assert ds[0].values.shape == (84,) + + assert isinstance( ds.values, cp.ndarray) + assert ds.values.shape == ( + 6, + 84, + ) + + x = ds[0].to_array() + assert isinstance(x, cp.ndarray) + assert x.shape == (7, 12) + + x = ds.to_array() + assert isinstance(x, cp.ndarray) + assert x.shape == (6, 7, 12) + + x = ds[0].to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (7, 12) + + x = ds.to_numpy() + assert isinstance(x, np.ndarray) + assert x.shape == (6, 7, 12) + + if __name__ == "__main__": from earthkit.data.testing import main diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index b29094cb..f1654f74 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -65,7 +65,7 @@ def test_grib_values_1(fl_type, array_backend): # @pytest.mark.parametrize("fl_type", FL_TYPES) # @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["pytorch"]) +@pytest.mark.parametrize("array_backend", ["numpy"]) def test_grib_values_18(fl_type, array_backend): f = load_grib_data("tuv_pl.grib", fl_type, array_backend) eps = 1e-5 diff --git a/tests/utils/test_array.py b/tests/utils/test_array.py index afbb6761..5b55e96b 100644 --- a/tests/utils/test_array.py +++ b/tests/utils/test_array.py @@ -11,7 +11,7 @@ import pytest -from earthkit.data.testing import NO_PYTORCH +from earthkit.data.testing import NO_CUPY, NO_PYTORCH from earthkit.data.utils.array import ensure_backend, get_backend @@ -68,6 +68,32 @@ def test_utils_array_backend_pytorch(): assert np.isclose(b.array_ns.mean(v), 1.0) +@pytest.mark.skipif(NO_CUPY, reason="No pytorch installed") +def test_utils_array_backend_cupy(): + b = ensure_backend("cupy") + assert b.name == "cupy" + + import numpy as np + import cupy as cp + + v = cp.ones(10, dtype=cp.float64) + v_np = np.ones(10, dtype=np.float64) + v_lst = [1.0] * 10 + + assert b.is_native_array(v) + assert id(b.from_cupy(v)) == id(v) + assert cp.allclose(b.from_numpy(v_np), v) + assert cp.allclose(b.from_other(v_lst, dtype=cp.float64), v) + assert get_backend(v) is b + assert get_backend(v, guess=b) is b + + np_b = ensure_backend("numpy") + r = b.to_backend(v, np_b) + assert isinstance(r, np.ndarray) + assert np.allclose(r, v_np) + + assert np.isclose(b.array_ns.mean(v), 1.0) + if __name__ == "__main__": from earthkit.data.testing import main From e441eff4e1e28fa3551c495bc2477b5edd77a776 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 26 Feb 2024 12:49:32 +0000 Subject: [PATCH 10/18] Impelement array backends for fieldlist --- earthkit/data/core/fieldlist.py | 3 +- earthkit/data/utils/array/__init__.py | 53 +++++++++++++++++---------- earthkit/data/utils/array/cupy.py | 19 ++-------- earthkit/data/utils/array/numpy.py | 19 +++++----- earthkit/data/utils/array/pytorch.py | 14 +++---- tests/utils/test_array.py | 18 ++++++--- 6 files changed, 66 insertions(+), 60 deletions(-) diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index 886d333e..aa9430a7 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -60,7 +60,8 @@ def _to_array(self, v, array_backend=None, source_backend=None): v: array-like The values. array_backend: :obj:`ArrayBackend` - The target array backend. + The target array backend. When it is None ``self.array_backend`` will + be used. source_backend: :obj:`ArrayBackend` The array backend of ``v``. When None, it will be automatically detected. diff --git a/earthkit/data/utils/array/__init__.py b/earthkit/data/utils/array/__init__.py index 7453a263..58e5b7b9 100644 --- a/earthkit/data/utils/array/__init__.py +++ b/earthkit/data/utils/array/__init__.py @@ -99,6 +99,7 @@ class ArrayBackend(metaclass=ABCMeta): _name = None _array_name = "array" _core = None + _converters = {} def __init__(self): self.lock = threading.Lock() @@ -108,7 +109,7 @@ def _load_core(self): with self.lock: if self._core is None: self._core = ArrayBackendCore(self) - + def _loaded(self): return self._core is not None @@ -136,15 +137,22 @@ def name(self): def array_name(self): return f"{self._name} {self._array_name}" - def to_array(self, v, backend=None): - if backend is not None: - if backend is self: - return v + def to_array(self, v, source_backend=None): + r"""Convert an array into the current backend. + + Parameters + ---------- + v: array-like + Array. + source_backend: :obj:`ArrayBackend` + The array backend of ``v``. When it is None automatically detected. - return backend.to_backend(v, self) - else: - b = get_backend(v, strict=False) - return b.to_backend(v, self) + Returns + ------- + array-like + ``v`` converted into the array backend defined by ``self``. + """ + return self.from_backend(v, source_backend) @property def _dtypes(self): @@ -168,20 +176,27 @@ def is_native_array(self, v, **kwargs): pass @abstractmethod - def to_backend(self, v, backend): + def to_numpy(self, v): pass - @abstractmethod - def from_numpy(self, v): - pass + def to_backend(self, v, backend, **kwargs): + assert backend is not None + backend = ensure_backend(backend) + return backend.from_backend(v, self, **kwargs) - @abstractmethod - def from_pytorch(self, v): - pass + def from_backend(self, v, backend, **kwargs): + if backend is None: + backend = get_backend(v, strict=False) - @abstractmethod - def from_cupy(self, v): - pass + if self is backend: + return v + + if backend is not None: + b = self._converters.get(backend.name, None) + if b is not None: + return b(v) + + return self.from_other(v, **kwargs) @abstractmethod def from_other(self, v, **kwargs): diff --git a/earthkit/data/utils/array/cupy.py b/earthkit/data/utils/array/cupy.py index 0525b7fa..2d7bc724 100644 --- a/earthkit/data/utils/array/cupy.py +++ b/earthkit/data/utils/array/cupy.py @@ -21,9 +21,7 @@ def _load(self): import array_api_compat except Exception as e: - raise ImportError( - f"array_api_compat is required to use cupy backend, {e}" - ) + raise ImportError(f"array_api_compat is required to use cupy backend, {e}") try: import cupy as cp @@ -45,19 +43,8 @@ def is_native_array(self, v, dtype=None): return False return self.match_dtype(v, dtype) - def to_backend(self, v, backend): - return backend.from_cupy(v) - - def from_numpy(self, v): - import cupy as cp - - return cp.array(v) - - def from_pytorch(self, v): - return None - - def from_cupy(self, v): - return v + def to_numpy(self, v): + return v.get() def from_other(self, v, **kwargs): import cupy as cp diff --git a/earthkit/data/utils/array/numpy.py b/earthkit/data/utils/array/numpy.py index 1ead3441..00ef5b31 100644 --- a/earthkit/data/utils/array/numpy.py +++ b/earthkit/data/utils/array/numpy.py @@ -33,7 +33,7 @@ def to_dtype(self, dtype): def is_native_array(self, v, dtype=None): if (not self._loaded() and "numpy" not in sys.modules) or not self.available: return False - + import numpy as np if not isinstance(v, np.ndarray): @@ -42,18 +42,17 @@ def is_native_array(self, v, dtype=None): return v.dtype == dtype return True - def to_backend(self, v, backend): - return backend.from_numpy(v) + def from_backend(self, v, backend, **kwargs): + if self is backend: + return v + elif backend is not None: + return backend.to_numpy(v) + else: + return super().from_backend(v, backend, **kwargs) - def from_numpy(self, v): + def to_numpy(self, v): return v - def from_pytorch(self, v): - return v.numpy() - - def from_cupy(self, v): - return v.get() - def from_other(self, v, **kwargs): import numpy as np diff --git a/earthkit/data/utils/array/pytorch.py b/earthkit/data/utils/array/pytorch.py index 8ffebab3..1581933f 100644 --- a/earthkit/data/utils/array/pytorch.py +++ b/earthkit/data/utils/array/pytorch.py @@ -16,6 +16,10 @@ class PytorchBackend(ArrayBackend): _name = "pytorch" _array_name = "tensor" + def __init__(self): + super().__init__() + self._converters = {"numpy": self.from_numpy} + def _load(self): try: import array_api_compat @@ -45,20 +49,14 @@ def is_native_array(self, v, dtype=None): return False return self.match_dtype(v, dtype) - def to_backend(self, v, backend): - return backend.from_pytorch(v) + def to_numpy(self, v): + return v.numpy() def from_numpy(self, v): import torch return torch.from_numpy(v) - def from_cupy(self, v): - return None - - def from_pytorch(self, v): - return v - def from_other(self, v, **kwargs): import torch diff --git a/tests/utils/test_array.py b/tests/utils/test_array.py index 5b55e96b..34bda5e8 100644 --- a/tests/utils/test_array.py +++ b/tests/utils/test_array.py @@ -25,7 +25,9 @@ def test_utils_array_backend_numpy(): v_lst = [1.0] * 10 assert b.is_native_array(v) - assert id(b.from_numpy(v)) == id(v) + assert id(b.to_numpy(v)) == id(v) + assert id(b.from_backend(v, b)) == id(v) + assert id(b.from_backend(v, None)) == id(v) assert np.allclose(b.from_other(v_lst, dtype=np.float64), v) assert get_backend(v) is b assert get_backend(v, guess=b) is b @@ -55,34 +57,38 @@ def test_utils_array_backend_pytorch(): v_lst = [1.0] * 10 assert b.is_native_array(v) - assert id(b.from_pytorch(v)) == id(v) + assert id(b.from_backend(v, b)) == id(v) + assert id(b.from_backend(v, None)) == id(v) + assert torch.allclose(b.from_backend(v_np, None), v) assert torch.allclose(b.from_numpy(v_np), v) assert torch.allclose(b.from_other(v_lst, dtype=torch.float64), v) assert get_backend(v) is b assert get_backend(v, guess=b) is b np_b = ensure_backend("numpy") - r = b.to_backend(v, np_b) + r = b._backend(v, np_b) assert isinstance(r, np.ndarray) assert np.allclose(r, v_np) assert np.isclose(b.array_ns.mean(v), 1.0) + @pytest.mark.skipif(NO_CUPY, reason="No pytorch installed") def test_utils_array_backend_cupy(): b = ensure_backend("cupy") assert b.name == "cupy" - import numpy as np import cupy as cp + import numpy as np v = cp.ones(10, dtype=cp.float64) v_np = np.ones(10, dtype=np.float64) v_lst = [1.0] * 10 assert b.is_native_array(v) - assert id(b.from_cupy(v)) == id(v) - assert cp.allclose(b.from_numpy(v_np), v) + assert id(b.from_backend(v, b)) == id(v) + assert id(b.from_backend(v, None)) == id(v) + assert cp.allclose(b.from_backend(v_np, None), v) assert cp.allclose(b.from_other(v_lst, dtype=cp.float64), v) assert get_backend(v) is b assert get_backend(v, guess=b) is b From 79a7fa1b6f31826ccefc17c2c86f92d2ed66c93e Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 26 Feb 2024 13:20:25 +0000 Subject: [PATCH 11/18] Impelement array backends for fieldlist --- earthkit/data/utils/array/__init__.py | 18 ++++++++++++++---- earthkit/data/utils/array/cupy.py | 9 ++------- earthkit/data/utils/array/numpy.py | 8 ++------ earthkit/data/utils/array/pytorch.py | 8 ++------ tests/utils/test_array.py | 2 +- 5 files changed, 21 insertions(+), 24 deletions(-) diff --git a/earthkit/data/utils/array/__init__.py b/earthkit/data/utils/array/__init__.py index 58e5b7b9..be23207b 100644 --- a/earthkit/data/utils/array/__init__.py +++ b/earthkit/data/utils/array/__init__.py @@ -9,6 +9,7 @@ import logging import os +import sys import threading from abc import ABCMeta, abstractmethod from importlib import import_module @@ -110,9 +111,6 @@ def _load_core(self): if self._core is None: self._core = ArrayBackendCore(self) - def _loaded(self): - return self._core is not None - @property def available(self): self._load_core() @@ -171,10 +169,22 @@ def match_dtype(self, v, dtype): return f return True + def is_native_array(self, v, dtype=None): + if ( + self._core is None and self._module_name not in sys.modules + ) or not self.available: + return False + return self._is_native_array(v, dtype=dtype) + @abstractmethod - def is_native_array(self, v, **kwargs): + def _is_native_array(self, v, **kwargs): pass + def _quick_check_available(self): + return ( + self._core is None and self._module_name not in sys.modules + ) or not self.available + @abstractmethod def to_numpy(self, v): pass diff --git a/earthkit/data/utils/array/cupy.py b/earthkit/data/utils/array/cupy.py index 2d7bc724..baf209b1 100644 --- a/earthkit/data/utils/array/cupy.py +++ b/earthkit/data/utils/array/cupy.py @@ -7,14 +7,12 @@ # nor does it submit to any jurisdiction. # -import sys - from . import ArrayBackend class CupyBackend(ArrayBackend): _name = "cupy" - _array_name = "tensor" + _module_name = "cupy" def _load(self): try: @@ -33,10 +31,7 @@ def _load(self): return ns, dt - def is_native_array(self, v, dtype=None): - if (not self._loaded() and "cupy" not in sys.modules) or not self.available: - return False - + def _is_native_array(self, v, dtype=None): import cupy as cp if not isinstance(v, cp.ndarray): diff --git a/earthkit/data/utils/array/numpy.py b/earthkit/data/utils/array/numpy.py index 00ef5b31..0ee92f86 100644 --- a/earthkit/data/utils/array/numpy.py +++ b/earthkit/data/utils/array/numpy.py @@ -7,13 +7,12 @@ # nor does it submit to any jurisdiction. # -import sys - from . import ArrayBackend class NumpyBackend(ArrayBackend): _name = "numpy" + _module_name = "numpy" def _load(self): import numpy as np @@ -30,10 +29,7 @@ def _load(self): def to_dtype(self, dtype): return dtype - def is_native_array(self, v, dtype=None): - if (not self._loaded() and "numpy" not in sys.modules) or not self.available: - return False - + def _is_native_array(self, v, dtype=None): import numpy as np if not isinstance(v, np.ndarray): diff --git a/earthkit/data/utils/array/pytorch.py b/earthkit/data/utils/array/pytorch.py index 1581933f..80d0c491 100644 --- a/earthkit/data/utils/array/pytorch.py +++ b/earthkit/data/utils/array/pytorch.py @@ -7,14 +7,13 @@ # nor does it submit to any jurisdiction. # -import sys - from . import ArrayBackend class PytorchBackend(ArrayBackend): _name = "pytorch" _array_name = "tensor" + _module_name = "torch" def __init__(self): super().__init__() @@ -39,10 +38,7 @@ def _load(self): return ns, dt - def is_native_array(self, v, dtype=None): - if (not self._loaded() and "torch" not in sys.modules) or not self.available: - return False - + def _is_native_array(self, v, dtype=None): import torch if not torch.is_tensor(v): diff --git a/tests/utils/test_array.py b/tests/utils/test_array.py index 34bda5e8..8160d702 100644 --- a/tests/utils/test_array.py +++ b/tests/utils/test_array.py @@ -66,7 +66,7 @@ def test_utils_array_backend_pytorch(): assert get_backend(v, guess=b) is b np_b = ensure_backend("numpy") - r = b._backend(v, np_b) + r = b.to_backend(v, np_b) assert isinstance(r, np.ndarray) assert np.allclose(r, v_np) From 76195c5f4beb841f084188b4bda9ee26b20194b0 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Mon, 26 Feb 2024 15:37:00 +0000 Subject: [PATCH 12/18] Impelement array backends for fieldlist --- earthkit/data/core/fieldlist.py | 22 +++++++++++----------- earthkit/data/utils/array/__init__.py | 7 ++++--- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/earthkit/data/core/fieldlist.py b/earthkit/data/core/fieldlist.py index aa9430a7..4d2f0c1b 100644 --- a/earthkit/data/core/fieldlist.py +++ b/earthkit/data/core/fieldlist.py @@ -40,15 +40,15 @@ def array_backend(self): @property def raw_values_backend(self): - r""":obj:`ArrayBackend`: Return the array backend the low level API - uses to extract the field values. + r""":obj:`ArrayBackend`: Return the array backend used by the low level API + to extract the field values. """ return self._raw_values_backend @property def raw_other_backend(self): - r""":obj:`ArrayBackend`: Return the array backend the low level API - uses to extract non-field-related values, e.g. latitudes, longitudes. + r""":obj:`ArrayBackend`: Return the array backend used by the low level API + to extract non-field-related values, e.g. latitudes, longitudes. """ return self._raw_other_backend @@ -192,7 +192,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): Specifies the type of data to be returned. Any combination of "lat", "lon" and "value" is allowed here. flatten: bool - When it is True a flat ndarray per key is returned. Otherwise an ndarray with the field's + When it is True a flat array per key is returned. Otherwise an array with the field's :obj:`shape` is returned for each key. dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default @@ -204,7 +204,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): array-like An multi-dimensional array containing one array per key is returned (following the order in ``keys``). When ``keys`` is a single value only the - ndarray belonging to the key is returned. The array format is specified by + array belonging to the key is returned. The array format is specified by :attr:`array_backend`. @@ -666,8 +666,8 @@ def from_array(array, metadata): Parameters ---------- array: array-like, list - The fields' values. When it is a list must contain one array per field. The array - type must be supported by :class:`ArrayBackend`. + The fields' values. When it is a list it must contain one array per field. + The array type must be supported by :class:`ArrayBackend`. metadata: list The fields' metadata. Must contain one :class:`Metadata` object per field. @@ -847,7 +847,7 @@ def to_array(self, **kwargs): @property def values(self): - r"""ndarray: Get all the fields' values as a 2D array. It is formed as the array of + r"""array-likr: Get all the fields' values as a 2D array. It is formed as the array of :obj:`GribField.values ` per field. See Also @@ -1356,7 +1356,7 @@ def write(self, f, **kwargs): s.write(f, **kwargs) def to_fieldlist(self, array_backend=None, **kwargs): - r"""Convert to a new :class:`FieldList` based on the ``array_backend``. + r"""Convert to a new :class:`FieldList`. When the :class:`FieldList` is already in the required format no new :class:`FieldList` is created but the current one is returned. @@ -1364,7 +1364,7 @@ def to_fieldlist(self, array_backend=None, **kwargs): Parameters ---------- array_backend: str, :obj:`ArrayBackend` - Specifies the array backend for the generated fieldlist. The array + Specifies the array backend for the generated :class:`FieldList`. The array type must be supported by :class:`ArrayBackend`. **kwargs: dict, optional diff --git a/earthkit/data/utils/array/__init__.py b/earthkit/data/utils/array/__init__.py index be23207b..be39735d 100644 --- a/earthkit/data/utils/array/__init__.py +++ b/earthkit/data/utils/array/__init__.py @@ -118,7 +118,7 @@ def available(self): @abstractmethod def _load(self): - """Called from arrayBackendCore. It must return ns and dtypes""" + """Load the backend object. Called from arrayBackendCore.""" pass @property @@ -142,8 +142,9 @@ def to_array(self, v, source_backend=None): ---------- v: array-like Array. - source_backend: :obj:`ArrayBackend` - The array backend of ``v``. When it is None automatically detected. + source_backend: :obj:`ArrayBackend`, str + The array backend of ``v``. When None ``source_backend`` + is automatically detected. Returns ------- From 5509fec407cdff0d9b2cca9c12c54fbcda3cdf21 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 16:47:14 +0000 Subject: [PATCH 13/18] Implement array backends for fieldlist --- earthkit/data/testing.py | 4 +++- tests/downstream-ci-requirements.txt | 1 + tests/grib/test_grib_backend.py | 13 +++++-------- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/earthkit/data/testing.py b/earthkit/data/testing.py index 5c8aab14..005e66bc 100644 --- a/earthkit/data/testing.py +++ b/earthkit/data/testing.py @@ -102,14 +102,16 @@ def modules_installed(*modules): NO_POLYTOPE = not os.path.exists(os.path.expanduser("~/.polytopeapirc")) NO_ECCOVJSON = not modules_installed("eccovjson") NO_PYTORCH = not modules_installed("torch") -NO_CUPY = not modules_installed("cupy") +NO_CUPY = not modules_installed("cupy") if not NO_CUPY: try: import cupy as cp + a = cp.ones(2) except Exception: NO_CUPY = True + def MISSING(*modules): return not modules_installed(*modules) diff --git a/tests/downstream-ci-requirements.txt b/tests/downstream-ci-requirements.txt index 0a9b842d..54a5cbc3 100644 --- a/tests/downstream-ci-requirements.txt +++ b/tests/downstream-ci-requirements.txt @@ -17,6 +17,7 @@ pyfdb pyodc pyyaml scipy +torch tqdm xarray>=0.19.0 earthkit-meteo>=0.0.1 diff --git a/tests/grib/test_grib_backend.py b/tests/grib/test_grib_backend.py index a886f129..799541e3 100644 --- a/tests/grib/test_grib_backend.py +++ b/tests/grib/test_grib_backend.py @@ -89,11 +89,10 @@ def test_grib_file_pytorch_backend(): assert ds1.array_backend.name == "pytorch" assert getattr(ds1, "path", None) is None + @pytest.mark.skipif(NO_CUPY, reason="No cupy installed") def test_grib_file_cupy_backend(): - ds = from_source( - "file", earthkit_examples_file("test6.grib"), array_backend="cupy" - ) + ds = from_source("file", earthkit_examples_file("test6.grib"), array_backend="cupy") import cupy as cp @@ -200,11 +199,10 @@ def test_grib_array_pytorch_backend(): assert isinstance(ds.to_numpy(), np.ndarray) assert ds.to_numpy().shape == (6, 7, 12) + @pytest.mark.skipif(NO_CUPY, reason="No cupy installed") def test_grib_array_cupy_backend(): - s = from_source( - "file", earthkit_examples_file("test6.grib"), array_backend="cupy" - ) + s = from_source("file", earthkit_examples_file("test6.grib"), array_backend="cupy") ds = FieldList.from_array( s.values, @@ -219,7 +217,7 @@ def test_grib_array_cupy_backend(): assert isinstance(ds[0].values, cp.ndarray) assert ds[0].values.shape == (84,) - assert isinstance( ds.values, cp.ndarray) + assert isinstance(ds.values, cp.ndarray) assert ds.values.shape == ( 6, 84, @@ -242,7 +240,6 @@ def test_grib_array_cupy_backend(): assert x.shape == (6, 7, 12) - if __name__ == "__main__": from earthkit.data.testing import main From 2d1363f024d48b5000ab241efb3b8b94f4c25d43 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 16:47:55 +0000 Subject: [PATCH 14/18] Implement array backends for fieldlist --- tests/downstream-ci-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/downstream-ci-requirements.txt b/tests/downstream-ci-requirements.txt index 54a5cbc3..8376a1a9 100644 --- a/tests/downstream-ci-requirements.txt +++ b/tests/downstream-ci-requirements.txt @@ -1,3 +1,4 @@ +array_api_compat cdsapi cfgrib>=0.9.10.1 eccodes>=1.5.0 From a2164c78f86a63237e635cf5f821955f719c3dd3 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 17:13:47 +0000 Subject: [PATCH 15/18] Implement array backends for fieldlist --- tests/downstream-ci-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/downstream-ci-requirements.txt b/tests/downstream-ci-requirements.txt index 8376a1a9..9708f802 100644 --- a/tests/downstream-ci-requirements.txt +++ b/tests/downstream-ci-requirements.txt @@ -33,3 +33,4 @@ pytest-forked pytest-timeout nbformat nbconvert +ipykernel From 48ba53b8edbfc09b822daa1baeaef1706b17d8b6 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 17:32:26 +0000 Subject: [PATCH 16/18] Use new ecmwf-opendata --- setup.cfg | 2 +- tests/downstream-ci-requirements.txt | 2 +- tests/environment-unit-tests.yml | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.cfg b/setup.cfg index 5c19020c..1bbb626a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,7 +25,7 @@ install_requires = cfgrib>=0.9.10.1 eccodes>=1.5.0 ecmwf-api-client>=1.6.1 - ecmwf-opendata>=0.1.2 + ecmwf-opendata>=0.3.3 polytope-client>=0.7.2 dask entrypoints diff --git a/tests/downstream-ci-requirements.txt b/tests/downstream-ci-requirements.txt index 9708f802..29f92053 100644 --- a/tests/downstream-ci-requirements.txt +++ b/tests/downstream-ci-requirements.txt @@ -3,7 +3,7 @@ cdsapi cfgrib>=0.9.10.1 eccodes>=1.5.0 ecmwf-api-client>=1.6.1 -ecmwf-opendata>=0.1.2 +ecmwf-opendata>=0.3.3 polytope-client>=0.7.2 dask entrypoints diff --git a/tests/environment-unit-tests.yml b/tests/environment-unit-tests.yml index be929f5a..b833bce5 100644 --- a/tests/environment-unit-tests.yml +++ b/tests/environment-unit-tests.yml @@ -26,7 +26,7 @@ dependencies: - pip: - git+https://github.com/ecmwf/multiurl - git+https://github.com/ecmwf/pyfdb - - ecmwf-opendata>=0.1.2 + - ecmwf-opendata>=0.3.3 - polytope-client>=0.7.2 - earthkit-meteo>=0.0.1 - git+https://github.com/ecmwf/earthkit-data-demo-source From 8a1ce7d8eed93bd96b31aa92752e28ebe8648dea Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 17:57:00 +0000 Subject: [PATCH 17/18] Disable ecmwf open data notebook test --- tests/documentation/test_notebooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/documentation/test_notebooks.py b/tests/documentation/test_notebooks.py index 5d91e385..e622b563 100644 --- a/tests/documentation/test_notebooks.py +++ b/tests/documentation/test_notebooks.py @@ -31,6 +31,7 @@ "polytope.ipynb", "grib_fdb_write.ipynb", "demo_source_plugin.ipynb", + "ecmwf_open_data.ipynb", ] if NO_PYTORCH: From 14ea2821311de86a15f25692176d006433fd023a Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 5 Mar 2024 18:13:10 +0000 Subject: [PATCH 18/18] Disable pandas notebook test --- tests/documentation/test_notebooks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/documentation/test_notebooks.py b/tests/documentation/test_notebooks.py index e622b563..1be2a2fa 100644 --- a/tests/documentation/test_notebooks.py +++ b/tests/documentation/test_notebooks.py @@ -32,6 +32,7 @@ "grib_fdb_write.ipynb", "demo_source_plugin.ipynb", "ecmwf_open_data.ipynb", + "pandas.ipynb", ] if NO_PYTORCH: