diff --git a/docs/examples/grib_array_backends.ipynb b/docs/examples/grib_array_backends.ipynb index f950a810..95d2b4de 100644 --- a/docs/examples/grib_array_backends.ipynb +++ b/docs/examples/grib_array_backends.ipynb @@ -11,7 +11,7 @@ "tags": [] }, "source": [ - "## GRIB: using array backends" + "## GRIB: using array fieldlists" ] }, { @@ -25,7 +25,7 @@ "tags": [] }, "source": [ - "In this example we will use a GRIB file containing 4 messages. First we ensure the file is available." + "In this example we will use a GRIB file containing 4 messages. First we ensure the file is available and read it into a fieldlist." ] }, { @@ -42,7 +42,8 @@ "outputs": [], "source": [ "import earthkit.data\n", - "earthkit.data.download_example_file(\"test4.grib\")" + "earthkit.data.download_example_file(\"test4.grib\")\n", + "ds_in = earthkit.data.from_source(\"file\", \"test4.grib\")" ] }, { @@ -54,15 +55,81 @@ "slideshow": { "slide_type": "" }, - "tags": [] + "tags": [], + "vscode": { + "languageId": "raw" + } }, "source": [ - "When reading GRIB data with :func:`from_source` we can specify the ``array_backend`` we want to use when extracting the field values. The default backend is \"numpy\". For this example we choose the \"pytorch\" backend. Since pytorch is an optional dependency for earthkit-data we need to ensure it is installed in the environment. We also need to install \"array_api_compat\" to make the array backends work." + "Using the :meth:`~data.core.fieldlist.FieldList.to_fieldlist` method we can convert this object into an array fieldlist where each field contains an array (holding the field values) and a :py:class:`~data.readers.grib.metadata.RestrictedGribMetadata` object representing the related metadata. Array fieldlists are entirely stored in memory. The resulting array format is controlled by ``array_backend`` keyword argument of :meth:`~data.core.fieldlist.FieldList.to_fieldlist`. When using its default value (None) the underlying array format of the original fieldlist is kept. For GRIB data read from a file or stream this will be \"numpy\". " + ] + }, + { + "cell_type": "markdown", + "id": "3374ab00-054b-4aba-a355-f119b619be1e", + "metadata": {}, + "source": [ + "### Numpy array fieldlist" + ] + }, + { + "cell_type": "markdown", + "id": "2d99d5aa-8e9e-4c75-8394-188891c75c29", + "metadata": {}, + "source": [ + "The \"numpy\" fieldlist we generate in the cell below works exactly in the same way as the original one but stores all the data in memory." ] }, { "cell_type": "code", "execution_count": 2, + "id": "6ec4489f-2030-4a55-8292-7f50fb845677", + "metadata": {}, + "outputs": [], + "source": [ + "ds = ds_in.to_fieldlist()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "07e0823d-ff1d-43d1-8ffb-e5a6d0616ce4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(ds)" + ] + }, + { + "cell_type": "markdown", + "id": "e4106883-8035-4da5-90eb-b83f0d903c4b", + "metadata": {}, + "source": [ + "### Pytorch array fieldlist" + ] + }, + { + "cell_type": "markdown", + "id": "e34b8492-60d5-4052-abf9-b02d1a692c1d", + "metadata": {}, + "source": [ + "For the next example we choose the \"pytorch\" array backend. Since pytorch is an optional dependency for earthkit-data we need to ensure it is installed in the environment." + ] + }, + { + "cell_type": "code", + "execution_count": 4, "id": "b3b30d9f-0edb-4938-baec-7026acd70192", "metadata": { "editable": true, @@ -73,13 +140,12 @@ }, "outputs": [], "source": [ - "!pip install torch --quiet\n", - "!pip install array_api_compat --quiet" + "!pip install torch --quiet" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 5, "id": "5d2174d7-0f36-4b20-8ad5-bd93fd12f91b", "metadata": { "editable": true, @@ -90,12 +156,12 @@ }, "outputs": [], "source": [ - "ds = earthkit.data.from_source(\"file\", \"test4.grib\", array_backend=\"pytorch\")" + "ds = ds_in.to_fieldlist(array_backend=\"pytorch\")" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, "id": "3c108a25-5f41-422f-9adb-98c932205dce", "metadata": { "editable": true, @@ -209,7 +275,7 @@ "3 an 0 regular_ll " ] }, - "execution_count": 4, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -249,7 +315,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 7, "id": "21a1e8b1-f0e5-4de6-bdbf-e92c5df5989e", "metadata": { "editable": true, @@ -266,7 +332,7 @@ " 228.0460, 228.0460, 228.0460], dtype=torch.float64)" ] }, - "execution_count": 5, + "execution_count": 7, "metadata": {}, "output_type": "execute_result" } @@ -277,7 +343,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 8, "id": "fbd92126-bb5b-47e5-80a7-d4bed3097764", "metadata": { "editable": true, @@ -293,7 +359,7 @@ "torch.Size([65160])" ] }, - "execution_count": 6, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -304,7 +370,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 9, "id": "714b320c-90ea-4326-bc5f-340dda66daab", "metadata": { "editable": true, @@ -320,7 +386,7 @@ "torch.Size([4, 65160])" ] }, - "execution_count": 7, + "execution_count": 9, "metadata": {}, "output_type": "execute_result" } @@ -355,12 +421,12 @@ "tags": [] }, "source": [ - "The :py:meth:`Field.to_array() ` and :py:meth:`FieldList.to_array() ` methods return the values based on the underlying backend. " + ":py:meth:`Field.to_array() ` and :py:meth:`FieldList.to_array() ` return the values based on the underlying backend. " ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 10, "id": "f60797eb-d578-4638-b1d0-bd18949dd249", "metadata": { "editable": true, @@ -377,7 +443,7 @@ " [228.6085, 228.5792]], dtype=torch.float64)" ] }, - "execution_count": 8, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -388,9 +454,15 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 11, "id": "f4d053fa-2acf-4949-9bbd-a0b2ccf30318", - "metadata": {}, + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, "outputs": [ { "data": { @@ -398,7 +470,7 @@ "torch.Size([4, 181, 360])" ] }, - "execution_count": 9, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -409,7 +481,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 12, "id": "04cd31df-34cd-47a9-90cc-833b9805bd55", "metadata": { "editable": true, @@ -425,7 +497,7 @@ "torch.Size([4, 65160])" ] }, - "execution_count": 10, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -436,7 +508,7 @@ }, { "cell_type": "markdown", - "id": "f0efdd2b-e078-43e2-936e-35db3a645a6c", + "id": "87cff063-95af-45b9-bae1-91cb28ca23ea", "metadata": { "editable": true, "slideshow": { @@ -445,12 +517,12 @@ "tags": [] }, "source": [ - "#### Array fieldlists" + "#### to_numpy()" ] }, { "cell_type": "raw", - "id": "ea2a5619-6022-4166-85f6-227b9282ffa7", + "id": "23eba4ad-a0f3-4f94-aff3-83b7a0dd07d6", "metadata": { "editable": true, "raw_mimetype": "text/restructuredtext", @@ -460,13 +532,13 @@ "tags": [] }, "source": [ - "Our fieldlist can be converted into an in-memory :py:class:`~data.sources.array_list.ArrayFieldList` where each message consists of a :py:class:`~data.readers.grib.metadata.GribMetadata` object and an array with the given backend storing the field values." + ":py:meth:`Field.to_numpy() ` and :py:meth:`FieldList.to_numpy() ` still return ndarrays." ] }, { "cell_type": "code", - "execution_count": 11, - "id": "c5d3efc1-1299-4e0d-b59a-301e795bffc5", + "execution_count": 13, + "id": "b60eea0f-0da1-48f1-a64e-f1b75e36a737", "metadata": { "editable": true, "slideshow": { @@ -477,27 +549,83 @@ "outputs": [ { "data": { - "text/html": [ - "ArrayFieldList(fields=4)" - ], "text/plain": [ - "ArrayFieldList(fields=4)" + "array([[228.04600525, 228.04600525],\n", + " [228.60850525, 228.57920837]])" ] }, - "execution_count": 11, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r = ds.to_fieldlist(array_backend=\"pytorch\")\n", - "r" + "ds[0].to_numpy()[:2,:2]" ] }, { "cell_type": "code", - "execution_count": 12, - "id": "f87b0384-60cb-4bf2-8669-e193427c28e1", + "execution_count": 14, + "id": "ae6bd7cd-b043-4993-b715-e5ea2531490a", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "(4, 181, 360)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.to_numpy().shape" + ] + }, + { + "cell_type": "markdown", + "id": "f0efdd2b-e078-43e2-936e-35db3a645a6c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "### Building array fieldlists with from_array()" + ] + }, + { + "cell_type": "raw", + "id": "3cdcbb75-86b1-4c38-b667-73ac563c8d97", + "metadata": { + "editable": true, + "raw_mimetype": "text/restructuredtext", + "slideshow": { + "slide_type": "" + }, + "tags": [], + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "Whe can build a new array fieldlist straight from metadata and array values using :meth:`~data.core.fieldlist.FieldList.from_array`. This can be used for computations when we want to alter the values and store the result in a new FieldList." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a92db7cb-e2e1-472e-92b6-0f42360d2105", "metadata": { "editable": true, "slideshow": { @@ -610,19 +738,36 @@ "3 an 0 regular_ll " ] }, - "execution_count": 12, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r.ls()" + "md = ds.metadata()\n", + "v = ds.to_array() + 2\n", + "r1 = earthkit.data.FieldList.from_array(v, md)\n", + "r1.ls()" + ] + }, + { + "cell_type": "markdown", + "id": "920f6b98-f0e6-4ffc-a5e4-f8f643c23f76", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "As expected, the values in *r1* are now differing by 2 from the ones in the original fieldlist (*r*)." ] }, { "cell_type": "code", - "execution_count": 13, - "id": "871e9c13-06e8-4ed6-90a9-8696c95ede8b", + "execution_count": 16, + "id": "5b78ea8a-64e9-4bfa-9d11-8b390995994b", "metadata": { "editable": true, "slideshow": { @@ -634,45 +779,32 @@ { "data": { "text/plain": [ - "tensor([228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460, 228.0460,\n", - " 228.0460, 228.0460, 228.0460], dtype=torch.float64)" + "tensor([230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460,\n", + " 230.0460, 230.0460, 230.0460], dtype=torch.float64)" ] }, - "execution_count": 13, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "r[0].values[:10]" + "r1[0].values[:10]" ] }, { - "cell_type": "raw", - "id": "3cdcbb75-86b1-4c38-b667-73ac563c8d97", - "metadata": { - "editable": true, - "raw_mimetype": "text/restructuredtext", - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "cell_type": "markdown", + "id": "48f6f5a4-1304-43ef-b357-563cfa071309", + "metadata": {}, "source": [ - "Whe can build a new :py:class:`~data.sources.array_list.ArrayFieldList` straight from metadata and array values. This can be used for computations when we want to alter the values and store the result in a new FieldList." + "### Building an array fieldlist in a loop" ] }, { "cell_type": "code", - "execution_count": 14, - "id": "a92db7cb-e2e1-472e-92b6-0f42360d2105", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, + "execution_count": 17, + "id": "a4b086a5-406d-4d1c-bd93-1edbde84bf81", + "metadata": {}, "outputs": [ { "data": { @@ -778,58 +910,37 @@ "3 an 0 regular_ll " ] }, - "execution_count": 14, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ + "from earthkit.data import SimpleFieldList\n", + "from earthkit.data import ArrayField\n", + "\n", "md = ds.metadata()\n", "v = ds.to_array() + 2\n", - "r1 = earthkit.data.FieldList.from_array(v, md)\n", + "\n", + "r1 = SimpleFieldList()\n", + "for k in range(len(md)):\n", + " r1.append(ArrayField(v[k], md[k]))\n", "r1.ls()" ] }, { "cell_type": "markdown", - "id": "920f6b98-f0e6-4ffc-a5e4-f8f643c23f76", - "metadata": { - "editable": true, - "slideshow": { - "slide_type": "" - }, - "tags": [] - }, - "source": [ - "As expected, the values in *r1* are now differing by 2 from the ones in the originial FieldList (*r*)." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "5b78ea8a-64e9-4bfa-9d11-8b390995994b", + "id": "78b1c896-a892-4192-99a1-bc9dc1cfcd4f", "metadata": { "editable": true, + "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, - "outputs": [ - { - "data": { - "text/plain": [ - "tensor([230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460, 230.0460,\n", - " 230.0460, 230.0460, 230.0460], dtype=torch.float64)" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ - "r1[0].values[:10]" + "### Saving to GRIB" ] }, { @@ -844,12 +955,12 @@ "tags": [] }, "source": [ - "We can save the :py:class:`~data.sources.numpy_list.ArrayFieldList` into a GRIB file:" + "We can save array fieldlists into GRIB." ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 18, "id": "9ceea7b3-5059-4378-9a26-d282caa3b74a", "metadata": { "editable": true, @@ -963,7 +1074,7 @@ "3 an 0 regular_ll " ] }, - "execution_count": 16, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/examples/numpy_fieldlist.ipynb b/docs/examples/numpy_fieldlist.ipynb index e91a30e4..92beb62b 100644 --- a/docs/examples/numpy_fieldlist.ipynb +++ b/docs/examples/numpy_fieldlist.ipynb @@ -26,7 +26,7 @@ "tags": [] }, "source": [ - "In this notebook we will show how to do some computations with GRIB data and generate a :py:class:`~data.sources.numpy_list.ArrayFieldList` from the results.\n", + "In this notebook we will show how to do some computations with GRIB data and generate a array fieldlist from the results.\n", "\n", "First we :ref:`read ` some GRIB data containing pressure level fields." ] @@ -227,7 +227,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 7, @@ -280,7 +280,7 @@ "tags": [] }, "source": [ - "A new FieldList (type of :py:class:`~data.sources.numpy_list.ArrayFieldList`) can be created from the resulting ndarray and the modified metadata. It behaves as if it were a :py:class:`~data.readers.grib.index.GribFieldList`." + "A new FieldList (type of :py:class:`~data.indexing.fieldlist.SimpleFieldList`) can be created from the resulting ndarray and the modified metadata. It will store a list of :py:class:`~data.sources.array_list.ArrayField` fields each containing a values array and a metadata object entirely in memory. This fieldlist behaves as if it were a :py:class:`~data.readers.grib.index.GribFieldList`." ] }, { @@ -373,7 +373,7 @@ { "data": { "text/plain": [ - "ArrayField()" + "ArrayField(pt,850,20180801,1200,0,0)" ] }, "execution_count": 10, @@ -447,7 +447,7 @@ "tags": [] }, "source": [ - "We can save the :py:class:`~data.sources.numpy_list.ArrayFieldList` into a GRIB file:" + "We can save our new fieldlist into a GRIB file:" ] }, { @@ -716,7 +716,7 @@ "tags": [] }, "source": [ - "We create a :py:class:`~data.sources.numpy_list.ArrayFieldList` from the resulting ndarray and the modified metadata." + "We create an array fieldlist from the resulting ndarray and the modified metadata." ] }, { @@ -887,7 +887,7 @@ "tags": [] }, "source": [ - "We can save the :py:class:`~data.sources.numpy_list.ArrayFieldList` into a GRIB file:" + "We can save the fieldlist into a GRIB file:" ] }, { @@ -1169,10 +1169,13 @@ "slideshow": { "slide_type": "" }, - "tags": [] + "tags": [], + "vscode": { + "languageId": "raw" + } }, "source": [ - "We can save the :py:class:`~data.sources.numpy_list.ArrayFieldList` into a GRIB file:" + "We can save the results into a GRIB file:" ] }, { @@ -1293,9 +1296,9 @@ ], "metadata": { "kernelspec": { - "display_name": "dev", + "display_name": "dev_ecc", "language": "python", - "name": "dev" + "name": "dev_ecc" }, "language_info": { "codemirror_mode": { diff --git a/pyproject.toml b/pyproject.toml index dd7fc265..c7ae81f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,7 @@ dynamic = [ "version", ] dependencies = [ + "array-api-compat", "cfgrib>=0.9.10.1", "dask", "earthkit-geo>=0.2", @@ -68,7 +69,6 @@ optional-dependencies.cds = [ "cdsapi>=0.7.1", ] optional-dependencies.ci = [ - "array-api-compat", "torch", ] optional-dependencies.covjsonkit = [ diff --git a/src/earthkit/data/__init__.py b/src/earthkit/data/__init__.py index 34a2fd98..8e1c642d 100644 --- a/src/earthkit/data/__init__.py +++ b/src/earthkit/data/__init__.py @@ -22,15 +22,18 @@ from .core.caching import CACHE as cache from .core.fieldlist import FieldList from .core.settings import SETTINGS as settings +from .indexing.fieldlist import SimpleFieldList from .readers.grib.output import new_grib_output from .sources import Source from .sources import from_source from .sources import from_source_lazily +from .sources.array_list import ArrayField from .utils.examples import download_example_file from .utils.examples import remote_example_file __all__ = [ "ALL", + "ArrayField", "cache", "download_example_file", "FieldList", @@ -41,6 +44,7 @@ "new_grib_output", "remote_example_file", "settings", + "SimpleFieldList", "Source", "__version__", ] diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index cf4d9e77..ad268fbf 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -18,8 +18,9 @@ from earthkit.data.core.index import MultiIndex from earthkit.data.decorators import cached_method from earthkit.data.decorators import detect_out_filename -from earthkit.data.utils.array import ensure_backend -from earthkit.data.utils.array import numpy_backend +from earthkit.data.utils.array import array_namespace +from earthkit.data.utils.array import array_to_numpy +from earthkit.data.utils.array import convert_array from earthkit.data.utils.metadata.args import metadata_argument @@ -80,59 +81,9 @@ class Field(Base): def __init__( self, - array_backend, metadata=None, - raw_values_backend=None, - raw_other_backend=None, ): self.__metadata = metadata - self._array_backend = array_backend - self._raw_values_backend = ensure_backend(raw_values_backend) - self._raw_other_backend = ensure_backend(raw_other_backend) - - @property - def array_backend(self): - r""":obj:`ArrayBackend`: Return the array backend of the field.""" - return self._array_backend - - @property - def raw_values_backend(self): - r""":obj:`ArrayBackend`: Return the array backend used by the low level API - to extract the field values. - """ - return self._raw_values_backend - - @property - def raw_other_backend(self): - r""":obj:`ArrayBackend`: Return the array backend used by the low level API - to extract non-field-related values, e.g. latitudes, longitudes. - """ - return self._raw_other_backend - - def _to_array(self, v, array_backend=None, source_backend=None): - r"""Convert an array into an ``array backend``. - - Parameters - ---------- - v: array-like - The values. - array_backend: :obj:`ArrayBackend` - The target array backend. When it is None ``self.array_backend`` will - be used. - source_backend: :obj:`ArrayBackend` - The array backend of ``v``. When None, it will be automatically detected. - - Returns - ------- - array-like - ``v`` converted onto the ``array_backend``. - - """ - if array_backend is None: - return self._array_backend.to_array(v, source_backend) - else: - array_backend = ensure_backend(array_backend) - return array_backend.to_array(v, source_backend) @abstractmethod def _values(self, dtype=None): @@ -146,27 +97,20 @@ def _values(self, dtype=None): type used by the underlying data accessor is used. For GRIB it is ``float64``. - The original shape and array backend type of the raw values are kept. + The original shape and array type of the raw values are kept. Returns ------- array-like - Field values in the format specified by :attr:`raw_values_backend`. + Field values. """ self._not_implemented() @property def values(self): - r"""array-like: Get the values stored in the field as a 1D array. The array type - is defined by :attr:`array_backend` - """ - v = self._to_array(self._values(), source_backend=self.raw_values_backend) - if len(v.shape) != 1: - n = math.prod(v.shape) - n = (n,) - return self._array_backend.array_ns.reshape(v, n) - return v + r"""array-like: Get the values stored in the field as a 1D array.""" + return self._flatten(self._values()) @property def _metadata(self): @@ -197,8 +141,7 @@ def to_numpy(self, flatten=False, dtype=None, index=None): Field values """ - v = self._values(dtype=dtype) - v = numpy_backend().to_array(v, self.raw_values_backend) + v = array_to_numpy(self._values(dtype=dtype)) shape = self._required_shape(flatten) if shape != v.shape: v = v.reshape(shape) @@ -207,8 +150,7 @@ def to_numpy(self, flatten=False, dtype=None, index=None): return v def to_array(self, flatten=False, dtype=None, array_backend=None, index=None): - r"""Return the values stored in the field in the - format of :attr:`array_backend`. + r"""Return the values stored in the field. Parameters ---------- @@ -218,6 +160,9 @@ def to_array(self, flatten=False, dtype=None, array_backend=None, index=None): dtype: str, array.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + array_backend: str, module or None + The array backend to be used. When it is :obj:`None` the underlying array format + of the field is used. index: array indexing object, optional The index of the values and to be extracted. When it is None all the values are extracted @@ -225,30 +170,18 @@ def to_array(self, flatten=False, dtype=None, array_backend=None, index=None): Returns ------- array-array - Field values in the format od :attr:`array_backend`. + Field values. """ - v = self._to_array( - self._values(dtype=dtype), - array_backend=array_backend, - source_backend=self.raw_values_backend, - ) - shape = self._required_shape(flatten) - if shape != v.shape: - v = self._array_backend.array_ns.reshape(v, shape) + v = self._values(dtype=dtype) + if array_backend is not None: + v = convert_array(v, target_backend=array_backend) + + v = self._reshape(v, flatten) if index is not None: v = v[index] return v - def _required_shape(self, flatten, shape=None): - if shape is None: - shape = self.shape - return shape if not flatten else (math.prod(shape),) - - def _array_matches(self, array, flatten=False, dtype=None): - shape = self._required_shape(flatten) - return shape == array.shape and (dtype is None or dtype == array.dtype) - def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=None): r"""Return the values and/or the geographical coordinates for each grid point. @@ -271,9 +204,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=No ------- array-like An multi-dimensional array containing one array per key is returned - (following the order in ``keys``). When ``keys`` is a single value only the - array belonging to the key is returned. The array format is specified by - :attr:`array_backend`. + (following the order in ``keys``). The underlying array format + of the field is used. When ``keys`` is a single value only the + array belonging to the key is returned. Examples -------- @@ -305,9 +238,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=No """ _keys = dict( - lat=(self._metadata.geography.latitudes, self.raw_other_backend), - lon=(self._metadata.geography.longitudes, self.raw_other_backend), - value=(self._values, self.raw_values_backend), + lat=self._metadata.geography.latitudes, + lon=self._metadata.geography.longitudes, + value=self._values, ) if isinstance(keys, str): @@ -317,20 +250,29 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=No if k not in _keys: raise ValueError(f"data: invalid argument: {k}") - r = [] + r = {} for k in keys: - v = self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) - shape = self._required_shape(flatten) - if shape != v.shape: - v = self._array_backend.array_ns.reshape(v, shape) + # TODO: convert dtype + v = _keys[k](dtype=dtype) + v = self._reshape(v, flatten) if index is not None: v = v[index] - r.append(v) - + r[k] = v + + # convert latlon to array format + ll = {k: r[k] for k in r if k != "value"} + if ll: + sample = r.get("value", None) + if sample is None: + sample = self._values(dtype=dtype) + for k, v in zip(ll.keys(), convert_array(list(ll.values()), target_array_sample=sample)): + r[k] = v + + r = list(r.values()) if len(r) == 1: return r[0] else: - return self._array_backend.array_ns.stack(r) + return array_namespace(r[0]).stack(r) def to_points(self, flatten=False, dtype=None, index=None): r"""Return the geographical coordinates in the data's original @@ -353,8 +295,8 @@ def to_points(self, flatten=False, dtype=None, index=None): ------- dict Dictionary with items "x" and "y", containing the arrays of the x and - y coordinates, respectively. The array format is specified by - :attr:`array_backend`. + y coordinates, respectively. The underlying array format + of the field is used. Raises ------ @@ -368,23 +310,27 @@ def to_points(self, flatten=False, dtype=None, index=None): """ x = self._metadata.geography.x(dtype=dtype) y = self._metadata.geography.y(dtype=dtype) + r = {} if x is not None and y is not None: - x = self._to_array(x, source_backend=self.raw_other_backend) - y = self._to_array(y, source_backend=self.raw_other_backend) - shape = self._required_shape(flatten) - if shape != x.shape: - x = self._array_backend.array_ns.reshape(x, shape) - y = self._array_backend.array_ns.reshape(y, shape) + x = self._reshape(x, flatten) + y = self._reshape(y, flatten) if index is not None: x = x[index] y = y[index] - return dict(x=x, y=y) + r = dict(x=x, y=y) elif self.projection().CARTOPY_CRS == "PlateCarree": lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index) return dict(x=lon, y=lat) else: raise ValueError("to_points(): geographical coordinates in original CRS are not available") + # convert values to array format + assert r + sample = self._values(dtype=dtype) + for k, v in zip(r.keys(), convert_array(list(r.values()), target_array_sample=sample)): + r[k] = v + return r + def to_latlon(self, flatten=False, dtype=None, index=None): r"""Return the latitudes/longitudes of all the gridpoints in the field. @@ -405,8 +351,8 @@ def to_latlon(self, flatten=False, dtype=None, index=None): ------- dict Dictionary with items "lat" and "lon", containing the arrays of the latitudes and - longitudes, respectively. The array format is specified by - :attr:`array_backend`. + longitudes, respectively. The underlying array format + of the field is used. See Also -------- @@ -759,62 +705,110 @@ def _attributes(self, names, remapping=None, joiner=None, default=None): # return {name: metadata(name) for name in names} + @staticmethod + def _flatten(v): + """Flatten the array without copying the data." + + Parameters + ---------- + v: array-like + The array to be flattened. + + Returns + ------- + array-like + 1-D array. + """ + if len(v.shape) != 1: + n = math.prod(v.shape) + n = (n,) + return array_namespace(v).reshape(v, n) + return v + + def _reshape(self, v, flatten): + """Reshape the array to the required shape.""" + shape = self._required_shape(flatten) + if shape != v.shape: + v = array_namespace(v).reshape(v, shape) + return v + + def _required_shape(self, flatten, shape=None): + """Return the required shape of the array.""" + if shape is None: + shape = self.shape + return shape if not flatten else (math.prod(shape),) + + def _array_matches(self, array, flatten=False, dtype=None): + """Check if the array matches the field and conditions.""" + shape = self._required_shape(flatten) + return shape == array.shape and (dtype is None or dtype == array.dtype) + class FieldList(Index): - r"""Represent a list of :obj:`Field` \s. + r"""Represent a list of :obj:`Field` \s.""" + + def __init__(self, **kwargs): + if "array_backend" in kwargs: + import warnings - Parameters - ---------- - array_backend: str, :obj:`ArrayBackend` - The array backend. When it is None the array backend - defaults to "numpy". - """ + warnings.warn( + ( + "array_backend option is not supported any longer in FieldList!" + " Use to_fieldlist() instead" + ), + DeprecationWarning, + ) + kwargs.pop("array_backend", None) - def __init__(self, array_backend=None, **kwargs): - self._array_backend = ensure_backend(array_backend) super().__init__(**kwargs) def _init_from_mask(self, index): - self._array_backend = index._index.array_backend + pass def _init_from_multi(self, index): - self._array_backend = index._indexes[0].array_backend + pass @staticmethod def from_fields(fields): - raise NotImplementedError + r"""Create a :class:`SimpleFieldList`. + + Parameters + ---------- + fields: list + List of :obj:`Field` objects. + + Returns + ------- + :class:`SimpleFieldList` + + """ + from earthkit.data.indexing.fieldlist import SimpleFieldList + + return SimpleFieldList(fields) @staticmethod def from_numpy(array, metadata): - from earthkit.data.sources.array_list import ArrayFieldList - - return ArrayFieldList(array, metadata, array_backend=numpy_backend()) + return FieldList.from_array(array, metadata) @staticmethod def from_array(array, metadata): - r"""Create an :class:`ArrayFieldList`. + r"""Create an :class:`SimpleFieldList`. Parameters ---------- array: array-like, list The fields' values. When it is a list it must contain one array per field. - The array type must be supported by :class:`ArrayBackend`. - metadata: list - The fields' metadata. Must contain one :class:`Metadata` object per field. + metadata: list, :class:`Metadata` + The fields' metadata. Must contain one :class:`Metadata` object per field. Or + it can be a single :class:`Metadata` object when all the fields have the same metadata. - In the generated :class:`ArrayFieldList`, each field is represented by an array + In the generated :class:`SimpleFieldList`, each field is represented by an array storing the field values and a :class:`MetaData` object holding the field metadata. The shape and dtype of the array is controlled by the ``kwargs``. - Please note that generated :class:`ArrayFieldList` stores all the field values in - a single array. """ - from earthkit.data.sources.array_list import ArrayFieldList + from earthkit.data.sources.array_list import from_array - return ArrayFieldList(array, metadata) - - @property - def array_backend(self): - return self._array_backend + return from_array(array, metadata) def ignore(self): # When the concrete type is Fieldlist we assume the object was @@ -934,8 +928,7 @@ def to_array(self, **kwargs): Returns ------- array-like - Array containing the field values. The array format is specified by - :attr:`array_backend`. + Array containing the field values. See Also -------- @@ -943,7 +936,9 @@ def to_array(self, **kwargs): to_numpy """ x = [f.to_array(**kwargs) for f in self] - return self._array_backend.array_ns.stack(x) + if len(x) == 0: + return None + return array_namespace(x[0]).stack(x) @property def values(self): @@ -969,7 +964,9 @@ def values(self): """ x = [f.values for f in self] - return self._array_backend.array_ns.stack(x) + if len(x) == 0: + return None + return array_namespace(x[0]).stack(x) def data( self, @@ -1061,10 +1058,10 @@ def data( r.extend([f.to_array(flatten=flatten, dtype=dtype, index=index) for f in self]) else: raise ValueError(f"data: invalid argument: {k}") - return self._array_backend.array_ns.stack(r) + return array_namespace(r[0]).stack(r) elif len(self) == 0: - return self._array_backend.array_ns.stack([]) + return array_namespace(r[0]).array_ns.stack([]) else: raise ValueError("Fields do not have the same grid geometry") @@ -1297,8 +1294,7 @@ def to_points(self, **kwargs): ------- dict Dictionary with items "x" and "y", containing the arrays of the x and - y coordinates, respectively. The array format is specified by - :attr:`array_backend`. + y coordinates, respectively. Raises ------ @@ -1328,8 +1324,7 @@ def to_latlon(self, index=None, **kwargs): ------- dict Dictionary with items "lat" and "lon", containing the arrays of the latitudes and - longitudes, respectively. The array format is specified by - :attr:`array_backend`. + longitudes, respectively. Raises ------ @@ -1443,7 +1438,7 @@ def save(self, filename, append=False, **kwargs): -------- :obj:`write` :meth:`GribFieldList.save() ` - :meth:`NumpyFieldList.save() ` + :meth:`SimpleFieldList.save() ` """ flag = "wb" if not append else "ab" @@ -1481,7 +1476,7 @@ def to_fieldlist(self, array_backend=None, **kwargs): Parameters ---------- - array_backend: str, :obj:`ArrayBackend` + array_backend: str, module, :obj:`ArrayBackend` Specifies the array backend for the generated :class:`FieldList`. The array type must be supported by :class:`ArrayBackend`. @@ -1493,12 +1488,12 @@ def to_fieldlist(self, array_backend=None, **kwargs): ------- :class:`FieldList` - the current :class:`FieldList` if it is already in the required format - - a new :class:`ArrayFieldList` otherwise + - a new :class:`SimpleFieldList` with :class`ArrayField` fields otherwise Examples -------- The following example will convert a fieldlist read from a file into a - :class:`ArrayFieldList` storing single precision field values. + :class:`SimpleFieldList` storing single precision field values. >>> import numpy as np >>> import earthkit.data @@ -1507,21 +1502,19 @@ def to_fieldlist(self, array_backend=None, **kwargs): 'docs/examples/tuv_pl.grib' >>> r = ds.to_fieldlist(array_backend="numpy", dtype=np.float32) >>> r - ArrayFieldList(fields=18) + SimpleFieldList(fields=18) >>> hasattr(r, "path") False >>> r.to_numpy().dtype dtype('float32') """ - if array_backend is None: - array_backend = self._array_backend - array_backend = ensure_backend(array_backend) - return self._to_array_fieldlist(array_backend=array_backend, **kwargs) - - def _to_array_fieldlist(self, **kwargs): - md = [f.metadata() for f in self] - return self.from_array(self.to_array(**kwargs), md) + array = [] + md = [] + for f in self: + array.append(f.to_array(array_backend=array_backend, **kwargs)) + md.append(f._metadata) + return self.from_array(array, md) def cube(self, *args, **kwargs): from earthkit.data.indexing.cube import FieldCube diff --git a/src/earthkit/data/indexing/fieldlist.py b/src/earthkit/data/indexing/fieldlist.py index ac494a4c..f5acb4aa 100644 --- a/src/earthkit/data/indexing/fieldlist.py +++ b/src/earthkit/data/indexing/fieldlist.py @@ -10,7 +10,7 @@ from earthkit.data.core.fieldlist import FieldList -class FieldArray(FieldList): +class SimpleFieldList(FieldList): def __init__(self, fields=None): self.fields = fields if fields is not None else [] @@ -25,3 +25,64 @@ def __len__(self): def __repr__(self) -> str: return f"FieldArray({len(self.fields)})" + + def __getstate__(self) -> dict: + ret = {} + ret["_fields"] = self.fields + return ret + + def __setstate__(self, state: dict): + self.fields = state.pop("_fields") + + def to_pandas(self, *args, **kwargs): + # TODO make it generic + if len(self) > 0: + if self[0]._metadata.data_format() == "grib": + from earthkit.data.readers.grib.pandas import PandasMixIn + + class _C(PandasMixIn, SimpleFieldList): + pass + + return _C(self.fields).to_pandas(*args, **kwargs) + else: + import pandas as pd + + return pd.DataFrame() + + def to_xarray(self, *args, **kwargs): + # TODO make it generic + if len(self) > 0: + if self[0]._metadata.data_format() == "grib": + from earthkit.data.readers.grib.xarray import XarrayMixIn + + class _C(XarrayMixIn, SimpleFieldList): + pass + + return _C(self.fields).to_xarray(*args, **kwargs) + else: + import xarray as xr + + return xr.Dataset() + + def mutate_source(self): + return self + + @classmethod + def new_mask_index(cls, *args, **kwargs): + assert len(args) == 2 + fs = args[0] + indices = list(args[1]) + return cls.from_fields([fs.fields[i] for i in indices]) + + @classmethod + def merge(cls, sources): + if not all(isinstance(_, SimpleFieldList) for _ in sources): + raise ValueError("SimpleFieldList can only be merged to another SimpleFieldLists") + + from itertools import chain + + return cls.from_fields(list(chain(*[f for f in sources]))) + + +# For backwards compatibility +FieldArray = SimpleFieldList diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index 3a9986d3..55efdc2b 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -245,8 +245,8 @@ class GribField(Field): _handle = None - def __init__(self, path, offset, length, backend, handle_manager=None, use_metadata_cache=False): - super().__init__(backend) + def __init__(self, path, offset, length, handle_manager=None, use_metadata_cache=False): + super().__init__() self.path = path self._offset = offset self._length = length diff --git a/src/earthkit/data/readers/grib/file.py b/src/earthkit/data/readers/grib/file.py index 6a064022..903f4ce5 100644 --- a/src/earthkit/data/readers/grib/file.py +++ b/src/earthkit/data/readers/grib/file.py @@ -19,10 +19,9 @@ class GRIBReader(GribFieldListInOneFile, Reader): appendable = True # GRIB messages can be added to the same file def __init__(self, source, path, parts=None): - _kwargs = {} for k in [ - "array_backend", + # "array_backend", "grib_field_policy", "grib_handle_policy", "grib_handle_cache_size", diff --git a/src/earthkit/data/readers/grib/index/__init__.py b/src/earthkit/data/readers/grib/index/__init__.py index f39d6523..1e381cf0 100644 --- a/src/earthkit/data/readers/grib/index/__init__.py +++ b/src/earthkit/data/readers/grib/index/__init__.py @@ -124,9 +124,6 @@ def availability_path(self): def merge(cls, sources): if not all(isinstance(_, GribFieldList) for _ in sources): raise ValueError("GribFieldList can only be merged to another GribFieldLists") - if not all(s.array_backend is s[0].array_backend for s in sources): - raise ValueError("Only fieldlists with the same array backend can be merged") - return GribMultiFieldList(sources) def _custom_availability(self, ignore_keys=None, filter_keys=lambda k: True): @@ -377,7 +374,6 @@ def _create_field(self, n): part.path, part.offset, part.length, - self.array_backend, handle_manager=self._handle_manager, use_metadata_cache=self._use_metadata_cache, ) diff --git a/src/earthkit/data/readers/grib/memory.py b/src/earthkit/data/readers/grib/memory.py index 226b40b9..df13eaf1 100644 --- a/src/earthkit/data/readers/grib/memory.py +++ b/src/earthkit/data/readers/grib/memory.py @@ -12,20 +12,18 @@ import eccodes +from earthkit.data.indexing.fieldlist import SimpleFieldList from earthkit.data.readers import Reader from earthkit.data.readers.grib.codes import GribCodesHandle from earthkit.data.readers.grib.codes import GribField -from earthkit.data.readers.grib.index import GribFieldList from earthkit.data.readers.grib.metadata import GribFieldMetadata -from earthkit.data.utils.array import ensure_backend LOG = logging.getLogger(__name__) class GribMemoryReader(Reader): - def __init__(self, array_backend=None, **kwargs): + def __init__(self, **kwargs): self._peeked = None - self._array_backend = ensure_backend(array_backend) def __iter__(self): return self @@ -47,7 +45,7 @@ def _next_handle(self): def _message_from_handle(self, handle): if handle is not None: - return GribFieldInMemory(GribCodesHandle(handle, None, None), self._array_backend) + return GribFieldInMemory(GribCodesHandle(handle, None, None)) def batched(self, n): from earthkit.data.utils.batch import batched @@ -119,8 +117,8 @@ def mutate_source(self): class GribFieldInMemory(GribField): """Represents a GRIB message in memory""" - def __init__(self, handle, array_backend=None): - super().__init__(None, None, None, array_backend) + def __init__(self, handle): + super().__init__(None, None, None) self._handle = handle @GribField.handle.getter @@ -142,44 +140,38 @@ def to_fieldlist(fields): @staticmethod def from_buffer(buf): handle = eccodes.codes_new_from_message(buf) - return GribFieldInMemory(GribCodesHandle(handle, None, None), None) + return GribFieldInMemory(GribCodesHandle(handle, None, None)) -class GribFieldListInMemory(GribFieldList, Reader): - """Represent a GRIB field list in memory""" +class GribFieldListInMemory(SimpleFieldList): + """Represent a GRIB field list in memory loaded lazily""" - @staticmethod - def from_fields(fields, array_backend=None): - if array_backend is None and len(fields) > 0: - array_backend = fields[0].array_backend - fs = GribFieldListInMemory(None, None, array_backend=array_backend) - fs._fields = fields - fs._loaded = True - return fs + # @staticmethod + # def from_fields(fields): + # if array_backend is None and len(fields) > 0: + # array_backend = fields[0].array_backend + # fs = GribFieldListInMemory(None, None, array_backend=array_backend) + # fs.fields = fields + # fs._loaded = True + # return fs def __init__(self, source, reader, *args, **kwargs): """The reader must support __next__.""" if source is not None: - Reader.__init__(self, source, "") - GribFieldList.__init__(self, *args, **kwargs) - - self._reader = reader + self._reader = reader self._loaded = False - self._fields = [] def __len__(self): self._load() - return len(self._fields) + return super().__len__() - def _getitem(self, n): + def __getitem__(self, n): self._load() - if isinstance(n, int): - n = n if n >= 0 else len(self) + n - return self._fields[n] + return super().__getitem__(n) def _load(self): if not self._loaded: - self._fields = [f for f in self._reader] + self.fields = [f for f in self._reader] self._loaded = True self._reader = None @@ -203,5 +195,5 @@ def __getstate__(self): def __setstate__(self, state): fields = [GribFieldInMemory.from_buffer(m) for m in state["messages"]] self.__init__(None, None) - self._fields = fields + self.fields = fields self._loaded = True diff --git a/src/earthkit/data/readers/netcdf/field.py b/src/earthkit/data/readers/netcdf/field.py index 9fdde80e..45a84731 100644 --- a/src/earthkit/data/readers/netcdf/field.py +++ b/src/earthkit/data/readers/netcdf/field.py @@ -213,8 +213,8 @@ def _key_name(key): class XArrayField(Field): - def __init__(self, ds, variable, slices, non_dim_coords, array_backend): - super().__init__(array_backend) + def __init__(self, ds, variable, slices, non_dim_coords): + super().__init__() self._ds = ds self._da = ds[variable] diff --git a/src/earthkit/data/readers/netcdf/fieldlist.py b/src/earthkit/data/readers/netcdf/fieldlist.py index 8ba9a85c..56b66f23 100644 --- a/src/earthkit/data/readers/netcdf/fieldlist.py +++ b/src/earthkit/data/readers/netcdf/fieldlist.py @@ -28,7 +28,6 @@ def get_fields_from_ds( ds, - array_backend, field_type=None, check_only=False, ): # noqa C901 @@ -145,7 +144,7 @@ def _skip_attr(v, attr_name): if check_only: return True - fields.append(field_type(ds, name, slices, non_dim_coords, array_backend)) + fields.append(field_type(ds, name, slices, non_dim_coords)) # if not fields: # raise Exception("NetCDFReader no 2D fields found in %s" % (self.path,)) @@ -174,7 +173,6 @@ def has_fields(self): if self._fields is None: return get_fields_from_ds( DataSet(self.xr_dataset), - self.array_backend, field_type=self.FIELD_TYPE, check_only=True, ) @@ -182,7 +180,7 @@ def has_fields(self): return len(self._fields) > 0 def _get_fields(self, ds): - return get_fields_from_ds(ds, self.array_backend, field_type=self.FIELD_TYPE) + return get_fields_from_ds(ds, field_type=self.FIELD_TYPE) def to_pandas(self, **kwargs): return self.to_xarray(**kwargs).to_pandas() diff --git a/src/earthkit/data/sources/array_list.py b/src/earthkit/data/sources/array_list.py index c297622d..6695a188 100644 --- a/src/earthkit/data/sources/array_list.py +++ b/src/earthkit/data/sources/array_list.py @@ -11,14 +11,7 @@ import math from earthkit.data.core.fieldlist import Field -from earthkit.data.core.fieldlist import FieldList -from earthkit.data.core.index import MaskIndex -from earthkit.data.core.index import MultiIndex -from earthkit.data.readers.grib.pandas import PandasMixIn -from earthkit.data.readers.grib.xarray import XarrayMixIn -from earthkit.data.utils.array import ensure_backend -from earthkit.data.utils.array import get_backend -from earthkit.data.utils.metadata.dict import UserMetadata +from earthkit.data.utils.array import array_namespace LOG = logging.getLogger(__name__) @@ -32,25 +25,22 @@ class ArrayField(Field): Array storing the values of the field metadata: :class:`Metadata` Metadata object describing the field metadata. - array_backend: str, ArrayBackend - Array backend. Must match the type of ``array``. """ - def __init__(self, array, metadata, array_backend=None): + def __init__(self, array, metadata): if isinstance(array, list): - array_backend = ensure_backend(array_backend) - array = array_backend.from_other(array) + import numpy as np + + array = np.array(array) if isinstance(metadata, dict): - metadata = UserMetadata(metadata, values=array) + from earthkit.data.utils.metadata.dict import UserMetadata - if array_backend is None: - array_backend = get_backend(array, guess=array_backend, strict=True) + metadata = UserMetadata(metadata, values=array) - if array_backend is None: - raise ValueError("array_backend must be provided") + metadata = metadata._hide_internal_keys() - super().__init__(array_backend, raw_values_backend=array_backend, metadata=metadata) + super().__init__(metadata=metadata) self._array = array def _values(self, dtype=None): @@ -58,7 +48,7 @@ def _values(self, dtype=None): if dtype is None: return self._array else: - return self.array_backend.array_ns.astype(self._array, dtype, copy=False) + return array_namespace(self._array).astype(self._array, dtype, copy=False) def __repr__(self): return self.__class__.__name__ + "(%s,%s,%s,%s,%s,%s)" % ( @@ -92,213 +82,58 @@ def __getstate__(self) -> dict: ret = {} ret["_array"] = self._array ret["_metadata"] = self._metadata - ret["_array_backend"] = self._array_backend.name return ret def __setstate__(self, state: dict): self._array = state.pop("_array") metadata = state.pop("_metadata") - array_backend = state.pop("_array_backend") - array_backend = ensure_backend(array_backend) - super().__init__(array_backend, raw_values_backend=array_backend, metadata=metadata) - - -class ArrayFieldListCore(PandasMixIn, XarrayMixIn, FieldList): - def __init__(self, array, metadata, *args, array_backend=None, **kwargs): - self._array = array - self._metadata = metadata - - if not isinstance(self._metadata, list): - self._metadata = [self._metadata] - - if isinstance(self._array, list): - if len(self._array) == 0: - raise ValueError("array must not be empty") - if isinstance(self._array[0], list): - array_backend = ensure_backend(array_backend) - self._array = [array_backend.from_other(a) for a in self._array] - elif isinstance(self._array[0], (int, float)): - array_backend = ensure_backend(array_backend) - self._array = array_backend.from_other(self._array) - - # get backend and check consistency - array_backend = get_backend(self._array, guess=array_backend, strict=True) - - FieldList.__init__(self, *args, array_backend=array_backend, **kwargs) - - if self.array_backend.is_native_array(self._array): - if self._array.shape[0] != len(self._metadata): - # we have a single array and a single metadata - if len(self._metadata) == 1 and self._shape_match( - self._array.shape, self._metadata[0].geography.shape() - ): - self._array = self.array_backend.array_ns.stack([self._array]) - else: - raise ValueError( - ( - f"first array dimension ({self._array.shape[0]}) differs " - f"from number of metadata objects ({len(self._metadata)})" - ) - ) - elif isinstance(self._array, list): - if len(self._array) != len(self._metadata): - raise ValueError( - ( - f"array len ({len(self._array)}) differs " - f"from number of metadata objects ({len(self._metadata)})" - ) - ) + super().__init__(metadata=metadata) - for i, a in enumerate(self._array): - if not self.array_backend.is_native_array(a): - raise ValueError( - ( - f"All array element must be an {self.array_backend.array_name}." - " Type at position={i} is {type(a)}" - ) - ) - - else: - raise TypeError( - ( - f"array must be an {self.array_backend.array_name} or a" - f" list of {self.array_backend.array_name}s" - ) - ) - - # hide internal metadata related to values - self._metadata = [md._hide_internal_keys() for md in self._metadata] - def _shape_match(self, shape1, shape2): +def from_array(array, metadata): + def _shape_match(shape1, shape2): if shape1 == shape2: return True if len(shape1) == 1 and shape1[0] == math.prod(shape2): return True return False - @classmethod - def new_mask_index(self, *args, **kwargs): - return ArrayMaskFieldList(*args, **kwargs) - - @classmethod - def merge(cls, sources): - if not all(isinstance(_, ArrayFieldListCore) for _ in sources): - raise ValueError("ArrayFieldList can only be merged to another ArrayFieldLists") - if not all(s.array_backend is s[0].array_backend for s in sources): - raise ValueError("Only fieldlists with the same array backend can be merged") - - merger = ListMerger(sources) - return merger.to_fieldlist() - - def __repr__(self): - return f"{self.__class__.__name__}(fields={len(self)})" + if not isinstance(metadata, list): + metadata = [metadata] + + # array_ns = get_backend(self._array).array_ns + if isinstance(array, list): + if len(array) == 0: + raise ValueError("array must not be empty") + + if not isinstance(array, list): + array_ns = array_namespace(array) + if array_ns is None: + raise ValueError(f"array type {type(array)} is not supported") + elif array.shape[0] != len(metadata): + # we have a single array and a single metadata + if len(metadata) == 1 and _shape_match(array.shape, metadata[0].geography.shape()): + array = array_ns.stack([array]) + else: + raise ValueError( + ( + f"first array dimension={array.shape[0]} differs " + f"from number of metadata objects={len(metadata)}" + ) + ) + else: + if len(array) != len(metadata): + raise ValueError( + (f"array len=({len(array)}) differs " f"from number of metadata objects=({len(metadata)})") + ) - def _to_array_fieldlist(self, array_backend=None, **kwargs): - if self[0]._array_matches(self._array[0], **kwargs): - return self + fields = [] + for i, a in enumerate(array): + if len(metadata) == 1: + fields.append(ArrayField(a, metadata[0])) else: - return type(self)(self.to_array(array_backend=array_backend, **kwargs), self._metadata) - - def save(self, filename, append=False, check_nans=True, bits_per_value=None): - r"""Write all the fields into a file. - - Parameters - ---------- - filename: str - The target file path. - append: bool - When it is true append data to the target file. Otherwise - the target file be overwritten if already exists. - check_nans: bool - Replace nans in the values with GRIB missing values when generating the output. - bits_per_value: int or None - Set the ``bitsPerValue`` GRIB key in the generated output. When None the - ``bitsPerValue`` stored in the metadata will be used. - """ - super().save( - filename, - append=append, - check_nans=check_nans, - bits_per_value=bits_per_value, - ) - - def __getstate__(self) -> dict: - ret = {} - ret["_array"] = self._array - ret["_metadata"] = self._metadata - ret["_array_backend"] = self._array_backend.name - return ret - - def __setstate__(self, state: dict): - self._array = state.pop("_array") - self._metadata = state.pop("_metadata") - array_backend = state.pop("_array_backend") - array_backend = ensure_backend(array_backend) - super().__init__(array_backend, raw_values_backend=array_backend, metadata=self._metadata) - - -# class MultiUnwindMerger: -# def __init__(self, sources): -# self.sources = list(self._flatten(sources)) - -# def _flatten(self, sources): -# if isinstance(sources, ArrayMultiFieldList): -# for s in sources.indexes: -# yield from self._flatten(s) -# elif isinstance(sources, list): -# for s in sources: -# yield from self._flatten(s) -# else: -# yield sources - -# def to_fieldlist(self): - -# return ArrayMultiFieldList(self.sources) - - -class ListMerger: - def __init__(self, sources): - self.sources = sources - - def to_fieldlist(self): - array = [] - metadata = [] - for s in self.sources: - for f in s: - array.append(f._array) - metadata.append(f._metadata) - array_backend = None if len(self.sources) == 0 else self.sources[0].array_backend - return ArrayFieldList(array, metadata, array_backend=array_backend) - - -class ArrayFieldList(ArrayFieldListCore): - r"""Represent a list of :obj:`ArrayField `\ s. - - The preferred way to create a ArrayFieldList is to use either the - static :obj:`from_array` method or the :obj:`to_fieldlist` method. - - See Also - -------- - from_array - to_fieldlist - - """ - - def _getitem(self, n): - if isinstance(n, int): - return ArrayField(self._array[n], self._metadata[n], self.array_backend) - - def __len__(self): - return len(self._array) if isinstance(self._array, list) else self._array.shape[0] - - -class ArrayMaskFieldList(ArrayFieldListCore, MaskIndex): - def __init__(self, *args, **kwargs): - MaskIndex.__init__(self, *args, **kwargs) - FieldList._init_from_mask(self, self) + fields.append(ArrayField(a, metadata[i])) + from earthkit.data.indexing.fieldlist import SimpleFieldList -class ArrayMultiFieldList(ArrayFieldListCore, MultiIndex): - def __init__(self, *args, **kwargs): - MultiIndex.__init__(self, *args, **kwargs) - FieldList._init_from_multi(self, self) + return SimpleFieldList(fields) diff --git a/src/earthkit/data/sources/forcings.py b/src/earthkit/data/sources/forcings.py index d849ff97..e8730f4e 100644 --- a/src/earthkit/data/sources/forcings.py +++ b/src/earthkit/data/sources/forcings.py @@ -217,7 +217,7 @@ def wrapper(date): class ForcingField(Field): - def __init__(self, maker, date, param, proc, number=None, array_backend=None): + def __init__(self, maker, date, param, proc, number=None): self.maker = maker self.date = date self.param = param @@ -234,7 +234,6 @@ def __init__(self, maker, date, param, proc, number=None, array_backend=None): levtype=None, ) super().__init__( - array_backend, metadata=ForcingMetadata(d, self.maker.field.metadata().geography), ) @@ -378,7 +377,6 @@ def _getitem(self, n): param, self.procs[param], number=number, - array_backend=self.array_backend, ) diff --git a/src/earthkit/data/sources/list_of_dicts.py b/src/earthkit/data/sources/list_of_dicts.py index 02c3a046..3f4d3286 100644 --- a/src/earthkit/data/sources/list_of_dicts.py +++ b/src/earthkit/data/sources/list_of_dicts.py @@ -12,7 +12,6 @@ from earthkit.data.utils.metadata.dict import UserMetadata from . import Source -from .array_list import ArrayFieldList LOG = logging.getLogger(__name__) @@ -23,13 +22,19 @@ def __init__(self, list_of_dicts, *args, **kwargs): self._kwargs = kwargs def mutate(self): - array = [] - metadata = [] + import numpy as np + + from earthkit.data.indexing.fieldlist import SimpleFieldList + + from .array_list import ArrayField + + fields = [] for f in self.d: v = f["values"] - array.append(v) - metadata.append(UserMetadata(f, values=v)) - return ArrayFieldList(array, metadata, **self._kwargs) + if isinstance(v, list): + v = np.array(v) + fields.append(ArrayField(v, UserMetadata(f, values=v))) + return SimpleFieldList(fields=fields) source = FieldlistFromDicts diff --git a/src/earthkit/data/sources/numpy_list.py b/src/earthkit/data/sources/numpy_list.py index acd65cec..dea662b6 100644 --- a/src/earthkit/data/sources/numpy_list.py +++ b/src/earthkit/data/sources/numpy_list.py @@ -12,7 +12,6 @@ class NumpyFieldList(ArrayFieldList): def __init__(self, *args, **kwargs): - from earthkit.data.utils.array import numpy_backend kwargs.pop("backend", None) - super().__init__(*args, array_backend=numpy_backend(), **kwargs) + super().__init__(*args, **kwargs) diff --git a/src/earthkit/data/testing.py b/src/earthkit/data/testing.py index ccd58ea7..65e33539 100644 --- a/src/earthkit/data/testing.py +++ b/src/earthkit/data/testing.py @@ -174,24 +174,26 @@ def load_nc_or_xr_source(path, mode): return from_object(xarray.open_dataset(path)) -def check_array_type(v, backend, **kwargs): - from earthkit.data.utils.array import ensure_backend +def check_array_type(array, expected_backend, dtype=None): + from earthkit.data.utils.array import get_backend - b = ensure_backend(backend) - assert b.is_native_array(v, **kwargs), f"{type(v)}, {backend=}, {kwargs=}" + b1 = get_backend(array) + b2 = get_backend(expected_backend) + assert b1 == b2, f"{b1=}, {b2=}" -def get_array_namespace(backend): - from earthkit.data.utils.array import ensure_backend + expected_dtype = dtype + if expected_dtype is not None: + assert b2.match_dtype(array, expected_dtype), f"{array.dtype}, {expected_dtype=}" - return ensure_backend(backend).array_ns +def get_array_namespace(backend): + if backend is None: + backend = "numpy" -def get_array(v, backend, **kwargs): - from earthkit.data.utils.array import ensure_backend + from earthkit.data.utils.array import get_backend - b = ensure_backend(backend) - return b.from_other(v, **kwargs) + return get_backend(backend).namespace ARRAY_BACKENDS = ["numpy"] diff --git a/src/earthkit/data/utils/array.py b/src/earthkit/data/utils/array.py new file mode 100644 index 00000000..caf55e4b --- /dev/null +++ b/src/earthkit/data/utils/array.py @@ -0,0 +1,320 @@ +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import logging +from abc import ABCMeta +from abc import abstractmethod +from functools import cached_property + +LOG = logging.getLogger(__name__) + + +class ArrayNamespace: + @cached_property + def api(self): + try: + import array_api_compat + + return array_api_compat + except Exception: + raise ImportError("array_api_compat is required to use array namespace") + + @cached_property + def numpy(self): + import numpy as np + + return self.api.array_namespace(np.ones(2)) + + def namespace(self, *arrays): + # default namespace is numpy + if not arrays: + return self.numpy + + # if isinstance(arrays[0], self.numpy.ndarray): + # return self.numpy + + # if self.api is not None: + return self.api.array_namespace(*arrays) + + # raise ValueError("Can't find namespace for array. Please install array_api_compat package") + + +_NAMESPACE = ArrayNamespace() + + +class ArrayBackend(metaclass=ABCMeta): + @property + def name(self): + return self._name + + @abstractmethod + def _make_sample(self): + return None + + @property + def namespace(self): + return _NAMESPACE.namespace(self._make_sample()) + + @property + @abstractmethod + def module(self): + pass + + @abstractmethod + def to_numpy(self, v): + pass + + @abstractmethod + def from_numpy(self, v): + pass + + @abstractmethod + def from_other(self, v, **kwargs): + pass + + @property + @abstractmethod + def dtypes(self): + pass + + def to_dtype(self, dtype): + if isinstance(dtype, str): + return self.dtypes.get(dtype, None) + return dtype + + def match_dtype(self, v, dtype): + if dtype is not None: + dtype = self.to_dtype(dtype) + f = v.dtype == dtype if dtype is not None else False + return f + return True + + +class NumpyBackend(ArrayBackend): + _name = "numpy" + _module_name = "numpy" + + def _make_sample(self): + import numpy as np + + return np.ones(2) + + @cached_property + def module(self): + import numpy as np + + return np + + def to_numpy(self, v): + return v + + def from_numpy(self, v): + return v + + def from_other(self, v, **kwargs): + import numpy as np + + if not kwargs and isinstance(v, np.ndarray): + return v + + return np.array(v, **kwargs) + + @cached_property + def dtypes(self): + import numpy + + return {"float64": numpy.float64, "float32": numpy.float32} + + +class PytorchBackend(ArrayBackend): + _name = "pytorch" + _module_name = "torch" + + def _make_sample(self): + import torch + + return torch.ones(2) + + @cached_property + def module(self): + import torch + + return torch + + def to_numpy(self, v): + return v.numpy() + + def from_numpy(self, v): + import torch + + return torch.from_numpy(v) + + def from_other(self, v, **kwargs): + import torch + + return torch.tensor(v, **kwargs) + + @cached_property + def dtypes(self): + import torch + + return {"float64": torch.float64, "float32": torch.float32} + + +class CupyBackend(ArrayBackend): + _name = "cupy" + _module_name = "cupy" + + def _make_sample(self): + import cupy + + return cupy.ones(2) + + @cached_property + def module(self): + import cupy + + return cupy + + def from_numpy(self, v): + return self.from_other(v) + + def to_numpy(self, v): + return v.get() + + def from_other(self, v, **kwargs): + import cupy as cp + + return cp.array(v, **kwargs) + + @cached_property + def dtypes(self): + import cupy as cp + + return {"float64": cp.float64, "float32": cp.float32} + + +_NUMPY = NumpyBackend() +_PYTORCH = PytorchBackend() +_CUPY = CupyBackend() + +_BACKENDS = [_NUMPY, _PYTORCH, _CUPY] +_BACKENDS_BY_NAME = {v._name: v for v in _BACKENDS} +_BACKENDS_BY_MODULE = {v._module_name: v for v in _BACKENDS} + + +def array_namespace(*args): + return _NAMESPACE.namespace(*args) + + +def array_to_numpy(array): + return backend_from_array(array).to_numpy(array) + + +def backend_from_array(array, raise_exception=True): + if isinstance(array, _NAMESPACE.numpy.ndarray): + return _NUMPY + + if _NAMESPACE.api is not None: + if _NAMESPACE.api.is_torch_array(array): + return _PYTORCH + elif _NAMESPACE.api.is_cupy_array(array): + return _CUPY + + if raise_exception: + raise ValueError(f"Can't find namespace for array type={type(array)}") + + +def backend_from_name(name, raise_exception=True): + r = _BACKENDS_BY_NAME.get(name, None) + if raise_exception and r is None: + raise ValueError(f"Unknown array backend name={name}") + return r + + +def backend_from_module(module, raise_exception=True): + import inspect + + r = None + if inspect.ismodule(module): + r = _BACKENDS_BY_MODULE.get(module.__name__, None) + if raise_exception and r is None: + raise ValueError(f"Unknown array backend module={module}") + return r + + +def get_backend(data): + if isinstance(data, ArrayBackend): + return data + if isinstance(data, str): + return backend_from_name(data, raise_exception=True) + + r = backend_from_module(data, raise_exception=True) + if r is None: + r = backend_from_array(data) + + return r + + +class Converter: + def __init__(self, source, target): + self.source = source + self.target = target + + def __call__(self, array, **kwargs): + if self.source == _NUMPY: + return self.target.from_numpy(array, **kwargs) + return self.target.from_other(array, **kwargs) + + +def converter(array, target): + if target is None: + return None + + source_backend = backend_from_array(array) + target_backend = get_backend(target) + + if source_backend == target_backend: + return None + return Converter(source_backend, target_backend) + + +def convert_array(array, target_backend=None, target_array_sample=None, **kwargs): + if target_backend is not None and target_array_sample is not None: + raise ValueError("Only one of target_backend or target_array_sample can be specified") + if target_backend is not None: + target = target_backend + else: + target = backend_from_array(target_array_sample) + + r = [] + target_is_list = True + if not isinstance(array, (list, tuple)): + array = [array] + target_is_list = False + + for a in array: + c = converter(a, target) + if c is None: + r.append(a) + else: + r.append(c(a, **kwargs)) + + if not target_is_list: + return r[0] + return r + + +def match(v1, v2): + get_backend(v1) == get_backend(v2) + + +# added for backward compatibility +def ensure_backend(backend): + return None diff --git a/src/earthkit/data/utils/array/__init__.py b/src/earthkit/data/utils/array/__init__.py deleted file mode 100644 index b53e9d40..00000000 --- a/src/earthkit/data/utils/array/__init__.py +++ /dev/null @@ -1,244 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -import logging -import os -import sys -import threading -from abc import ABCMeta -from abc import abstractmethod -from importlib import import_module - -LOG = logging.getLogger(__name__) - - -class ArrayBackendManager: - def __init__(self): - self.backends = None - self._np_backend = None - self.loaded = None - self.lock = threading.Lock() - - def find_for_name(self, name): - self._load() - b = self.backends.get(name, None) - if b is None: - raise TypeError(f"No array backend found for name={name}") - - # throw an exception when the backend is not available - if not b.available: - raise Exception(f"Could not load array backend for name={name}") - - return b - - def find_for_array(self, v, guess=None): - self._load() - - if guess is not None and guess.is_native_array(v): - return guess - - # try all the backends. This will only try to load/import an unloaded/unimported - # backend when necessary - for _, b in self.backends.items(): - if b.is_native_array(v): - return b - - raise TypeError(f"No array backend found for array type={type(v)}") - - @property - def numpy_backend(self): - if self._np_backend is None: - self._np_backend = self.find_for_name("numpy") - return self._np_backend - - def _load(self): - """Load the available backend objects""" - if self.loaded is None: - with self.lock: - self.backends = {} - here = os.path.dirname(__file__) - for path in sorted(os.listdir(here)): - if path[0] in ("_", "."): - continue - - if path.endswith(".py") or os.path.isdir(os.path.join(here, path)): - name, _ = os.path.splitext(path) - try: - module = import_module(f".{name}", package=__name__) - if hasattr(module, "Backend"): - w = getattr(module, "Backend") - self.backends[name] = w() - except Exception as e: - LOG.exception(f"Failed to import array backend code {name} from {path}. {e}") - self.loaded = True - - -class ArrayBackendCore: - def __init__(self, backend): - self.ns = None - self.dtypes = None - - try: - self.ns, self.dtypes = backend._load() - self.avail = True - except Exception as e: - LOG.exception(f"Failed to load array backend {backend.name}. {e}") - self.avail = False - - -class ArrayBackend(metaclass=ABCMeta): - """The backend objects are created upfront but only loaded on - demand to avoid unnecessary imports - """ - - _name = None - _array_name = "array" - _core = None - _converters = {} - - def __init__(self): - self.lock = threading.Lock() - - def _load_core(self): - if self._core is None: - with self.lock: - if self._core is None: - self._core = ArrayBackendCore(self) - - @property - def available(self): - self._load_core() - return self._core.avail - - @abstractmethod - def _load(self): - """Load the backend object. Called from arrayBackendCore.""" - pass - - @property - def array_ns(self): - """Delayed construction of array namespace""" - self._load_core() - return self._core.ns - - @property - def name(self): - return self._name - - @property - def array_name(self): - return f"{self._name} {self._array_name}" - - def to_array(self, v, source_backend=None): - r"""Convert an array into the current backend. - - Parameters - ---------- - v: array-like - Array. - source_backend: :obj:`ArrayBackend`, str - The array backend of ``v``. When None ``source_backend`` - is automatically detected. - - Returns - ------- - array-like - ``v`` converted into the array backend defined by ``self``. - """ - return self.from_backend(v, source_backend) - - @property - def _dtypes(self): - self._load_core() - return self._core.dtypes - - def to_dtype(self, dtype): - if isinstance(dtype, str): - return self._dtypes.get(dtype, None) - return dtype - - def match_dtype(self, v, dtype): - if dtype is not None: - dtype = self.to_dtype(dtype) - f = v.dtype == dtype if dtype is not None else False - return f - return True - - def is_native_array(self, v, dtype=None): - if (self._core is None and self._module_name not in sys.modules) or not self.available: - return False - return self._is_native_array(v, dtype=dtype) - - @abstractmethod - def _is_native_array(self, v, **kwargs): - pass - - def _quick_check_available(self): - return (self._core is None and self._module_name not in sys.modules) or not self.available - - @abstractmethod - def to_numpy(self, v): - pass - - def to_backend(self, v, backend, **kwargs): - assert backend is not None - backend = ensure_backend(backend) - return backend.from_backend(v, self, **kwargs) - - def from_backend(self, v, backend, **kwargs): - if backend is None: - backend = get_backend(v, strict=False) - - if self is backend: - return v - - if backend is not None: - b = self._converters.get(backend.name, None) - if b is not None: - return b(v) - - return self.from_other(v, **kwargs) - - @abstractmethod - def from_other(self, v, **kwargs): - pass - - -_MANAGER = ArrayBackendManager() - -# The public API - - -def ensure_backend(backend): - if backend is None: - return numpy_backend() - if isinstance(backend, str): - return _MANAGER.find_for_name(backend) - else: - return backend - - -def get_backend(array, guess=None, strict=True): - if isinstance(array, list): - array = array[0] - - if guess is not None: - guess = ensure_backend(guess) - - if isinstance(array, list): - return guess - - b = _MANAGER.find_for_array(array, guess=guess) - if strict and guess is not None and b is not guess: - raise ValueError(f"array type={b.array_name} and specified backend={guess} do not match") - return b - - -def numpy_backend(): - return _MANAGER.numpy_backend diff --git a/src/earthkit/data/utils/array/cupy.py b/src/earthkit/data/utils/array/cupy.py deleted file mode 100644 index baf209b1..00000000 --- a/src/earthkit/data/utils/array/cupy.py +++ /dev/null @@ -1,50 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -from . import ArrayBackend - - -class CupyBackend(ArrayBackend): - _name = "cupy" - _module_name = "cupy" - - def _load(self): - try: - import array_api_compat - - except Exception as e: - raise ImportError(f"array_api_compat is required to use cupy backend, {e}") - - try: - import cupy as cp - except Exception as e: - raise ImportError(f"cupy is required to use cupy backend, {e}") - - dt = {"float64": cp.float64, "float32": cp.float32} - ns = array_api_compat.array_namespace(cp.ones(2)) - - return ns, dt - - def _is_native_array(self, v, dtype=None): - import cupy as cp - - if not isinstance(v, cp.ndarray): - return False - return self.match_dtype(v, dtype) - - def to_numpy(self, v): - return v.get() - - def from_other(self, v, **kwargs): - import cupy as cp - - return cp.array(v, **kwargs) - - -Backend = CupyBackend diff --git a/src/earthkit/data/utils/array/numpy.py b/src/earthkit/data/utils/array/numpy.py deleted file mode 100644 index 0ee92f86..00000000 --- a/src/earthkit/data/utils/array/numpy.py +++ /dev/null @@ -1,58 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -from . import ArrayBackend - - -class NumpyBackend(ArrayBackend): - _name = "numpy" - _module_name = "numpy" - - def _load(self): - import numpy as np - - try: - import array_api_compat - - ns = array_api_compat.array_namespace(np.ones(2)) - except Exception: - ns = np - - return ns, {} - - def to_dtype(self, dtype): - return dtype - - def _is_native_array(self, v, dtype=None): - import numpy as np - - if not isinstance(v, np.ndarray): - return False - if dtype is not None: - return v.dtype == dtype - return True - - def from_backend(self, v, backend, **kwargs): - if self is backend: - return v - elif backend is not None: - return backend.to_numpy(v) - else: - return super().from_backend(v, backend, **kwargs) - - def to_numpy(self, v): - return v - - def from_other(self, v, **kwargs): - import numpy as np - - return np.array(v, **kwargs) - - -Backend = NumpyBackend diff --git a/src/earthkit/data/utils/array/pytorch.py b/src/earthkit/data/utils/array/pytorch.py deleted file mode 100644 index 8aee9e10..00000000 --- a/src/earthkit/data/utils/array/pytorch.py +++ /dev/null @@ -1,60 +0,0 @@ -# (C) Copyright 2020 ECMWF. -# -# This software is licensed under the terms of the Apache Licence Version 2.0 -# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. -# In applying this licence, ECMWF does not waive the privileges and immunities -# granted to it by virtue of its status as an intergovernmental organisation -# nor does it submit to any jurisdiction. -# - -from . import ArrayBackend - - -class PytorchBackend(ArrayBackend): - _name = "pytorch" - _array_name = "tensor" - _module_name = "torch" - - def __init__(self): - super().__init__() - self._converters = {"numpy": self.from_numpy} - - def _load(self): - try: - import array_api_compat - - except Exception as e: - raise ImportError(f"array_api_compat is required to use pytorch backend, {e}") - - try: - import torch - except Exception as e: - raise ImportError(f"torch is required to use pytorch backend, {e}") - - dt = {"float64": torch.float64, "float32": torch.float32} - ns = array_api_compat.array_namespace(torch.ones(2)) - - return ns, dt - - def _is_native_array(self, v, dtype=None): - import torch - - if not torch.is_tensor(v): - return False - return self.match_dtype(v, dtype) - - def to_numpy(self, v): - return v.numpy() - - def from_numpy(self, v): - import torch - - return torch.from_numpy(v) - - def from_other(self, v, **kwargs): - import torch - - return torch.tensor(v, **kwargs) - - -Backend = PytorchBackend diff --git a/tests/array_fieldlist/array_fl_fixtures.py b/tests/array_fieldlist/array_fl_fixtures.py index 4a66264d..d275286e 100644 --- a/tests/array_fieldlist/array_fl_fixtures.py +++ b/tests/array_fieldlist/array_fl_fixtures.py @@ -26,7 +26,8 @@ def load_array_fl(num, array_backend=None): ds_in = [] md = [] for fname in files: - ds_in.append(from_source("file", earthkit_examples_file(fname), array_backend=array_backend)) + ds = from_source("file", earthkit_examples_file(fname)) + ds_in.append(ds.to_fieldlist(array_backend=array_backend)) md += ds_in[-1].metadata("param") ds = [] @@ -38,7 +39,8 @@ def load_array_fl(num, array_backend=None): def load_array_fl_file(fname, array_backend=None): - ds_in = from_source("file", earthkit_examples_file(fname), array_backend=array_backend) + ds_in = from_source("file", earthkit_examples_file(fname)) + ds_in = ds_in.to_fieldlist(array_backend=array_backend) md = ds_in.metadata("param") ds = FieldList.from_array(ds_in.values, [m.override(edition=1) for m in ds_in.metadata()]) @@ -101,7 +103,6 @@ def check_array_fl_from_to_fieldlist(ds, ds_input, md_full, array_backend=None, assert ns.allclose(ds[0].to_array(**np_kwargs), ds_input[0][0].to_array(**np_kwargs)) assert ds.to_array(**np_kwargs).shape == ds_input[0].to_array(**np_kwargs).shape - assert ds._array.shape == ds_input[0].to_array(**np_kwargs).shape # check slice r = ds[1] diff --git a/tests/array_fieldlist/test_numpy_fl_write.py b/tests/array_fieldlist/test_numpy_fl_write.py index 5ce6b199..46a64bd4 100644 --- a/tests/array_fieldlist/test_numpy_fl_write.py +++ b/tests/array_fieldlist/test_numpy_fl_write.py @@ -33,7 +33,9 @@ @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) def test_array_fl_grib_write_to_path(array_backend): - ds = from_source("file", earthkit_examples_file("test.grib"), array_backend=array_backend) + ds = from_source("file", earthkit_examples_file("test.grib")) + ds = ds.to_fieldlist(array_backend=array_backend) + ns = get_array_namespace(array_backend) assert ds[0].metadata("shortName") == "2t" @@ -48,14 +50,16 @@ def test_array_fl_grib_write_to_path(array_backend): with temp_file() as tmp: r.save(tmp) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, array_backend=array_backend) + r_tmp = from_source("file", tmp) + r_tmp = r_tmp.to_fieldlist(array_backend=array_backend) v_tmp = r_tmp[0].values assert ns.allclose(v1, v_tmp) @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) def test_array_fl_grib_write_to_filehandle(array_backend): - ds = from_source("file", earthkit_examples_file("test.grib"), array_backend=array_backend) + ds = from_source("file", earthkit_examples_file("test.grib")) + ds = ds.to_fieldlist(array_backend=array_backend) ns = get_array_namespace(array_backend) assert ds[0].metadata("shortName") == "2t" @@ -71,7 +75,8 @@ def test_array_fl_grib_write_to_filehandle(array_backend): with open(tmp, "wb") as fh: r.write(fh) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, array_backend=array_backend) + r_tmp = from_source("file", tmp) + r_tmp = r_tmp.to_fieldlist(array_backend=array_backend) v_tmp = r_tmp[0].values assert ns.allclose(v1, v_tmp) @@ -79,7 +84,8 @@ def test_array_fl_grib_write_to_filehandle(array_backend): @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("_kwargs", [{}, {"check_nans": True}]) def test_array_fl_grib_write_missing(array_backend, _kwargs): - ds = from_source("file", earthkit_examples_file("test.grib"), array_backend=array_backend) + ds = from_source("file", earthkit_examples_file("test.grib")) + ds = ds.to_fieldlist(array_backend=array_backend) ns = get_array_namespace(array_backend) assert ds[0].metadata("shortName") == "2t" @@ -102,7 +108,8 @@ def test_array_fl_grib_write_missing(array_backend, _kwargs): with temp_file() as tmp: r.save(tmp, **_kwargs) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, array_backend=array_backend) + r_tmp = from_source("file", tmp) + r_tmp = r_tmp.to_fieldlist(array_backend=array_backend) v_tmp = r_tmp[0].values assert ns.isnan(v_tmp[0]) assert not ns.isnan(v_tmp[1]) diff --git a/tests/grib/grib_fixtures.py b/tests/grib/grib_fixtures.py index 4ece33af..95a9794a 100644 --- a/tests/grib/grib_fixtures.py +++ b/tests/grib/grib_fixtures.py @@ -11,32 +11,42 @@ from earthkit.data import from_source -from earthkit.data.core.fieldlist import FieldList +from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import earthkit_examples_file from earthkit.data.testing import earthkit_test_data_file +from earthkit.data.utils.array import get_backend def load_array_fieldlist(path, array_backend): - ds = from_source("file", path, array_backend=array_backend) - return FieldList.from_array( - ds.values, [m.override(generatingProcessIdentifier=120) for m in ds.metadata()] - ) + ds = from_source("file", path) + return ds.to_fieldlist(array_backend=array_backend) + # return FieldList.from_array( + # ds.values, [m.override(generatingProcessIdentifier=120) for m in ds.metadata()] + # ) -def load_grib_data(filename, fl_type, array_backend, folder="example"): +def load_grib_data(filename, fl_type, folder="example"): + if isinstance(filename, str): + filename = [filename] + if folder == "example": - path = earthkit_examples_file(filename) + path = [earthkit_examples_file(name) for name in filename] elif folder == "data": - path = earthkit_test_data_file(filename) + path = [earthkit_test_data_file(name) for name in filename] else: - raise ValueError("Invalid folder={folder}") + raise ValueError(f"Invalid folder={folder}") if fl_type == "file": - return from_source("file", path, array_backend=array_backend) - elif fl_type == "array": - return load_array_fieldlist(path, array_backend) + return from_source("file", path), get_backend("numpy") + elif fl_type in ARRAY_BACKENDS: + array_backend = fl_type + return load_array_fieldlist(path, array_backend), get_backend(array_backend) else: - raise ValueError("Invalid fl_type={fl_type}") + raise ValueError(f"Invalid fl_type={fl_type}") -FL_TYPES = ["file", "array"] +FL_TYPES = ["file"] +FL_TYPES.extend(ARRAY_BACKENDS) +FL_ARRAYS = ARRAY_BACKENDS +FL_NUMPY = ["file", "numpy"] +FL_FILE = ["file"] diff --git a/tests/grib/test_grib_backend.py b/tests/grib/test_grib_backend.py index 6321d6fd..ea0ece11 100644 --- a/tests/grib/test_grib_backend.py +++ b/tests/grib/test_grib_backend.py @@ -17,12 +17,18 @@ from earthkit.data.testing import NO_CUPY from earthkit.data.testing import NO_PYTORCH from earthkit.data.testing import earthkit_examples_file +from earthkit.data.utils.array import _CUPY +from earthkit.data.utils.array import _NUMPY +from earthkit.data.utils.array import _PYTORCH +from earthkit.data.utils.array import get_backend @pytest.mark.parametrize("_kwargs", [{}, {"array_backend": "numpy"}]) def test_grib_file_numpy_backend(_kwargs): - ds = from_source("file", earthkit_examples_file("test6.grib"), **_kwargs) + ds = from_source("file", earthkit_examples_file("test6.grib")) + ds = ds.to_fieldlist(**_kwargs) + assert getattr(ds, "path", None) is None assert len(ds) == 6 assert isinstance(ds[0].values, np.ndarray) @@ -46,16 +52,20 @@ def test_grib_file_numpy_backend(_kwargs): assert isinstance(ds.to_numpy(), np.ndarray) assert ds.to_numpy().shape == (6, 7, 12) + assert get_backend(ds[0].to_array()) == _NUMPY + ds1 = ds.to_fieldlist() assert len(ds1) == len(ds) - assert ds1.array_backend.name == "numpy" assert getattr(ds1, "path", None) is None + assert get_backend(ds1[0].to_array()) == _NUMPY @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_grib_file_pytorch_backend(): - ds = from_source("file", earthkit_examples_file("test6.grib"), array_backend="pytorch") + ds = from_source("file", earthkit_examples_file("test6.grib")) + ds = ds.to_fieldlist(array_backend="pytorch") + assert getattr(ds, "path", None) is None assert len(ds) == 6 import torch @@ -85,18 +95,22 @@ def test_grib_file_pytorch_backend(): assert isinstance(x, np.ndarray) assert x.shape == (6, 7, 12) + assert get_backend(ds[0].to_array()) == _PYTORCH + ds1 = ds.to_fieldlist() assert len(ds1) == len(ds) - assert ds1.array_backend.name == "pytorch" assert getattr(ds1, "path", None) is None + assert get_backend(ds1[0].to_array()) == _PYTORCH @pytest.mark.skipif(NO_CUPY, reason="No cupy installed") def test_grib_file_cupy_backend(): - ds = from_source("file", earthkit_examples_file("test6.grib"), array_backend="cupy") + ds = from_source("file", earthkit_examples_file("test6.grib")) + ds = ds.to_fieldlist(array_backend="cupy") import cupy as cp + assert getattr(ds, "path", None) is None assert len(ds) == 6 assert isinstance(ds[0].values, cp.ndarray) @@ -124,10 +138,12 @@ def test_grib_file_cupy_backend(): assert isinstance(x, np.ndarray) assert x.shape == (6, 7, 12) + assert get_backend(ds[0].to_array()) == _CUPY + ds1 = ds.to_fieldlist() assert len(ds1) == len(ds) - assert ds1.array_backend.name == "cupy" assert getattr(ds1, "path", None) is None + assert get_backend(ds1[0].to_array()) == _CUPY def test_grib_array_numpy_backend(): @@ -165,18 +181,18 @@ def test_grib_array_numpy_backend(): @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_grib_array_pytorch_backend(): - s = from_source("file", earthkit_examples_file("test6.grib"), array_backend="pytorch") + s = from_source("file", earthkit_examples_file("test6.grib")) + + import torch ds = FieldList.from_array( - s.values, + torch.tensor(s.values), [m for m in s.metadata()], ) assert len(ds) == 6 with pytest.raises(AttributeError): ds.path - import torch - assert torch.is_tensor(ds[0].values) assert ds[0].values.shape == (84,) @@ -201,18 +217,18 @@ def test_grib_array_pytorch_backend(): @pytest.mark.skipif(NO_CUPY, reason="No cupy installed") def test_grib_array_cupy_backend(): - s = from_source("file", earthkit_examples_file("test6.grib"), array_backend="cupy") + s = from_source("file", earthkit_examples_file("test6.grib")) + + import cupy as cp ds = FieldList.from_array( - s.values, + cp.array(s.values), [m for m in s.metadata()], ) assert len(ds) == 6 with pytest.raises(AttributeError): ds.path - import cupy as cp - assert isinstance(ds[0].values, cp.ndarray) assert ds[0].values.shape == (84,) diff --git a/tests/grib/test_grib_concat.py b/tests/grib/test_grib_concat.py index 01bd0260..5c92f477 100644 --- a/tests/grib/test_grib_concat.py +++ b/tests/grib/test_grib_concat.py @@ -101,6 +101,22 @@ def test_grib_concat_3b(mode): _check_save_to_disk(ds, 26, md) +@pytest.mark.parametrize("mode", ["oper", "multi"]) +def test_grib_concat_mixed(mode): + ds1 = from_source("file", earthkit_examples_file("test.grib")) + ds2 = ds1.to_fieldlist() + md = ds1.metadata("param") + ds2.metadata("param") + + if mode == "oper": + ds = ds1 + ds2 + else: + ds = from_source("multi", ds1, ds2) + + assert len(ds) == 4 + assert ds.metadata("param") == md + _check_save_to_disk(ds, 4, md) + + def test_grib_from_empty_1(): ds_e = FieldList() ds = from_source("file", earthkit_examples_file("test.grib")) diff --git a/tests/grib/test_grib_convert.py b/tests/grib/test_grib_convert.py index 51bd2fbc..602ed8aa 100644 --- a/tests/grib/test_grib_convert.py +++ b/tests/grib/test_grib_convert.py @@ -17,15 +17,14 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) -from grib_fixtures import FL_TYPES # noqa: E402 +from grib_fixtures import FL_NUMPY # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_icon_to_xarray(fl_type, array_backend): +@pytest.mark.parametrize("fl_type", FL_NUMPY) +def test_icon_to_xarray(fl_type): # test the conversion to xarray for an icon (unstructured grid) grib file. - g = load_grib_data("test_icon.grib", fl_type, array_backend, folder="data") + g, _ = load_grib_data("test_icon.grib", fl_type, folder="data") ds = g.to_xarray(engine="cfgrib") assert len(ds.data_vars) == 1 @@ -35,10 +34,9 @@ def test_icon_to_xarray(fl_type, array_backend): assert ds["pres"].sizes["values"] == 6 -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_to_xarray_filter_by_keys(fl_type, array_backend): - g = load_grib_data("tuv_pl.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_NUMPY) +def test_to_xarray_filter_by_keys(fl_type): + g, _ = load_grib_data("tuv_pl.grib", fl_type) g = g.sel(param="t", level=500) + g.sel(param="u") assert len(g) > 1 @@ -51,10 +49,9 @@ def test_to_xarray_filter_by_keys(fl_type, array_backend): assert r["t"].sizes["isobaricInhPa"] == 1 -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_to_pandas(fl_type, array_backend): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +@pytest.mark.parametrize("fl_type", FL_NUMPY) +def test_grib_to_pandas(fl_type): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") # all points df = f.to_pandas() diff --git a/tests/grib/test_grib_geography.py b/tests/grib/test_grib_geography.py index 53744ab9..da0380fb 100644 --- a/tests/grib/test_grib_geography.py +++ b/tests/grib/test_grib_geography.py @@ -16,7 +16,6 @@ import pytest import earthkit.data -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import check_array_type from earthkit.data.testing import earthkit_examples_file from earthkit.data.testing import earthkit_test_data_file @@ -24,6 +23,7 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_NUMPY # noqa: E402 from grib_fixtures import FL_TYPES # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 @@ -36,10 +36,9 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single(fl_type, array_backend, index): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_latlon_single(fl_type, index): + f, array_backend = load_grib_data("test_single.grib", fl_type, folder="data") eps = 1e-5 g = f[index] if index is not None else f @@ -66,10 +65,9 @@ def test_grib_to_latlon_single(fl_type, array_backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_latlon_single_shape(fl_type, array_backend, index): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_latlon_single_shape(fl_type, index): + f, array_backend = load_grib_data("test_single.grib", fl_type, folder="data") g = f[index] if index is not None else f v = g.to_latlon() @@ -88,11 +86,10 @@ def test_grib_to_latlon_single_shape(fl_type, array_backend, index): assert np.allclose(y, np.ones(12) * (90 - i * 30)) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_latlon_multi(fl_type, array_backend, dtype): - f = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_to_latlon_multi(fl_type, dtype): + f, _ = load_grib_data("test.grib", fl_type) v_ref = f[0].to_latlon(flatten=True, dtype=dtype) v = f.to_latlon(flatten=True, dtype=dtype) @@ -107,10 +104,9 @@ def test_grib_to_latlon_multi(fl_type, array_backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_latlon_multi_non_shared_grid(fl_type, array_backend): - f1 = load_grib_data("test.grib", fl_type, array_backend) - f2 = load_grib_data("test4.grib", fl_type, array_backend) +def test_grib_to_latlon_multi_non_shared_grid(fl_type): + f1, _ = load_grib_data("test.grib", fl_type) + f2, _ = load_grib_data("test4.grib", fl_type) f = f1 + f2 with pytest.raises(ValueError): @@ -118,10 +114,9 @@ def test_grib_to_latlon_multi_non_shared_grid(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_to_points_single(fl_type, array_backend, index): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_points_single(fl_type, index): + f, array_backend = load_grib_data("test_single.grib", fl_type, folder="data") eps = 1e-5 g = f[index] if index is not None else f @@ -148,18 +143,16 @@ def test_grib_to_points_single(fl_type, array_backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_points_unsupported_grid(fl_type, array_backend): - f = load_grib_data("mercator.grib", fl_type, array_backend, folder="data") +def test_grib_to_points_unsupported_grid(fl_type): + f, _ = load_grib_data("mercator.grib", fl_type, folder="data") with pytest.raises(ValueError): f[0].to_points() -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_points_multi(fl_type, array_backend, dtype): - f = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_to_points_multi(fl_type, dtype): + f, _ = load_grib_data("test.grib", fl_type) v_ref = f[0].to_points(flatten=True, dtype=dtype) v = f.to_points(flatten=True, dtype=dtype) @@ -174,10 +167,9 @@ def test_grib_to_points_multi(fl_type, array_backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_points_multi_non_shared_grid(fl_type, array_backend): - f1 = load_grib_data("test.grib", fl_type, array_backend) - f2 = load_grib_data("test4.grib", fl_type, array_backend) +def test_grib_to_points_multi_non_shared_grid(fl_type): + f1, _ = load_grib_data("test.grib", fl_type) + f2, _ = load_grib_data("test4.grib", fl_type) f = f1 + f2 with pytest.raises(ValueError): @@ -185,9 +177,8 @@ def test_grib_to_points_multi_non_shared_grid(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_bbox(fl_type, array_backend): - ds = load_grib_data("test.grib", fl_type, array_backend) +def test_bbox(fl_type): + ds, _ = load_grib_data("test.grib", fl_type) bb = ds.bounding_box() assert len(bb) == 2 for b in bb: @@ -195,10 +186,9 @@ def test_bbox(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize("index", [0, None]) -def test_grib_projection_ll(fl_type, array_backend, index): - f = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_projection_ll(fl_type, index): + f, _ = load_grib_data("test.grib", fl_type) if index is not None: g = f[index] @@ -208,9 +198,8 @@ def test_grib_projection_ll(fl_type, array_backend, index): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_projection_mercator(fl_type, array_backend): - f = load_grib_data("mercator.grib", fl_type, array_backend, folder="data") +def test_grib_projection_mercator(fl_type): + f, _ = load_grib_data("mercator.grib", fl_type, folder="data") projection = f[0].projection() assert isinstance(projection, projections.Mercator) assert projection.parameters == { diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py index 37d722a2..feb9a570 100644 --- a/tests/grib/test_grib_inidces.py +++ b/tests/grib/test_grib_inidces.py @@ -14,8 +14,6 @@ import pytest -from earthkit.data.testing import ARRAY_BACKENDS - here = os.path.dirname(__file__) sys.path.insert(0, here) from grib_fixtures import FL_TYPES # noqa: E402 @@ -23,9 +21,8 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_base(fl_type, array_backend): - ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_indices_base(fl_type): + ds, _ = load_grib_data("tuv_pl.grib", fl_type) ref_full = { "class": ["od"], @@ -66,9 +63,8 @@ def test_grib_indices_base(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_sel(fl_type, array_backend): - ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_indices_sel(fl_type): + ds, _ = load_grib_data("tuv_pl.grib", fl_type) ref = { "class": ["od"], @@ -96,10 +92,9 @@ def test_grib_indices_sel(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_multi(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_indices_multi(fl_type): + f1, _ = load_grib_data("tuv_pl.grib", fl_type) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") ds = f1 + f2 ref = { @@ -163,10 +158,9 @@ def test_grib_indices_multi(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_multi_sel(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_indices_multi_sel(fl_type): + f1, _ = load_grib_data("tuv_pl.grib", fl_type) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") ds = f1 + f2 ref = { @@ -189,9 +183,8 @@ def test_grib_indices_multi_sel(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_order_by(fl_type, array_backend): - ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_indices_order_by(fl_type): + ds, _ = load_grib_data("tuv_pl.grib", fl_type) ref = { "class": ["od"], diff --git a/tests/grib/test_grib_iter.py b/tests/grib/test_grib_iter.py index 76487ec7..6df7fa6f 100644 --- a/tests/grib/test_grib_iter.py +++ b/tests/grib/test_grib_iter.py @@ -10,17 +10,24 @@ # +import os +import sys + import pytest from earthkit.data import from_source -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import earthkit_examples_file +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import FL_ARRAYS # noqa: E402 +from grib_fixtures import load_grib_data # noqa: E402 + -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("fl_type", FL_ARRAYS) @pytest.mark.parametrize("group", ["param"]) -def test_grib_group_by(array_backend, group): - ds = from_source("file", earthkit_examples_file("test6.grib"), array_backend=array_backend) +def test_grib_group_by(fl_type, group): + ds, array_backend = load_grib_data("test6.grib", fl_type) ref = [ [("t", 1000), ("t", 850)], @@ -31,7 +38,7 @@ def test_grib_group_by(array_backend, group): for i, f in enumerate(ds.group_by(group)): assert len(f) == 2 assert f.metadata(("param", "level")) == ref[i] - afl = f.to_fieldlist(array_backend=array_backend) + afl = f.to_fieldlist(array_backend=array_backend._name) assert afl is not f assert len(afl) == 2 cnt += len(f) @@ -39,14 +46,10 @@ def test_grib_group_by(array_backend, group): assert cnt == len(ds) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +@pytest.mark.parametrize("fl_type", FL_ARRAYS) @pytest.mark.parametrize("group", ["level", ["level", "gridType"]]) -def test_grib_multi_group_by(array_backend, group): - ds = from_source( - "file", - [earthkit_examples_file("test4.grib"), earthkit_examples_file("test6.grib")], - array_backend=array_backend, - ) +def test_grib_multi_group_by(fl_type, group): + ds, _ = load_grib_data(["test4.grib", "test6.grib"], fl_type) ref = [ [("t", 500), ("z", 500)], diff --git a/tests/grib/test_grib_metadata.py b/tests/grib/test_grib_metadata.py index 195df36e..9e7eaac5 100644 --- a/tests/grib/test_grib_metadata.py +++ b/tests/grib/test_grib_metadata.py @@ -17,11 +17,11 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import earthkit_examples_file here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_FILE # noqa: E402 from grib_fixtures import FL_TYPES # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 @@ -38,7 +38,6 @@ def repeat_list_items(items, count): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -56,8 +55,8 @@ def repeat_list_items(items, count): (("shortName", "level"), ("2t", 0)), ], ) -def test_grib_metadata_grib(fl_type, array_backend, key, expected_value): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_metadata_grib(fl_type, key, expected_value): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") sn = f.metadata(key) assert sn == [expected_value] sn = f[0].metadata(key) @@ -65,7 +64,6 @@ def test_grib_metadata_grib(fl_type, array_backend, key, expected_value): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -79,8 +77,8 @@ def test_grib_metadata_grib(fl_type, array_backend, key, expected_value): ("level", int, 0), ], ) -def test_grib_metadata_astype_1(fl_type, array_backend, key, astype, expected_value): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_metadata_astype_1(fl_type, key, astype, expected_value): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") sn = f.metadata(key, astype=astype) assert sn == [expected_value] sn = f[0].metadata(key, astype=astype) @@ -88,7 +86,6 @@ def test_grib_metadata_astype_1(fl_type, array_backend, key, astype, expected_va @pytest.mark.parametrize("fs_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -100,15 +97,14 @@ def test_grib_metadata_astype_1(fl_type, array_backend, key, astype, expected_va ("level:int", repeat_list_items([1000, 850, 700, 500, 400, 300], 3)), ], ) -def test_grib_metadata_18(fs_type, array_backend, key, expected_value): +def test_grib_metadata_18(fs_type, key, expected_value): # f = load_grib_data("tuv_pl.grib", mode) - ds = load_grib_data("tuv_pl.grib", fs_type, array_backend) + ds, _ = load_grib_data("tuv_pl.grib", fs_type) sn = ds.metadata(key) assert sn == expected_value @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype,expected_value", [ @@ -126,14 +122,13 @@ def test_grib_metadata_18(fs_type, array_backend, key, expected_value): ), ], ) -def test_grib_metadata_astype_18(fl_type, array_backend, key, astype, expected_value): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_metadata_astype_18(fl_type, key, astype, expected_value): + f, _ = load_grib_data("tuv_pl.grib", fl_type) sn = f.metadata(key, astype=astype) assert sn == expected_value @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,expected_value", [ @@ -142,15 +137,14 @@ def test_grib_metadata_astype_18(fl_type, array_backend, key, astype, expected_v ("latitudeOfFirstGridPointInDegrees:float", 90.0), ], ) -def test_grib_metadata_double_1(fl_type, array_backend, key, expected_value): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_metadata_double_1(fl_type, key, expected_value): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") r = f.metadata(key) assert len(r) == 1 assert np.isclose(r[0], expected_value) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key", [ @@ -159,8 +153,8 @@ def test_grib_metadata_double_1(fl_type, array_backend, key, expected_value): ("latitudeOfFirstGridPointInDegrees:float"), ], ) -def test_grib_metadata_double_18(fl_type, array_backend, key): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_metadata_double_18(fl_type, key): + f, _ = load_grib_data("tuv_pl.grib", fl_type) ref = [90.0] * 18 r = f.metadata(key) @@ -168,7 +162,6 @@ def test_grib_metadata_double_18(fl_type, array_backend, key): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "key,astype", [ @@ -176,8 +169,8 @@ def test_grib_metadata_double_18(fl_type, array_backend, key): ("latitudeOfFirstGridPointInDegrees", float), ], ) -def test_grib_metadata_double_astype_18(fl_type, array_backend, key, astype): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_metadata_double_astype_18(fl_type, key, astype): + f, _ = load_grib_data("tuv_pl.grib", fl_type) ref = [90.0] * 18 @@ -186,9 +179,8 @@ def test_grib_metadata_double_astype_18(fl_type, array_backend, key, astype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_get_long_array_1(fl_type, array_backend): - f = load_grib_data("rgg_small_subarea_cellarea_ref.grib", fl_type, array_backend, folder="data") +def test_grib_get_long_array_1(fl_type): + f, _ = load_grib_data("rgg_small_subarea_cellarea_ref.grib", fl_type, folder="data") assert len(f) == 1 pl = f.metadata("pl") @@ -202,10 +194,9 @@ def test_grib_get_long_array_1(fl_type, array_backend): assert pl[72] == 312 -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_grib_get_double_array_values_1(fl_type, array_backend): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_get_double_array_values_1(fl_type): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") v = f.metadata("values") assert len(v) == 1 @@ -222,10 +213,9 @@ def test_grib_get_double_array_values_1(fl_type, array_backend): ) -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_grib_get_double_array_values_18(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_get_double_array_values_18(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) v = f.metadata("values") assert isinstance(v, list) assert len(v) == 18 @@ -254,9 +244,10 @@ def test_grib_get_double_array_values_18(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_get_double_array_1(fl_type, array_backend): - f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data")[0] +def test_grib_get_double_array_1(fl_type): + f_in, _ = load_grib_data("ml_data.grib", fl_type, folder="data") + + f = f_in[0] # f is now a field! v = f.metadata("pv") assert isinstance(v, np.ndarray) @@ -268,9 +259,8 @@ def test_grib_get_double_array_1(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_get_double_array_18(fl_type, array_backend): - f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_get_double_array_18(fl_type): + f, _ = load_grib_data("ml_data.grib", fl_type, folder="data") v = f.metadata("pv") assert isinstance(v, list) assert len(v) == 36 @@ -286,9 +276,9 @@ def test_grib_get_double_array_18(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_metadata_type_qualifier(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend)[0:4] +def test_grib_metadata_type_qualifier(fl_type): + f_in, _ = load_grib_data("tuv_pl.grib", fl_type) + f = f_in[0:4] # to str r = f.metadata("centre:s") @@ -326,9 +316,9 @@ def test_grib_metadata_type_qualifier(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_metadata_astype_core(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend)[0:4] +def test_grib_metadata_astype_core(fl_type): + f_in, _ = load_grib_data("tuv_pl.grib", fl_type) + f = f_in[0:4] # to str r = f.metadata("centre", astype=None) @@ -361,9 +351,8 @@ def test_grib_metadata_astype_core(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_metadata_generic(fl_type, array_backend): - f_full = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_metadata_generic(fl_type): + f_full, _ = load_grib_data("tuv_pl.grib", fl_type) f = f_full[0:4] @@ -391,9 +380,8 @@ def test_grib_metadata_generic(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_metadata_missing_value(fl_type, array_backend): - f = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_metadata_missing_value(fl_type): + f, _ = load_grib_data("ml_data.grib", fl_type, folder="data") with pytest.raises(KeyError): f[0].metadata("scaleFactorOfSecondFixedSurface") @@ -403,9 +391,8 @@ def test_grib_metadata_missing_value(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_metadata_missing_key(fl_type, array_backend): - f = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_metadata_missing_key(fl_type): + f, _ = load_grib_data("test.grib", fl_type) with pytest.raises(KeyError): f[0].metadata("_badkey_") @@ -414,10 +401,9 @@ def test_grib_metadata_missing_key(fl_type, array_backend): assert v == 0 -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_grib_metadata_namespace(fl_type, array_backend): - f = load_grib_data("test6.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_metadata_namespace(fl_type): + f, _ = load_grib_data("test6.grib", fl_type) r = f[0].metadata(namespace="vertical") ref = {"level": 1000, "typeOfLevel": "isobaricInhPa"} @@ -496,9 +482,8 @@ def test_grib_metadata_namespace(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_datetime(fl_type, array_backend): - s = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_datetime(fl_type): + s, _ = load_grib_data("test.grib", fl_type) ref = { "base_time": [datetime.datetime(2020, 5, 13, 12)], @@ -527,18 +512,16 @@ def test_grib_datetime(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_valid_datetime(fl_type, array_backend): - ds = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") +def test_grib_valid_datetime(fl_type): + ds, _ = load_grib_data("t_time_series.grib", fl_type, folder="data") f = ds[4] assert f.metadata("valid_datetime") == "2020-12-21T18:00:00" -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_message(fl_type, array_backend): - f = load_grib_data("test.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_message(fl_type): + f, _ = load_grib_data("test.grib", fl_type) v = f[0].message() assert len(v) == 526 assert v[:4] == b"GRIB" @@ -547,10 +530,9 @@ def test_message(fl_type, array_backend): assert v[:4] == b"GRIB" -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_grib_tilde_shortname(fl_type, array_backend): - f = load_grib_data("tilde_shortname.grib", fl_type, array_backend, folder="data") +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_tilde_shortname(fl_type): + f, _ = load_grib_data("tilde_shortname.grib", fl_type, folder="data") assert f[0].metadata("shortName") == "106" assert f[0].metadata("shortName", astype=int) == 0 diff --git a/tests/grib/test_grib_order_by.py b/tests/grib/test_grib_order_by.py index f7cd571e..5c71c816 100644 --- a/tests/grib/test_grib_order_by.py +++ b/tests/grib/test_grib_order_by.py @@ -15,7 +15,6 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import ARRAY_BACKENDS here = os.path.dirname(__file__) sys.path.insert(0, here) @@ -25,9 +24,8 @@ # @pytest.mark.skipif(("GITHUB_WORKFLOW" in os.environ) or True, reason="Not yet ready") @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_order_by_single_message(fl_type, array_backend): - s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_order_by_single_message(fl_type): + s, _ = load_grib_data("test_single.grib", fl_type, folder="data") r = s.order_by("shortName") assert len(r) == 1 @@ -56,7 +54,6 @@ def __call__(self, x, y): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -104,11 +101,10 @@ def __call__(self, x, y): ) def test_grib_order_by_single_file_( fl_type, - array_backend, params, expected_meta, ): - f = load_grib_data("test6.grib", fl_type, array_backend) + f, _ = load_grib_data("test6.grib", fl_type) g = f.order_by(params) assert len(g) == len(f) @@ -118,7 +114,6 @@ def test_grib_order_by_single_file_( @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta", [ @@ -147,9 +142,9 @@ def test_grib_order_by_single_file_( ), ], ) -def test_grib_order_by_multi_file(fl_type, array_backend, params, expected_meta): - f1 = load_grib_data("test4.grib", fl_type, array_backend) - f2 = load_grib_data("test6.grib", fl_type, array_backend) +def test_grib_order_by_multi_file(fl_type, params, expected_meta): + f1, _ = load_grib_data("test4.grib", fl_type) + f2, _ = load_grib_data("test6.grib", fl_type) f = from_source("multi", [f1, f2]) g = f.order_by(params) @@ -160,9 +155,8 @@ def test_grib_order_by_multi_file(fl_type, array_backend, params, expected_meta) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_order_by_with_sel(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_order_by_with_sel(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.sel(level=500) assert len(g) == 3 @@ -178,9 +172,8 @@ def test_grib_order_by_with_sel(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_order_by_valid_datetime(fl_type, array_backend): - f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") +def test_grib_order_by_valid_datetime(fl_type): + f, _ = load_grib_data("t_time_series.grib", fl_type, folder="data") g = f.order_by(valid_datetime="descending") assert len(g) == 10 @@ -202,9 +195,8 @@ def test_grib_order_by_valid_datetime(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_order_by_remapping(fl_type, array_backend): - ds = load_grib_data("test6.grib", fl_type, array_backend) +def test_grib_order_by_remapping(fl_type): + ds, _ = load_grib_data("test6.grib", fl_type) ordering = ["t850", "t1000", "u1000", "v850", "v1000", "u850"] ref = [("t", 850), ("t", 1000), ("u", 1000), ("v", 850), ("v", 1000), ("u", 850)] diff --git a/tests/grib/test_grib_output.py b/tests/grib/test_grib_output.py index 40df73c8..8e607a95 100644 --- a/tests/grib/test_grib_output.py +++ b/tests/grib/test_grib_output.py @@ -20,15 +20,19 @@ import earthkit.data from earthkit.data import from_source from earthkit.data.core.temporary import temp_file -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import earthkit_examples_file +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import FL_ARRAYS # noqa: E402 +from grib_fixtures import load_grib_data # noqa: E402 + EPSILON = 1e-4 -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_save_when_loaded_from_file(array_backend): - fs = from_source("file", earthkit_examples_file("test6.grib"), array_backend=array_backend) +@pytest.mark.parametrize("fl_type", FL_ARRAYS) +def test_grib_save_when_loaded_from_file(fl_type): + fs, _ = load_grib_data("test6.grib", fl_type) assert len(fs) == 6 with temp_file() as tmp: fs.save(tmp) diff --git a/tests/grib/test_grib_sel.py b/tests/grib/test_grib_sel.py index 7396c6c6..7fcb96d0 100644 --- a/tests/grib/test_grib_sel.py +++ b/tests/grib/test_grib_sel.py @@ -16,7 +16,6 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import ARRAY_BACKENDS here = os.path.dirname(__file__) sys.path.insert(0, here) @@ -27,9 +26,8 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_single_message(fl_type, array_backend): - s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_sel_single_message(fl_type): + s, _ = load_grib_data("test_single.grib", fl_type, folder="data") r = s.sel(shortName="2t") assert len(r) == 1 @@ -37,7 +35,6 @@ def test_grib_sel_single_message(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -65,8 +62,8 @@ def test_grib_sel_single_message(fl_type, array_backend): ), ], ) -def test_grib_sel_single_file_1(fl_type, array_backend, params, expected_meta, metadata_keys): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_sel_single_file_1(fl_type, params, expected_meta, metadata_keys): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.sel(**params) assert len(g) == len(expected_meta) @@ -80,9 +77,8 @@ def test_grib_sel_single_file_1(fl_type, array_backend, params, expected_meta, m @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_single_file_2(fl_type, array_backend): - f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") +def test_grib_sel_single_file_2(fl_type): + f, _ = load_grib_data("t_time_series.grib", fl_type, folder="data") g = f.sel(shortName=["t"], step=[3, 6]) assert len(g) == 2 @@ -102,9 +98,8 @@ def test_grib_sel_single_file_2(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_single_file_as_dict(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_sel_single_file_as_dict(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.sel({"shortName": "t", "level": [500, 700], "mars.type": "an"}) assert len(g) == 2 @@ -115,7 +110,6 @@ def test_grib_sel_single_file_as_dict(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -127,8 +121,8 @@ def test_grib_sel_single_file_as_dict(fl_type, array_backend): (131, (slice(510, 520)), []), ], ) -def test_grib_sel_slice_single_file(fl_type, array_backend, param_id, level, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_sel_slice_single_file(fl_type, param_id, level, expected_meta): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.sel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -137,10 +131,12 @@ def test_grib_sel_slice_single_file(fl_type, array_backend, param_id, level, exp @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_multi_file(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_sel_multi_file(fl_type): + f1, _ = load_grib_data( + "tuv_pl.grib", + fl_type, + ) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -155,10 +151,9 @@ def test_grib_sel_multi_file(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_slice_multi_file(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_sel_slice_multi_file(fl_type): + f1, _ = load_grib_data("tuv_pl.grib", fl_type) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") f = from_source("multi", [f1, f2]) @@ -171,10 +166,9 @@ def test_grib_sel_slice_multi_file(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_date(fl_type, array_backend): +def test_grib_sel_date(fl_type): # date and time - f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") + f, _ = load_grib_data("t_time_series.grib", fl_type, folder="data") g = f.sel(date=20201221, time=1200, step=9) # g = f.sel(date="20201221", time="12", step="9") @@ -190,9 +184,8 @@ def test_grib_sel_date(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_valid_datetime(fl_type, array_backend): - f = load_grib_data("t_time_series.grib", fl_type, array_backend, folder="data") +def test_grib_sel_valid_datetime(fl_type): + f, _ = load_grib_data("t_time_series.grib", fl_type, folder="data") g = f.sel(valid_datetime="2020-12-21T21:00:00") assert len(g) == 2 @@ -207,9 +200,8 @@ def test_grib_sel_valid_datetime(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_isel_single_message(fl_type, array_backend): - s = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_isel_single_message(fl_type): + s, _ = load_grib_data("test_single.grib", fl_type, folder="data") r = s.isel(shortName=0) assert len(r) == 1 @@ -217,7 +209,6 @@ def test_grib_isel_single_message(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "params,expected_meta,metadata_keys", [ @@ -254,8 +245,8 @@ def test_grib_isel_single_message(fl_type, array_backend): ), ], ) -def test_grib_isel_single_file(fl_type, array_backend, params, expected_meta, metadata_keys): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_isel_single_file(fl_type, params, expected_meta, metadata_keys): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.isel(**params) assert len(g) == len(expected_meta) @@ -268,7 +259,6 @@ def test_grib_isel_single_file(fl_type, array_backend, params, expected_meta, me @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "param_id,level,expected_meta", [ @@ -280,8 +270,8 @@ def test_grib_isel_single_file(fl_type, array_backend, params, expected_meta, me (1, (slice(None, None, 2)), [[131, 850], [131, 500], [131, 300]]), ], ) -def test_grib_isel_slice_single_file(fl_type, array_backend, param_id, level, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_isel_slice_single_file(fl_type, param_id, level, expected_meta): + f, _ = load_grib_data("tuv_pl.grib", fl_type) g = f.isel(paramId=param_id, level=level) assert len(g) == len(expected_meta) @@ -290,9 +280,8 @@ def test_grib_isel_slice_single_file(fl_type, array_backend, param_id, level, ex @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_isel_slice_invalid(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_isel_slice_invalid(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) with pytest.raises(IndexError): f.isel(level=500) @@ -302,10 +291,9 @@ def test_grib_isel_slice_invalid(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_isel_multi_file(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_isel_multi_file(fl_type): + f1, _ = load_grib_data("tuv_pl.grib", fl_type) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") f = from_source("multi", [f1, f2]) # single resulting field @@ -319,10 +307,9 @@ def test_grib_isel_multi_file(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_isel_slice_multi_file(fl_type, array_backend): - f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) - f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") +def test_grib_isel_slice_multi_file(fl_type): + f1, _ = load_grib_data("tuv_pl.grib", fl_type) + f2, _ = load_grib_data("ml_data.grib", fl_type, folder="data") f = from_source("multi", [f1, f2]) g = f.isel(shortName=1, level=slice(20, 22)) @@ -334,18 +321,16 @@ def test_grib_isel_slice_multi_file(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_remapping_1(fl_type, array_backend): - ds = load_grib_data("test6.grib", fl_type, array_backend) +def test_grib_sel_remapping_1(fl_type): + ds, _ = load_grib_data("test6.grib", fl_type) ref = [("t", 850)] r = ds.sel(param_level="t850", remapping={"param_level": "{param}{levelist}"}) assert r.metadata("param", "level") == ref @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_sel_remapping_2(fl_type, array_backend): - ds = load_grib_data("test6.grib", fl_type, array_backend) +def test_grib_sel_remapping_2(fl_type): + ds, _ = load_grib_data("test6.grib", fl_type) ref = [("u", 1000), ("t", 850)] r = ds.sel(param_level=["t850", "u1000"], remapping={"param_level": "{param}{levelist}"}) assert r.metadata("param", "level") == ref diff --git a/tests/grib/test_grib_serialise.py b/tests/grib/test_grib_serialise.py index edd94e34..670c6298 100644 --- a/tests/grib/test_grib_serialise.py +++ b/tests/grib/test_grib_serialise.py @@ -22,32 +22,47 @@ here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_NUMPY # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_metadata(fl_type, array_backend): - ds = load_grib_data("test.grib", fl_type, array_backend) +def _pickle(data, representation): + if representation == "file": + with temp_file() as tmp: + with open(tmp, "wb") as f: + pickle.dump(data, f) + + with open(tmp, "rb") as f: + data_res = pickle.load(f) + elif representation == "memory": + pickled_data = pickle.dumps(data) + data_res = pickle.loads(pickled_data) + else: + raise ValueError(f"Invalid representation: {representation}") + return data_res + +@pytest.mark.parametrize("fl_type", FL_NUMPY) +@pytest.mark.parametrize("representation", ["file", "memory"]) +def test_grib_serialise_metadata(fl_type, representation): + ds, _ = load_grib_data("test.grib", fl_type) md = ds[0].metadata().override() - pickled_md = pickle.dumps(md) - md2 = pickle.loads(pickled_md) + + md2 = _pickle(md, representation) keys = ["param", "date", "time", "step", "level", "gridType", "type"] for k in keys: assert md[k] == md2[k] -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_array_field(fl_type, array_backend): - ds0 = load_grib_data("test.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_NUMPY) +@pytest.mark.parametrize("representation", ["file", "memory"]) +def test_grib_serialise_array_field_memory(fl_type, representation): + ds0, _ = load_grib_data("test.grib", fl_type) ds = ds0.to_fieldlist() for idx in range(len(ds)): - pickled_f = pickle.dumps(ds[idx]) - f2 = pickle.loads(pickled_f) + f2 = _pickle(ds[idx], representation) assert np.allclose(ds[idx].values, f2.values), f"index={idx}" assert np.allclose(ds[idx].to_numpy(), f2.to_numpy()), f"index={idx}" @@ -57,14 +72,13 @@ def test_grib_serialise_array_field(fl_type, array_backend): assert ds[idx].metadata(k) == f2.metadata(k), f"index={idx}" -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_array_fieldlist(fl_type, array_backend): - ds0 = load_grib_data("test.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_NUMPY) +@pytest.mark.parametrize("representation", ["file", "memory"]) +def test_grib_serialise_array_fieldlist(fl_type, representation): + ds0, _ = load_grib_data("test.grib", fl_type) ds = ds0.to_fieldlist() - pickled_f = pickle.dumps(ds) - ds2 = pickle.loads(pickled_f) + ds2 = _pickle(ds, representation) assert len(ds) == len(ds2) assert np.allclose(ds.values, ds2.values) @@ -82,16 +96,15 @@ def test_grib_serialise_array_fieldlist(fl_type, array_backend): with temp_file() as tmp: ds2.save(tmp) assert os.path.exists(tmp) - r_tmp = from_source("file", tmp, array_backend=array_backend) + r_tmp = from_source("file", tmp) assert len(ds2) == len(r_tmp) v_tmp = r_tmp[0].to_numpy() assert np.allclose(v1 + 1, v_tmp) @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_file_fieldlist_core(fl_type, array_backend): - ds = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_serialise_file_fieldlist_core(fl_type): + ds, _ = load_grib_data("test.grib", fl_type) pickled_f = pickle.dumps(ds) ds2 = pickle.loads(pickled_f) @@ -108,9 +121,8 @@ def test_grib_serialise_file_fieldlist_core(fl_type, array_backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_file_fieldlist_sel(fl_type, array_backend): - ds0 = load_grib_data("test6.grib", fl_type, array_backend) +def test_grib_serialise_file_fieldlist_sel(fl_type): + ds0, _ = load_grib_data("test6.grib", fl_type) ds = ds0.sel(param="t") assert len(ds) == 2 @@ -129,11 +141,10 @@ def test_grib_serialise_file_fieldlist_sel(fl_type, array_backend): @pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_serialise_file_fieldlist_concat(fl_type, array_backend): - ds = load_grib_data("test.grib", fl_type, array_backend) + load_grib_data( - "test6.grib", fl_type, array_backend - ) +def test_grib_serialise_file_fieldlist_concat(fl_type): + ds00, _ = load_grib_data("test.grib", fl_type) + ds01, _ = load_grib_data("test6.grib", fl_type) + ds = ds00 + ds01 assert len(ds) == 8 pickled_f = pickle.dumps(ds) diff --git a/tests/grib/test_grib_simplefieldlist.py b/tests/grib/test_grib_simplefieldlist.py new file mode 100644 index 00000000..bd45c42b --- /dev/null +++ b/tests/grib/test_grib_simplefieldlist.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import pytest + +from earthkit.data import ArrayField +from earthkit.data import SimpleFieldList +from earthkit.data import from_source +from earthkit.data.testing import earthkit_examples_file + + +def _check(ds, group): + assert len(ds) == 6 + + ref = [("t", 1000), ("u", 1000), ("v", 1000), ("t", 850), ("u", 850), ("v", 850)] + + assert ds.metadata(("param", "level")) == ref + + ref = [ + [("t", 1000), ("t", 850)], + [("u", 1000), ("u", 850)], # + [("v", 1000), ("v", 850)], + ] + cnt = 0 + for i, f in enumerate(ds.group_by(group)): + assert len(f) == 2 + assert f.metadata(("param", "level")) == ref[i] + afl = f.to_fieldlist() + assert afl is not f + assert len(afl) == 2 + cnt += len(f) + + assert cnt == len(ds) + + +@pytest.mark.parametrize("group", ["param"]) +def test_grib_simple_fl_1(group): + ds_in = from_source("file", earthkit_examples_file("test6.grib")) + + ds = SimpleFieldList() + for f in ds_in: + ds.append(f) + + _check(ds, group) + + +@pytest.mark.parametrize("group", ["param"]) +def test_grib_simple_fl_2(group): + ds = from_source("file", earthkit_examples_file("test6.grib")) + + ds = SimpleFieldList([f for f in ds]) + + _check(ds, group) + + +@pytest.mark.parametrize("group", ["param"]) +def test_grib_simple_fl_3(group): + ds_in = from_source("file", earthkit_examples_file("test6.grib")) + + ds = SimpleFieldList() + for f in ds_in: + ds.append(ArrayField(f.to_numpy(), f.metadata())) + + _check(ds, group) diff --git a/tests/grib/test_grib_slice.py b/tests/grib/test_grib_slice.py index fe515bff..8749f16e 100644 --- a/tests/grib/test_grib_slice.py +++ b/tests/grib/test_grib_slice.py @@ -16,7 +16,6 @@ import pytest from earthkit.data import from_source -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import earthkit_examples_file here = os.path.dirname(__file__) @@ -26,7 +25,6 @@ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "index,expected_meta", [ @@ -37,8 +35,8 @@ (-5, ["u", 400]), ], ) -def test_grib_single_index(fl_type, array_backend, index, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_single_index(fl_type, index, expected_meta): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # f = from_source("file", earthkit_examples_file("tuv_pl.grib")) r = f[index] @@ -50,15 +48,13 @@ def test_grib_single_index(fl_type, array_backend, index, expected_meta): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_single_index_bad(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_single_index_bad(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) with pytest.raises(IndexError): f[27] @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes,expected_meta", [ @@ -70,8 +66,8 @@ def test_grib_single_index_bad(fl_type, array_backend): (slice(14, None), [["v", 400], ["t", 300], ["u", 300], ["v", 300]]), ], ) -def test_grib_slice_single_file(fl_type, array_backend, indexes, expected_meta): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_slice_single_file(fl_type, indexes, expected_meta): + f, _ = load_grib_data("tuv_pl.grib", fl_type) r = f[indexes] assert len(r) == 4 assert r.metadata(["shortName", "level"]) == expected_meta @@ -107,7 +103,6 @@ def test_grib_slice_multi_file(indexes, expected_meta): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes1,indexes2", [ @@ -116,8 +111,8 @@ def test_grib_slice_multi_file(indexes, expected_meta): ((1, 16, 5, 9), (1, 3)), ], ) -def test_grib_array_indexing(fl_type, array_backend, indexes1, indexes2): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_array_indexing(fl_type, indexes1, indexes2): + f, _ = load_grib_data("tuv_pl.grib", fl_type) r = f[indexes1] assert len(r) == 4 @@ -130,7 +125,6 @@ def test_grib_array_indexing(fl_type, array_backend, indexes1, indexes2): @pytest.mark.skip(reason="Index range checking disabled") @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "indexes", [ @@ -139,16 +133,15 @@ def test_grib_array_indexing(fl_type, array_backend, indexes1, indexes2): ((1, 16, 5, 9), (1, 3)), ], ) -def test_grib_array_indexing_bad(fl_type, array_backend, indexes): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_array_indexing_bad(fl_type, indexes): + f = load_grib_data("tuv_pl.grib", fl_type) with pytest.raises(IndexError): f[indexes] @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator(fl_type, array_backend): - g = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_fieldlist_iterator(fl_type): + g, _ = load_grib_data("tuv_pl.grib", fl_type) sn = g.metadata("shortName") assert len(sn) == 18 iter_sn = [f.metadata("shortName") for f in g] @@ -159,12 +152,11 @@ def test_grib_fieldlist_iterator(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator_with_zip(fl_type, array_backend): +def test_grib_fieldlist_iterator_with_zip(fl_type): # test something different to the iterator - does not try to # 'go off the edge' of the fieldlist, because the length is determined by # the list of levels - g = load_grib_data("tuv_pl.grib", fl_type, array_backend) + g, _ = load_grib_data("tuv_pl.grib", fl_type) ref_levs = g.metadata("level") assert len(ref_levs) == 18 levs1 = [] @@ -177,10 +169,9 @@ def test_grib_fieldlist_iterator_with_zip(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, array_backend): +def test_grib_fieldlist_iterator_with_zip_multiple(fl_type): # same as test_fieldlist_iterator_with_zip() but multiple times - g = load_grib_data("tuv_pl.grib", fl_type, array_backend) + g, _ = load_grib_data("tuv_pl.grib", fl_type) ref_levs = g.metadata("level") assert len(ref_levs) == 18 for i in range(2): @@ -194,9 +185,8 @@ def test_grib_fieldlist_iterator_with_zip_multiple(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_fieldlist_reverse_iterator(fl_type, array_backend): - g = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_fieldlist_reverse_iterator(fl_type): + g, _ = load_grib_data("tuv_pl.grib", fl_type) sn = g.metadata("shortName") sn_reversed = list(reversed(sn)) assert sn_reversed[0] == "v" diff --git a/tests/grib/test_grib_stream.py b/tests/grib/test_grib_stream.py index c1ca6fbc..8b14ffb0 100644 --- a/tests/grib/test_grib_stream.py +++ b/tests/grib/test_grib_stream.py @@ -63,20 +63,21 @@ def test_grib_from_stream_iter(): assert sum([1 for _ in ds]) == 0 -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_from_stream_fieldlist_backend(array_backend): - with open(earthkit_examples_file("test6.grib"), "rb") as stream: - ds = from_source("stream", stream, array_backend=array_backend) +# @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +# def test_grib_from_stream_fieldlist_backend(array_backend): +# with open(earthkit_examples_file("test6.grib"), "rb") as stream: +# ds = from_source("stream", stream, array_backend=array_backend) - assert isinstance(ds, StreamFieldList) - assert ds.array_backend.name == array_backend - assert ds.to_array().shape == (6, 7, 12) +# assert isinstance(ds, StreamFieldList) - assert sum([1 for _ in ds]) == 0 +# # assert ds.array_backend.name == array_backend +# assert ds.to_array().shape == (6, 7, 12) - with pytest.raises((RuntimeError, ValueError)): - ds.to_array() +# assert sum([1 for _ in ds]) == 0 + +# with pytest.raises((RuntimeError, ValueError)): +# ds.to_array() @pytest.mark.parametrize( @@ -135,9 +136,11 @@ def test_grib_from_stream_batched_convert_to_numpy(convert_kwargs, expected_shap for i, f in enumerate(ds.batched(2)): df = f.to_fieldlist(array_backend="numpy", **convert_kwargs) assert df.metadata(("param", "level")) == ref[i], i - assert df._array.shape == expected_shape, i assert df.to_numpy(**convert_kwargs).shape == expected_shape, i - assert df.to_fieldlist(array_backend="numpy", **convert_kwargs) is df, i + df1 = df.to_fieldlist(array_backend="numpy", **convert_kwargs) + assert df1 is not df, i + assert df1.metadata(("param", "level")) == ref[i], i + assert df1.to_numpy(**convert_kwargs).shape == expected_shape, i # stream consumed, no data is available assert sum([1 for _ in ds]) == 0 @@ -147,7 +150,7 @@ def test_grib_from_stream_batched_convert_to_numpy(convert_kwargs, expected_shap @pytest.mark.parametrize("group", ["level", ["level", "gridType"]]) def test_grib_from_stream_group_by(array_backend, group): with open(earthkit_examples_file("test6.grib"), "rb") as stream: - ds = from_source("stream", stream, array_backend=array_backend) + ds = from_source("stream", stream) # no methods are available with pytest.raises((TypeError, NotImplementedError)): @@ -199,9 +202,12 @@ def test_grib_from_stream_group_by_convert_to_numpy(convert_kwargs, expected_sha df = f.to_fieldlist(array_backend="numpy", **convert_kwargs) assert len(df) == 3 assert df.metadata(("param", "level")) == ref[i] - assert df._array.shape == expected_shape assert df.to_numpy(**convert_kwargs).shape == expected_shape - assert df.to_fieldlist(array_backend="numpy", **convert_kwargs) is df + df1 = df.to_fieldlist(array_backend="numpy", **convert_kwargs) + assert df1 is not df + assert len(df1) == 3 + assert df1.metadata(("param", "level")) == ref[i] + assert df1.to_numpy(**convert_kwargs).shape == expected_shape # stream consumed, no data is available assert sum([1 for _ in ds]) == 0 @@ -318,8 +324,11 @@ def test_grib_from_stream_in_memory_convert_to_numpy(convert_kwargs, expected_sh vals = ds.to_numpy(**convert_kwargs)[:, 0] assert np.allclose(vals, ref) - assert ds._array.shape == expected_shape - assert ds.to_fieldlist(array_backend="numpy", **convert_kwargs) is ds + assert ds.to_numpy(**convert_kwargs).shape == expected_shape + ds1 = ds.to_fieldlist(array_backend="numpy", **convert_kwargs) + assert ds1 is not ds + assert len(ds1) == 6 + assert ds1.to_numpy(**convert_kwargs).shape == expected_shape def test_grib_save_when_loaded_from_stream(): diff --git a/tests/grib/test_grib_summary.py b/tests/grib/test_grib_summary.py index ce0cf120..c064ea03 100644 --- a/tests/grib/test_grib_summary.py +++ b/tests/grib/test_grib_summary.py @@ -13,18 +13,16 @@ import pytest -from earthkit.data.testing import ARRAY_BACKENDS - here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_FILE # noqa: E402 from grib_fixtures import FL_TYPES # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_describe(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_describe(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # full contents df = f.describe() @@ -148,9 +146,8 @@ def test_grib_describe(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # default keys f1 = f[0:4] @@ -203,9 +200,8 @@ def test_grib_ls(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls_keys(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls_keys(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # default keys # positive num (=head) @@ -230,9 +226,8 @@ def test_grib_ls_keys(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls_namespace(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls_namespace(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) df = f.ls(n=2, namespace="vertical") ref = { @@ -252,9 +247,8 @@ def test_grib_ls_namespace(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls_invalid_num(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls_invalid_num(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) with pytest.raises(ValueError): f.ls(n=0) @@ -264,17 +258,15 @@ def test_grib_ls_invalid_num(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls_invalid_arg(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls_invalid_arg(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) with pytest.raises(TypeError): f.ls(invalid=1) @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_ls_num(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_ls_num(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # default keys @@ -320,9 +312,8 @@ def test_grib_ls_num(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_head_num(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_head_num(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # default keys df = f.head(n=2) @@ -346,9 +337,8 @@ def test_grib_head_num(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_tail_num(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_tail_num(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) # default keys df = f.tail(n=2) @@ -371,10 +361,9 @@ def test_grib_tail_num(fl_type, array_backend): assert ref == df.to_dict() -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", [None]) -def test_grib_dump(fl_type, array_backend): - f = load_grib_data("test6.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_dump(fl_type): + f, _ = load_grib_data("test6.grib", fl_type) namespaces = ( "default", diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index 14faac06..d8c2499d 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -15,13 +15,12 @@ import numpy as np import pytest -from earthkit.data.testing import ARRAY_BACKENDS from earthkit.data.testing import check_array_type -from earthkit.data.testing import get_array -from earthkit.data.testing import get_array_namespace here = os.path.dirname(__file__) sys.path.insert(0, here) +from grib_fixtures import FL_FILE # noqa: E402 +from grib_fixtures import FL_NUMPY # noqa: E402 from grib_fixtures import FL_TYPES # noqa: E402 from grib_fixtures import load_grib_data # noqa: E402 @@ -34,9 +33,8 @@ def check_array(v, shape=None, first=None, last=None, meanv=None, eps=1e-3): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_values_1(fl_type, array_backend): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_values_1(fl_type): + f, array_backend = load_grib_data("test_single.grib", fl_type, folder="data") eps = 1e-5 # whole file @@ -61,12 +59,9 @@ def test_grib_values_1(fl_type, array_backend): assert np.allclose(v, v1, eps) -# @pytest.mark.parametrize("fl_type", FL_TYPES) -# @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -@pytest.mark.parametrize("fl_type", ["file"]) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_values_18(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_FILE) +def test_grib_values_18(fl_type): + f, array_backend = load_grib_data("tuv_pl.grib", fl_type) eps = 1e-5 # whole file @@ -95,9 +90,8 @@ def test_grib_values_18(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_numpy_1(fl_type, array_backend): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_numpy_1(fl_type): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") eps = 1e-5 v = f.to_numpy() @@ -115,7 +109,6 @@ def test_grib_to_numpy_1(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "first,options, expected_shape", [ @@ -127,8 +120,8 @@ def test_grib_to_numpy_1(fl_type, array_backend): (True, {"flatten": False}, (7, 12)), ], ) -def test_grib_to_numpy_1_shape(fl_type, array_backend, first, options, expected_shape): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_numpy_1_shape(fl_type, first, options, expected_shape): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") v_ref = f[0].to_numpy().flatten() eps = 1e-5 @@ -143,9 +136,8 @@ def test_grib_to_numpy_1_shape(fl_type, array_backend, first, options, expected_ @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_numpy_18(fl_type, array_backend): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_to_numpy_18(fl_type): + f, _ = load_grib_data("tuv_pl.grib", fl_type) eps = 1e-5 @@ -176,7 +168,6 @@ def test_grib_to_numpy_18(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @pytest.mark.parametrize( "options, expected_shape", [ @@ -198,8 +189,8 @@ def test_grib_to_numpy_18(fl_type, array_backend): ({"flatten": False}, (18, 7, 12)), ], ) -def test_grib_to_numpy_18_shape(fl_type, array_backend, options, expected_shape): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_to_numpy_18_shape(fl_type, options, expected_shape): + f, _ = load_grib_data("tuv_pl.grib", fl_type) eps = 1e-5 @@ -223,11 +214,10 @@ def test_grib_to_numpy_18_shape(fl_type, array_backend, options, expected_shape) assert np.allclose(vf15, vr, eps) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_1_dtype(fl_type, array_backend, dtype): - f = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_numpy_1_dtype(fl_type, dtype): + f, _ = load_grib_data("test_single.grib", fl_type, folder="data") v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -236,11 +226,10 @@ def test_grib_to_numpy_1_dtype(fl_type, array_backend, dtype): assert v.dtype == dtype -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize("dtype", [np.float32, np.float64]) -def test_grib_to_numpy_18_dtype(fl_type, array_backend, dtype): - f = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_to_numpy_18_dtype(fl_type, dtype): + f, _ = load_grib_data("tuv_pl.grib", fl_type) v = f[0].to_numpy(dtype=dtype) assert v.dtype == dtype @@ -250,9 +239,8 @@ def test_grib_to_numpy_18_dtype(fl_type, array_backend, dtype): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_numpy_1_index(fl_type, array_backend): - ds = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") +def test_grib_to_numpy_1_index(fl_type): + ds, _ = load_grib_data("test_single.grib", fl_type, folder="data") eps = 1e-5 @@ -288,9 +276,8 @@ def test_grib_to_numpy_1_index(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_to_numpy_18_index(fl_type, array_backend): - ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) +def test_grib_to_numpy_18_index(fl_type): + ds, _ = load_grib_data("tuv_pl.grib", fl_type) eps = 1e-5 @@ -366,22 +353,22 @@ def test_grib_to_numpy_18_index(fl_type, array_backend): ) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_FILE) +# @pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ ({}, (11, 19), np.float64), - ({"flatten": True}, (209,), np.float64), - ({"flatten": True, "dtype": np.float32}, (209,), np.float32), - ({"flatten": True, "dtype": np.float64}, (209,), np.float64), - ({"flatten": False}, (11, 19), np.float64), - ({"flatten": False, "dtype": np.float32}, (11, 19), np.float32), - ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), + # ({"flatten": True}, (209,), np.float64), + # ({"flatten": True, "dtype": np.float32}, (209,), np.float32), + # ({"flatten": True, "dtype": np.float64}, (209,), np.float64), + # ({"flatten": False}, (11, 19), np.float64), + # ({"flatten": False, "dtype": np.float32}, (11, 19), np.float32), + # ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_field_data(fl_type, array_backend, kwarg, expected_shape, expected_dtype): - ds = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_field_data(fl_type, kwarg, expected_shape, expected_dtype): + ds, _ = load_grib_data("test.grib", fl_type) latlon = ds[0].to_latlon(**kwarg) v = ds[0].to_numpy(**kwarg) @@ -418,8 +405,7 @@ def test_grib_field_data(fl_type, array_backend, kwarg, expected_shape, expected assert np.allclose(d[1], latlon["lon"]) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) +@pytest.mark.parametrize("fl_type", FL_NUMPY) @pytest.mark.parametrize( "kwarg,expected_shape,expected_dtype", [ @@ -432,8 +418,8 @@ def test_grib_field_data(fl_type, array_backend, kwarg, expected_shape, expected ({"flatten": False, "dtype": np.float64}, (11, 19), np.float64), ], ) -def test_grib_fieldlist_data(fl_type, array_backend, kwarg, expected_shape, expected_dtype): - ds = load_grib_data("test.grib", fl_type, array_backend) +def test_grib_fieldlist_data(fl_type, kwarg, expected_shape, expected_dtype): + ds, _ = load_grib_data("test.grib", fl_type) latlon = ds.to_latlon(**kwarg) v = ds.to_numpy(**kwarg) @@ -471,10 +457,9 @@ def test_grib_fieldlist_data(fl_type, array_backend, kwarg, expected_shape, expe assert np.allclose(d[2], latlon["lon"]) -@pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ["numpy"]) -def test_grib_fieldlist_data_index(fl_type, array_backend): - ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) +@pytest.mark.parametrize("fl_type", FL_NUMPY) +def test_grib_fieldlist_data_index(fl_type): + ds, _ = load_grib_data("tuv_pl.grib", fl_type) eps = 1e-5 @@ -571,19 +556,18 @@ def test_grib_fieldlist_data_index(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) -@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_values_with_missing(fl_type, array_backend): - f = load_grib_data("test_single_with_missing.grib", fl_type, array_backend, folder="data") +def test_grib_values_with_missing(fl_type): + f, array_backend = load_grib_data("test_single_with_missing.grib", fl_type, folder="data") v = f[0].values check_array_type(v, array_backend) assert v.shape == (84,) eps = 0.001 - ns = get_array_namespace(array_backend) + ns = array_backend.namespace assert ns.count_nonzero(ns.isnan(v)) == 38 - mask = get_array([12, 14, 15, 24, 25, 26] + list(range(28, 60)), array_backend) + mask = array_backend.from_other([12, 14, 15, 24, 25, 26] + list(range(28, 60))) assert np.isclose(v[0], 260.4356, eps) assert np.isclose(v[11], 260.4356, eps) assert np.isclose(v[-1], 227.1856, eps) diff --git a/tests/utils/test_array.py b/tests/utils/test_array.py index 0ab40a59..b5d67d83 100644 --- a/tests/utils/test_array.py +++ b/tests/utils/test_array.py @@ -13,43 +13,49 @@ from earthkit.data.testing import NO_CUPY from earthkit.data.testing import NO_PYTORCH -from earthkit.data.utils.array import ensure_backend +from earthkit.data.utils.array import _CUPY +from earthkit.data.utils.array import _NUMPY +from earthkit.data.utils.array import _PYTORCH from earthkit.data.utils.array import get_backend +"""These tests are for the array backend utilities mostly used in other tests.""" + def test_utils_array_backend_numpy(): - b = ensure_backend("numpy") + b = get_backend("numpy") assert b.name == "numpy" + assert b is _NUMPY import numpy as np v = np.ones(10) v_lst = [1.0] * 10 - assert b.is_native_array(v) assert id(b.to_numpy(v)) == id(v) - assert id(b.from_backend(v, b)) == id(v) - assert id(b.from_backend(v, None)) == id(v) + assert id(b.from_numpy(v)) == id(v) + assert id(b.from_other(v)) == id(v) + assert np.allclose(b.from_other(v_lst, dtype=np.float64), v) assert get_backend(v) is b - assert get_backend(v, guess=b) is b + assert get_backend(np) is b - assert np.isclose(b.array_ns.mean(v), 1.0) + assert np.isclose(b.namespace.mean(v), 1.0) if not NO_PYTORCH: import torch v_pt = torch.ones(10, dtype=torch.float64) - pt_b = ensure_backend("pytorch") - r = b.to_backend(v, pt_b) + pt_b = get_backend("pytorch") + r = pt_b.from_other(v) assert torch.is_tensor(r) assert torch.allclose(r, v_pt) @pytest.mark.skipif(NO_PYTORCH, reason="No pytorch installed") def test_utils_array_backend_pytorch(): - b = ensure_backend("pytorch") + b = get_backend("pytorch") assert b.name == "pytorch" + assert b is _PYTORCH import numpy as np import torch @@ -58,27 +64,22 @@ def test_utils_array_backend_pytorch(): v_np = np.ones(10, dtype=np.float64) v_lst = [1.0] * 10 - assert b.is_native_array(v) - assert id(b.from_backend(v, b)) == id(v) - assert id(b.from_backend(v, None)) == id(v) - assert torch.allclose(b.from_backend(v_np, None), v) assert torch.allclose(b.from_numpy(v_np), v) assert torch.allclose(b.from_other(v_lst, dtype=torch.float64), v) assert get_backend(v) is b - assert get_backend(v, guess=b) is b - np_b = ensure_backend("numpy") - r = b.to_backend(v, np_b) + r = b.to_numpy(v) assert isinstance(r, np.ndarray) assert np.allclose(r, v_np) - assert np.isclose(b.array_ns.mean(v), 1.0) + assert np.isclose(b.namespace.mean(v), 1.0) @pytest.mark.skipif(NO_CUPY, reason="No pytorch installed") def test_utils_array_backend_cupy(): - b = ensure_backend("cupy") + b = get_backend("cupy") assert b.name == "cupy" + assert b is _CUPY import cupy as cp import numpy as np @@ -87,20 +88,18 @@ def test_utils_array_backend_cupy(): v_np = np.ones(10, dtype=np.float64) v_lst = [1.0] * 10 - assert b.is_native_array(v) - assert id(b.from_backend(v, b)) == id(v) - assert id(b.from_backend(v, None)) == id(v) - assert cp.allclose(b.from_backend(v_np, None), v) + # assert b.is_native_array(v) + # assert id(b.from_backend(v, b)) == id(v) + # assert id(b.from_backend(v, None)) == id(v) + assert cp.allclose(b.from_numpy(v_np, None), v) assert cp.allclose(b.from_other(v_lst, dtype=cp.float64), v) assert get_backend(v) is b - assert get_backend(v, guess=b) is b - np_b = ensure_backend("numpy") - r = b.to_backend(v, np_b) + r = b.to_numpy(v) assert isinstance(r, np.ndarray) assert np.allclose(r, v_np) - assert np.isclose(b.array_ns.mean(v), 1.0) + assert np.isclose(b.namespace.mean(v), 1.0) if __name__ == "__main__":