diff --git a/ci/doc.yml b/ci/doc.yml index 414cd07..867cb57 100644 --- a/ci/doc.yml +++ b/ci/doc.yml @@ -3,8 +3,9 @@ channels: - conda-forge dependencies: - cupy-core + - rapidsai::kvikio>=25.04.00 - pip - - python=3.10 + - python=3.11 - sphinx - sphinx-design - sphinx-copybutton @@ -14,7 +15,8 @@ dependencies: - ipywidgets - furo>=2024.8.6 - myst-nb - - xarray + - xarray>=2025.03.0 + - zarr>=3.0.3 - pip: # relative to this file. Needs to be editable to be accepted. - --editable .. diff --git a/cupy_xarray/__init__.py b/cupy_xarray/__init__.py index 5c3a06c..0bb96aa 100644 --- a/cupy_xarray/__init__.py +++ b/cupy_xarray/__init__.py @@ -1,4 +1,5 @@ from . import _version -from .accessors import CupyDataArrayAccessor, CupyDatasetAccessor # noqa +from .accessors import CupyDataArrayAccessor, CupyDatasetAccessor # noqa: F401 +from .kvikio import KvikioBackendEntrypoint # noqa: F401 __version__ = _version.get_versions()["version"] diff --git a/cupy_xarray/kvikio.py b/cupy_xarray/kvikio.py new file mode 100644 index 0000000..7e3180d --- /dev/null +++ b/cupy_xarray/kvikio.py @@ -0,0 +1,104 @@ +""" +:doc:`kvikIO ` backend for xarray to read Zarr stores directly into CuPy +arrays in GPU memory. +""" + +import functools + +from xarray.backends.common import _normalize_path # TODO: can this be public +from xarray.backends.store import StoreBackendEntrypoint +from xarray.backends.zarr import ZarrBackendEntrypoint, ZarrStore +from xarray.core.dataset import Dataset +from xarray.core.utils import close_on_error # TODO: can this be public. + +try: + import kvikio.zarr + import zarr + + has_kvikio = True +except ImportError: + has_kvikio = False + + +class KvikioBackendEntrypoint(ZarrBackendEntrypoint): + """ + Xarray backend to read Zarr stores using 'kvikio' engine. + + For more information about the underlying library, visit + :doc:`kvikIO's Zarr page`. + """ + + available = has_kvikio + description = "Open zarr files (.zarr) using Kvikio" + url = "https://docs.rapids.ai/api/kvikio/stable/api/#zarr" + + # disabled by default + # We need to provide this because of the subclassing from + # ZarrBackendEntrypoint + def guess_can_open(self, filename_or_obj): + return False + + def open_dataset( + self, + filename_or_obj, + mask_and_scale=True, + decode_times=True, + concat_characters=True, + decode_coords=True, + drop_variables=None, + use_cftime=None, + decode_timedelta=None, + group=None, + mode="r", + synchronizer=None, + consolidated=None, + chunk_store=None, + storage_options=None, + zarr_version=None, + zarr_format=None, + store=None, + engine=None, + use_zarr_fill_value_as_mask=None, + cache_members: bool = True, + ) -> Dataset: + filename_or_obj = _normalize_path(filename_or_obj) + if not store: + with zarr.config.enable_gpu(): + _store = kvikio.zarr.GDSStore(root=filename_or_obj) + + # Override default buffer prototype to be GPU buffer + # buffer_prototype = zarr.core.buffer.core.default_buffer_prototype() + buffer_prototype = zarr.core.buffer.gpu.buffer_prototype + _store.get = functools.partial(_store.get, prototype=buffer_prototype) + _store.get_partial_values = functools.partial( + _store.get_partial_values, prototype=buffer_prototype + ) + + store = ZarrStore.open_group( + store=_store, + group=group, + mode=mode, + synchronizer=synchronizer, + consolidated=consolidated, + consolidate_on_close=False, + chunk_store=chunk_store, + storage_options=storage_options, + zarr_version=zarr_version, + use_zarr_fill_value_as_mask=None, + zarr_format=zarr_format, + cache_members=cache_members, + ) + + store_entrypoint = StoreBackendEntrypoint() + with close_on_error(store): + ds = store_entrypoint.open_dataset( + store, + mask_and_scale=mask_and_scale, + decode_times=decode_times, + concat_characters=concat_characters, + decode_coords=decode_coords, + drop_variables=drop_variables, + use_cftime=use_cftime, + decode_timedelta=decode_timedelta, + ) + return ds diff --git a/cupy_xarray/tests/test_kvikio.py b/cupy_xarray/tests/test_kvikio.py new file mode 100644 index 0000000..b6fe67f --- /dev/null +++ b/cupy_xarray/tests/test_kvikio.py @@ -0,0 +1,54 @@ +import cupy as cp +import numpy as np +import pytest +import xarray as xr +from xarray.core.indexing import ExplicitlyIndexedNDArrayMixin + +kvikio = pytest.importorskip("kvikio") +zarr = pytest.importorskip("zarr") + +import kvikio.zarr # noqa +import xarray.core.indexing # noqa + + +@pytest.fixture +def store(tmp_path): + ds = xr.Dataset( + { + "a": ("x", np.arange(10), {"foo": "bar"}), + "scalar": np.array(1), + }, + coords={"x": ("x", np.arange(-5, 5))}, + ) + + for var in ds.variables: + ds[var].encoding["compressors"] = None + + store_path = tmp_path / "kvikio.zarr" + ds.to_zarr(store_path, consolidated=True) + return store_path + + +def test_entrypoint(): + assert "kvikio" in xr.backends.list_engines() + + +@pytest.mark.parametrize("consolidated", [True, False]) +def test_lazy_load(consolidated, store): + with xr.open_dataset(store, engine="kvikio", consolidated=consolidated) as ds: + for _, da in ds.data_vars.items(): + assert isinstance(da.variable._data, ExplicitlyIndexedNDArrayMixin) + + +@pytest.mark.parametrize("indexer", [slice(None), slice(2, 4), 2, [2, 3, 5]]) +def test_lazy_indexing(indexer, store): + with zarr.config.enable_gpu(), xr.open_dataset(store, engine="kvikio") as ds: + ds = ds.isel(x=indexer) + for _, da in ds.data_vars.items(): + assert isinstance(da.variable._data, ExplicitlyIndexedNDArrayMixin) + + loaded = ds.compute() + for _, da in loaded.data_vars.items(): + if da.ndim == 0: + continue + assert isinstance(da.data, cp.ndarray) diff --git a/docs/api.rst b/docs/api.rst index 70d22b0..17bdb12 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -51,3 +51,16 @@ Methods Dataset.cupy.as_cupy Dataset.cupy.as_numpy + + +KvikIO engine +------------- + +.. currentmodule:: cupy_xarray + +.. automodule:: cupy_xarray.kvikio + +.. autosummary:: + :toctree: generated/ + + KvikioBackendEntrypoint diff --git a/docs/conf.py b/docs/conf.py index 2dffa80..ebba5ed 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -57,6 +57,7 @@ "python": ("https://docs.python.org/3/", None), "dask": ("https://docs.dask.org/en/latest", None), "cupy": ("https://docs.cupy.dev/en/latest", None), + "kvikio": ("https://docs.rapids.ai/api/kvikio/stable", None), "xarray": ("http://docs.xarray.dev/en/latest/", None), } diff --git a/docs/examples/07_kvikio.ipynb b/docs/examples/07_kvikio.ipynb new file mode 100644 index 0000000..0a9dd45 --- /dev/null +++ b/docs/examples/07_kvikio.ipynb @@ -0,0 +1,5334 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5920bb97-1d76-4363-9aee-d1c5cd395409", + "metadata": {}, + "source": [ + "# Kvikio demo\n", + "\n", + "Requires\n", + "- [ ] https://github.com/pydata/xarray/pull/10078\n", + "- [ ] https://github.com/rapidsai/kvikio/pull/646" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c9ee3a73-6f7b-4875-b5a6-2e6d48fade44", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exception reporting mode: Minimal\n", + "numpy : 2.2.3\n", + "zarr : 3.0.5\n", + "cupy_xarray: 0.1.4+36.ge26ed24.dirty\n", + "kvikio : 25.4.0\n", + "xarray : 2025.1.3.dev22+g0184702f\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%xmode minimal\n", + "\n", + "import cupy_xarray # registers cupy accessor\n", + "import kvikio.zarr\n", + "\n", + "import numpy as np\n", + "import xarray as xr\n", + "import zarr\n", + "\n", + "%watermark -iv" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "83b1b514-eeb8-4a81-a3e8-3a7dc82ffce4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'netcdf4': \n", + " Open netCDF (.nc, .nc4 and .cdf) and most HDF5 files using netCDF4 in Xarray\n", + " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.NetCDF4BackendEntrypoint.html,\n", + " 'kvikio': \n", + " Open zarr files (.zarr) using Kvikio\n", + " Learn more at https://docs.rapids.ai/api/kvikio/stable/api/#zarr,\n", + " 'store': \n", + " Open AbstractDataStore instances in Xarray\n", + " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.StoreBackendEntrypoint.html,\n", + " 'zarr': \n", + " Open zarr files (.zarr) using zarr in Xarray\n", + " Learn more at https://docs.xarray.dev/en/stable/generated/xarray.backends.ZarrBackendEntrypoint.html}" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "xr.backends.list_engines()" + ] + }, + { + "cell_type": "markdown", + "id": "5f12848d-a5ec-4cea-9a49-4f2bcefd9114", + "metadata": { + "tags": [] + }, + "source": [ + "## Create example dataset\n", + "\n", + "- cannot be compressed" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "d481cc3b-420e-4b7c-8c5e-77d874128b12", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/user/mambaforge/envs/cupy-xarray-doc/lib/python3.11/site-packages/xarray/core/dataset.py:2503: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs\n", + " return to_zarr( # type: ignore[call-overload,misc]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "store = \"/tmp/air-temperature.zarr\"\n", + "airt = xr.tutorial.open_dataset(\"air_temperature\", engine=\"netcdf4\")\n", + "for var in airt.variables:\n", + " airt[var].encoding[\"compressors\"] = None\n", + "airt[\"scalar\"] = 12.0\n", + "airt.to_zarr(store, mode=\"w\", zarr_format=3, consolidated=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1a3d0ec7-22fb-4558-8e60-9627266e3111", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "id": "883d5507-988f-453a-b576-87bb563b540f", + "metadata": { + "tags": [] + }, + "source": [ + "## Test opening\n", + "\n", + "### Standard usage" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "4a9ba63c-0b29-4eb8-9171-965b90071496", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_72617/982297347.py:1: RuntimeWarning: Failed to open Zarr store with consolidated metadata, but successfully read with non-consolidated metadata. This is typically much slower for opening a dataset. To silence this warning, consider:\n", + "1. Consolidating metadata in this existing store with zarr.consolidate_metadata().\n", + "2. Explicitly setting consolidated=False, to avoid trying to read consolidate metadata, or\n", + "3. Explicitly setting consolidated=True, to raise an error in this case instead of falling back to try reading non-consolidated metadata.\n", + " ds_cpu = xr.open_dataset(store, engine=\"zarr\")\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ],\n",
+       "        [243.8 , 244.5 , ..., 235.3 , 239.3 ],\n",
+       "        ...,\n",
+       "        [295.9 , 296.2 , ..., 295.9 , 295.2 ],\n",
+       "        [296.29, 296.79, ..., 296.79, 296.6 ]],\n",
+       "\n",
+       "       [[242.1 , 242.7 , ..., 233.6 , 235.8 ],\n",
+       "        [243.6 , 244.1 , ..., 232.5 , 235.7 ],\n",
+       "        ...,\n",
+       "        [296.2 , 296.7 , ..., 295.5 , 295.1 ],\n",
+       "        [296.29, 297.2 , ..., 296.4 , 296.6 ]],\n",
+       "\n",
+       "       ...,\n",
+       "\n",
+       "       [[245.79, 244.79, ..., 243.99, 244.79],\n",
+       "        [249.89, 249.29, ..., 242.49, 244.29],\n",
+       "        ...,\n",
+       "        [296.29, 297.19, ..., 295.09, 294.39],\n",
+       "        [297.79, 298.39, ..., 295.49, 295.19]],\n",
+       "\n",
+       "       [[245.09, 244.29, ..., 241.49, 241.79],\n",
+       "        [249.89, 249.29, ..., 240.29, 241.69],\n",
+       "        ...,\n",
+       "        [296.09, 296.89, ..., 295.69, 295.19],\n",
+       "        [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53))\n",
+       "Coordinates:\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "array([[[241.2 , 242.5 , ..., 235.5 , 238.6 ],\n", + " [243.8 , 244.5 , ..., 235.3 , 239.3 ],\n", + " ...,\n", + " [295.9 , 296.2 , ..., 295.9 , 295.2 ],\n", + " [296.29, 296.79, ..., 296.79, 296.6 ]],\n", + "\n", + " [[242.1 , 242.7 , ..., 233.6 , 235.8 ],\n", + " [243.6 , 244.1 , ..., 232.5 , 235.7 ],\n", + " ...,\n", + " [296.2 , 296.7 , ..., 295.5 , 295.1 ],\n", + " [296.29, 297.2 , ..., 296.4 , 296.6 ]],\n", + "\n", + " ...,\n", + "\n", + " [[245.79, 244.79, ..., 243.99, 244.79],\n", + " [249.89, 249.29, ..., 242.49, 244.29],\n", + " ...,\n", + " [296.29, 297.19, ..., 295.09, 294.39],\n", + " [297.79, 298.39, ..., 295.49, 295.19]],\n", + "\n", + " [[245.09, 244.29, ..., 241.49, 241.79],\n", + " [249.89, 249.29, ..., 240.29, 241.69],\n", + " ...,\n", + " [296.09, 296.89, ..., 295.69, 295.19],\n", + " [297.69, 298.09, ..., 296.19, 295.69]]], shape=(2920, 25, 53))\n", + "Coordinates:\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds_cpu = xr.open_dataset(store, engine=\"zarr\")\n", + "print(ds_cpu.air.data.__class__)\n", + "ds_cpu.air" + ] + }, + { + "cell_type": "markdown", + "id": "95161182-6b58-4dbd-9752-9961c251be1a", + "metadata": {}, + "source": [ + "### Now with kvikio!\n", + "\n", + " - must read with `consolidated=False` (https://github.com/rapidsai/kvikio/issues/119)\n", + " - dask.from_zarr to GDSStore / open_mfdataset" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8fd27bdf-e317-4de3-891e-41d38d06dcaf", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float64')), key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None))))))\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 31MB\n",
+       "Dimensions:  (time: 2920, lat: 25, lon: 53)\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Data variables:\n",
+       "    scalar   float64 8B ...\n",
+       "    air      (time, lat, lon) float64 31MB ...\n",
+       "Attributes:\n",
+       "    Conventions:  COARDS\n",
+       "    title:        4x daily NMC reanalysis (1948)\n",
+       "    description:  Data is from NMC initialized reanalysis\\n(4x/day).  These a...\n",
+       "    platform:     Model\n",
+       "    references:   http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
" + ], + "text/plain": [ + " Size: 31MB\n", + "Dimensions: (time: 2920, lat: 25, lon: 53)\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Data variables:\n", + " scalar float64 8B ...\n", + " air (time, lat, lon) float64 31MB ...\n", + "Attributes:\n", + " Conventions: COARDS\n", + " title: 4x daily NMC reanalysis (1948)\n", + " description: Data is from NMC initialized reanalysis\\n(4x/day). These a...\n", + " platform: Model\n", + " references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly..." + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# Consolidated must be False\n", + "ds = xr.open_dataset(store, engine=\"kvikio\", consolidated=False)\n", + "print(ds.air._variable._data)\n", + "ds" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "6c939a04-1588-4693-9483-c6ad7152951a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'scalar' ()> Size: 8B\n",
+       "[1 values with dtype=float64]
" + ], + "text/plain": [ + " Size: 8B\n", + "[1 values with dtype=float64]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.scalar" + ] + }, + { + "cell_type": "markdown", + "id": "bb84a7ad-84dc-4bb3-8636-3f9416953089", + "metadata": { + "tags": [] + }, + "source": [ + "## Lazy reading" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "1ecc39b1-b788-4831-9160-5b35afb83598", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "[3869000 values with dtype=float64]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "[3869000 values with dtype=float64]\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "markdown", + "id": "7d366864-a2b3-4573-9bf7-41d1f6ee457c", + "metadata": { + "tags": [] + }, + "source": [ + "## Data load for repr" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "00205e73-9b43-4254-9cba-f75435251391", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (lon: 53)> Size: 424B\n",
+       "array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 ,\n",
+       "       280.79, 280.7 , 280.79, 281.  , 280.29, 277.7 , 273.5 , 269.  ,\n",
+       "       265.5 , 264.  , 265.2 , 268.1 , 269.79, 267.9 , 263.  , 258.1 ,\n",
+       "       254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254.  , 254.3 , 255.89,\n",
+       "       260.  , 263.  , 261.5 , 257.29, 255.5 , 258.29, 264.  , 268.7 ,\n",
+       "       270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 ,\n",
+       "       282.9 , 284.7 , 286.1 , 286.9 , 286.6 ])\n",
+       "Coordinates:\n",
+       "    lat      float32 4B 50.0\n",
+       "    time     datetime64[ns] 8B 2013-01-01\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 424B\n", + "array([277.29, 277.4 , 277.79, 278.6 , 279.5 , 280.1 , 280.6 , 280.9 ,\n", + " 280.79, 280.7 , 280.79, 281. , 280.29, 277.7 , 273.5 , 269. ,\n", + " 265.5 , 264. , 265.2 , 268.1 , 269.79, 267.9 , 263. , 258.1 ,\n", + " 254.6 , 251.8 , 249.6 , 249.89, 252.3 , 254. , 254.3 , 255.89,\n", + " 260. , 263. , 261.5 , 257.29, 255.5 , 258.29, 264. , 268.7 ,\n", + " 270.5 , 270.6 , 271.2 , 272.9 , 274.79, 276.4 , 278.2 , 280.5 ,\n", + " 282.9 , 284.7 , 286.1 , 286.9 , 286.6 ])\n", + "Coordinates:\n", + " lat float32 4B 50.0\n", + " time datetime64[ns] 8B 2013-01-01\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[\"air\"].isel(time=0, lat=10).load()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "80aa6892-8c7f-44b3-bd52-9795ec4ea6f3", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'scalar' ()> Size: 8B\n",
+       "[1 values with dtype=float64]
" + ], + "text/plain": [ + " Size: 8B\n", + "[1 values with dtype=float64]" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.scalar" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "ba48a2c0-96e0-41d7-9e07-381e05e8dc33", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "[3869000 values with dtype=float64]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "[3869000 values with dtype=float64]\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "markdown", + "id": "d0ea31d2-6c52-4346-b489-fc1e43200213", + "metadata": { + "tags": [] + }, + "source": [ + "## CuPy array on load\n", + "\n", + "Configure Zarr to use GPU memory by setting `zarr.config.enable_gpu()`.\n", + "\n", + "See https://zarr.readthedocs.io/en/stable/user-guide/gpu.html#using-gpus-with-zarr" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1b34a68a-a6b3-4273-bf7c-28814ebfce11", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "MemoryCachedArray(array=CopyOnWriteArray(array=LazilyIndexedArray(array=_ElementwiseFunctionArray(LazilyIndexedArray(array=, key=BasicIndexer((slice(None, None, None), slice(None, None, None), slice(None, None, None)))), func=functools.partial(, scale_factor=0.01, add_offset=None, dtype=), dtype=dtype('float64')), key=BasicIndexer((0, 10, slice(None, None, None))))))" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds[\"air\"].isel(time=0, lat=10).variable._data" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "db69559c-1fde-4b3b-914d-87d8437ec256", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "with zarr.config.enable_gpu():\n", + " print(type(ds[\"air\"].isel(time=0, lat=10).load().data))" + ] + }, + { + "cell_type": "markdown", + "id": "d34a5cce-7bbc-408f-b643-05da1e121c78", + "metadata": { + "tags": [] + }, + "source": [ + "## Load to host" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "84094bc6-7884-414a-89cf-4526c3a54aea", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "zarr.config.enable_gpu()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "09b40d7d-ed38-4a50-af11-c2e5f0242a97", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "[3869000 values with dtype=float64]\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "[3869000 values with dtype=float64]\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.air" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "615efd76-2194-4604-9ab8-61499e7d725d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + } + ], + "source": [ + "print(type(ds[\"air\"].data))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "eeb9ad78-1353-464f-8419-4c44ea499f17", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "numpy.ndarray" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds.air.as_numpy().data)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "140fe3e2-ea9b-445d-8401-5c624384c182", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "cupy.ndarray" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(ds.air.mean(\"time\").load().data)" + ] + }, + { + "cell_type": "markdown", + "id": "cab539a7-d952-4b38-b515-712c52c62501", + "metadata": { + "tags": [] + }, + "source": [ + "## Doesn't work: Chunk with dask" + ] + }, + { + "cell_type": "markdown", + "id": "62c084eb-8df4-4b7f-a187-a736d68d430d", + "metadata": {}, + "source": [ + "`meta` is wrong" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "68f93bfe-fe56-488a-a10b-dc4f48029367", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "dask.array\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds.chunk(time=10).air" + ] + }, + { + "cell_type": "markdown", + "id": "3f4c72f6-22e7-4e99-9f4e-2524d6ab4226", + "metadata": {}, + "source": [ + "`dask.array.core.getter` calls `np.asarray` on each chunk.\n", + "\n", + "This calls `ImplicitToExplicitIndexingAdapter.__array__` which calls `np.asarray(cupy.array)` which raises.\n", + "\n", + "Xarray uses `.get_duck_array` internally to remove these adapters. We might need to add\n", + "```python\n", + "# handle xarray internal classes that might wrap cupy\n", + "if hasattr(c, \"get_duck_array\"):\n", + " c = c.get_duck_array()\n", + "else:\n", + " c = np.asarray(c)\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e1256d03-9701-433a-8291-80dc8dccffce", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "False" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from dask.utils import is_arraylike\n", + "\n", + "data = ds.air.variable._data\n", + "is_arraylike(data)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "308affa5-9fb9-4638-989b-97aac2604c16", + "metadata": {}, + "outputs": [], + "source": [ + "from xarray.core.indexing import ImplicitToExplicitIndexingAdapter" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "985cd2f8-406e-4e9e-8017-42efb16aa40e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([[[241.2 , 242.5 , 243.5 , ..., 232.8 , 235.5 , 238.6 ],\n", + " [243.8 , 244.5 , 244.7 , ..., 232.8 , 235.3 , 239.3 ],\n", + " [250. , 249.8 , 248.89, ..., 233.2 , 236.39, 241.7 ],\n", + " ...,\n", + " [296.6 , 296.2 , 296.4 , ..., 295.4 , 295.1 , 294.7 ],\n", + " [295.9 , 296.2 , 296.79, ..., 295.9 , 295.9 , 295.2 ],\n", + " [296.29, 296.79, 297.1 , ..., 296.9 , 296.79, 296.6 ]],\n", + "\n", + " [[242.1 , 242.7 , 243.1 , ..., 232. , 233.6 , 235.8 ],\n", + " [243.6 , 244.1 , 244.2 , ..., 231. , 232.5 , 235.7 ],\n", + " [253.2 , 252.89, 252.1 , ..., 230.8 , 233.39, 238.5 ],\n", + " ...,\n", + " [296.4 , 295.9 , 296.2 , ..., 295.4 , 295.1 , 294.79],\n", + " [296.2 , 296.7 , 296.79, ..., 295.6 , 295.5 , 295.1 ],\n", + " [296.29, 297.2 , 297.4 , ..., 296.4 , 296.4 , 296.6 ]],\n", + "\n", + " [[242.3 , 242.2 , 242.3 , ..., 234.3 , 236.1 , 238.7 ],\n", + " [244.6 , 244.39, 244. , ..., 230.3 , 232. , 235.7 ],\n", + " [256.2 , 255.5 , 254.2 , ..., 231.2 , 233.2 , 238.2 ],\n", + " ...,\n", + " [295.6 , 295.4 , 295.4 , ..., 296.29, 295.29, 295. ],\n", + " [296.2 , 296.5 , 296.29, ..., 296.4 , 296. , 295.6 ],\n", + " [296.4 , 296.29, 296.4 , ..., 297. , 297. , 296.79]],\n", + "\n", + " ...,\n", + "\n", + " [[243.49, 242.99, 242.09, ..., 244.19, 244.49, 244.89],\n", + " [249.09, 248.99, 248.59, ..., 240.59, 241.29, 242.69],\n", + " [262.69, 262.19, 261.69, ..., 239.39, 241.69, 245.19],\n", + " ...,\n", + " [294.79, 295.29, 297.49, ..., 295.49, 295.39, 294.69],\n", + " [296.79, 297.89, 298.29, ..., 295.49, 295.49, 294.79],\n", + " [298.19, 299.19, 298.79, ..., 296.09, 295.79, 295.79]],\n", + "\n", + " [[245.79, 244.79, 243.49, ..., 243.29, 243.99, 244.79],\n", + " [249.89, 249.29, 248.49, ..., 241.29, 242.49, 244.29],\n", + " [262.39, 261.79, 261.29, ..., 240.49, 243.09, 246.89],\n", + " ...,\n", + " [293.69, 293.89, 295.39, ..., 295.09, 294.69, 294.29],\n", + " [296.29, 297.19, 297.59, ..., 295.29, 295.09, 294.39],\n", + " [297.79, 298.39, 298.49, ..., 295.69, 295.49, 295.19]],\n", + "\n", + " [[245.09, 244.29, 243.29, ..., 241.69, 241.49, 241.79],\n", + " [249.89, 249.29, 248.39, ..., 239.59, 240.29, 241.69],\n", + " [262.99, 262.19, 261.39, ..., 239.89, 242.59, 246.29],\n", + " ...,\n", + " [293.79, 293.69, 295.09, ..., 295.29, 295.09, 294.69],\n", + " [296.09, 296.89, 297.19, ..., 295.69, 295.69, 295.19],\n", + " [297.69, 298.09, 298.09, ..., 296.49, 296.19, 295.69]]],\n", + " shape=(2920, 25, 53))" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ImplicitToExplicitIndexingAdapter(data).get_duck_array()" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "fa8ef4f7-5014-476f-b4c0-ec2f9abdb6e2", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.", + "output_type": "error", + "traceback": [ + "\u001b[31mTypeError\u001b[39m\u001b[31m:\u001b[39m Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.\n" + ] + } + ], + "source": [ + "ds.chunk(time=10).air.compute()" + ] + }, + { + "cell_type": "markdown", + "id": "17dc1bf6-7548-4eee-a5f3-ebcc20d41567", + "metadata": {}, + "source": [ + "### explicit meta" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "cdd4b4e6-d69a-4898-964a-0e6096ca1942", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.DataArray 'air' (time: 2920, lat: 25, lon: 53)> Size: 31MB\n",
+       "dask.array<xarray-air, shape=(2920, 25, 53), dtype=float64, chunksize=(10, 25, 53), chunktype=numpy.ndarray>\n",
+       "Coordinates:\n",
+       "  * lat      (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n",
+       "  * time     (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n",
+       "  * lon      (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n",
+       "Attributes:\n",
+       "    long_name:     4xDaily Air temperature at sigma level 995\n",
+       "    units:         degK\n",
+       "    precision:     2\n",
+       "    GRIB_id:       11\n",
+       "    GRIB_name:     TMP\n",
+       "    var_desc:      Air temperature\n",
+       "    dataset:       NMC Reanalysis\n",
+       "    level_desc:    Surface\n",
+       "    statistic:     Individual Obs\n",
+       "    parent_stat:   Other\n",
+       "    actual_range:  [185.16000366210938, 322.1000061035156]
" + ], + "text/plain": [ + " Size: 31MB\n", + "dask.array\n", + "Coordinates:\n", + " * lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0\n", + " * time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00\n", + " * lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0\n", + "Attributes:\n", + " long_name: 4xDaily Air temperature at sigma level 995\n", + " units: degK\n", + " precision: 2\n", + " GRIB_id: 11\n", + " GRIB_name: TMP\n", + " var_desc: Air temperature\n", + " dataset: NMC Reanalysis\n", + " level_desc: Surface\n", + " statistic: Individual Obs\n", + " parent_stat: Other\n", + " actual_range: [185.16000366210938, 322.1000061035156]" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import cupy as cp\n", + "\n", + "chunked = ds.chunk(time=10, from_array_kwargs={\"meta\": cp.array([])})\n", + "chunked.air" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "74f80d94-ebb6-43c3-9411-79e0442d894e", + "metadata": {}, + "outputs": [ + { + "ename": "TypeError", + "evalue": "Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.", + "output_type": "error", + "traceback": [ + "\u001b[31mTypeError\u001b[39m\u001b[31m:\u001b[39m Implicit conversion to a NumPy array is not allowed. Please use `.get()` to construct a NumPy array explicitly.\n" + ] + } + ], + "source": [ + "chunked.compute()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ac543634-80be-4e44-83e8-9e95a4955030", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "cupy-xarray-doc", + "language": "python", + "name": "cupy-xarray-doc" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + }, + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "state": {}, + "version_major": 2, + "version_minor": 0 + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/index.md b/docs/index.md index 3025717..3a52b2e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -59,6 +59,7 @@ Large parts of this documentations comes from [SciPy 2023 Xarray on GPUs tutoria examples/04_high-level-api examples/05_apply-ufunc examples/06_real-example + examples/07_kvikio **Tutorials & Presentations**: diff --git a/pyproject.toml b/pyproject.toml index d98b3fe..2d5094e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,9 @@ test = [ "pytest", ] +[project.entry-points."xarray.backends"] +kvikio = "cupy_xarray.kvikio:KvikioBackendEntrypoint" + [tool.ruff] line-length = 100 # E501 (line-too-long) exclude = [