From 9c76c22e20c595ea03aa6fe910ad90a8acc12858 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Tue, 18 Jun 2024 18:39:34 +0100 Subject: [PATCH] Add subsetting option to to_numpy method --- src/earthkit/data/core/fieldlist.py | 83 ++++++++--- tests/grib/test_grib_values.py | 216 ++++++++++++++++++++++++++++ 2 files changed, 277 insertions(+), 22 deletions(-) diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index 7d6167df..de3d5739 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -126,7 +126,7 @@ def _metadata(self): self.__metadata = self._make_metadata() return self.__metadata - def to_numpy(self, flatten=False, dtype=None): + def to_numpy(self, flatten=False, dtype=None, index=None): r"""Return the values stored in the field as an ndarray. Parameters @@ -137,6 +137,9 @@ def to_numpy(self, flatten=False, dtype=None): dtype: str, numpy.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: ndarray indexing object, optional + The index of the values and to be extracted. When it + is None all the values are extracted Returns ------- @@ -148,10 +151,12 @@ def to_numpy(self, flatten=False, dtype=None): v = numpy_backend().to_array(v, self.raw_values_backend) shape = self._required_shape(flatten) if shape != v.shape: - return v.reshape(shape) + v = v.reshape(shape) + if index is not None: + v = v[index] return v - def to_array(self, flatten=False, dtype=None, array_backend=None): + def to_array(self, flatten=False, dtype=None, array_backend=None, index=None): r"""Return the values stored in the field in the format of :attr:`array_backend`. @@ -163,6 +168,9 @@ def to_array(self, flatten=False, dtype=None, array_backend=None): dtype: str, array.dtype or None Typecode or data-type of the array. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: array indexing object, optional + The index of the values and to be extracted. When it + is None all the values are extracted Returns ------- @@ -177,17 +185,21 @@ def to_array(self, flatten=False, dtype=None, array_backend=None): ) shape = self._required_shape(flatten) if shape != v.shape: - return self._array_backend.array_ns.reshape(v, shape) + v = self._array_backend.array_ns.reshape(v, shape) + if index is not None: + v = v[index] return v - def _required_shape(self, flatten): - return self.shape if not flatten else (math.prod(self.shape),) + def _required_shape(self, flatten, shape=None): + if shape is None: + shape = self.shape + return shape if not flatten else (math.prod(shape),) def _array_matches(self, array, flatten=False, dtype=None): shape = self._required_shape(flatten) return shape == array.shape and (dtype is None or dtype == array.dtype) - def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): + def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=None): r"""Return the values and/or the geographical coordinates for each grid point. Parameters @@ -201,6 +213,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): dtype: str, array.dtype or None Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: array indexing object, optional + The index of the values and or the latitudes/longitudes to be extracted. When it + is None all the values and/or coordinates are extracted. Returns ------- @@ -252,18 +267,22 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): if k not in _keys: raise ValueError(f"data: invalid argument: {k}") - r = [self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) for k in keys] - shape = self._required_shape(flatten) - if shape != r[0].shape: - # r = [x.reshape(shape) for x in r] - r = [self._array_backend.array_ns.reshape(x, shape) for x in r] + r = [] + for k in keys: + v = self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) + shape = self._required_shape(flatten) + if shape != v.shape: + v = self._array_backend.array_ns.reshape(v, shape) + if index is not None: + v = v[index] + r.append(v) if len(r) == 1: return r[0] else: return self._array_backend.array_ns.stack(r) - def to_points(self, flatten=False, dtype=None): + def to_points(self, flatten=False, dtype=None, index=None): r"""Return the geographical coordinates in the data's original Coordinate Reference System (CRS). @@ -276,6 +295,9 @@ def to_points(self, flatten=False, dtype=None): Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: array indexing object, optional + The index of the coordinates to be extracted. When it is None + all the values are extracted. Returns ------- @@ -303,14 +325,17 @@ def to_points(self, flatten=False, dtype=None): if shape != x.shape: x = self._array_backend.array_ns.reshape(x, shape) y = self._array_backend.array_ns.reshape(y, shape) + if index is not None: + x = x[index] + y = y[index] return dict(x=x, y=y) elif self.projection().CARTOPY_CRS == "PlateCarree": - lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype) + lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index) return dict(x=lon, y=lat) else: raise ValueError("to_points(): geographical coordinates in original CRS are not available") - def to_latlon(self, flatten=False, dtype=None): + def to_latlon(self, flatten=False, dtype=None, index=None): r"""Return the latitudes/longitudes of all the gridpoints in the field. Parameters @@ -322,6 +347,9 @@ def to_latlon(self, flatten=False, dtype=None): Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: array indexing object, optional + The index of the latitudes/longitudes to be extracted. When it is None + all the values are extracted. Returns ------- @@ -335,7 +363,7 @@ def to_latlon(self, flatten=False, dtype=None): to_points """ - lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype) + lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index) return dict(lat=lat, lon=lon) def grid_points(self): @@ -869,7 +897,7 @@ def to_array(self, **kwargs): @property def values(self): - r"""array-likr: Get all the fields' values as a 2D array. It is formed as the array of + r"""array-like: Get all the fields' values as a 2D array. It is formed as the array of :obj:`GribField.values ` per field. See Also @@ -893,7 +921,13 @@ def values(self): x = [f.values for f in self] return self._array_backend.array_ns.stack(x) - def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): + def data( + self, + keys=("lat", "lon", "value"), + flatten=False, + dtype=None, + index=None, + ): r"""Return the values and/or the geographical coordinates. Only works when all the fields have the same grid geometry. @@ -910,6 +944,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): Typecode or data-type of the arrays. When it is :obj:`None` the default type used by the underlying data accessor is used. For GRIB it is ``float64``. + index: array indexing object, optional + The index of the values to be extracted from each field. When it is None all the + values are extracted. Returns ------- @@ -962,7 +999,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): keys = [keys] if "lat" in keys or "lon" in keys: - latlon = self[0].to_latlon(flatten=flatten, dtype=dtype) + latlon = self[0].to_latlon(flatten=flatten, dtype=dtype, index=index) r = [] for k in keys: @@ -971,10 +1008,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None): elif k == "lon": r.append(latlon["lon"]) elif k == "value": - r.extend([f.to_array(flatten=flatten, dtype=dtype) for f in self]) + r.extend([f.to_array(flatten=flatten, dtype=dtype, index=index) for f in self]) else: raise ValueError(f"data: invalid argument: {k}") - return self._array_backend.array_ns.stack(r) elif len(self) == 0: @@ -1226,11 +1262,14 @@ def to_points(self, **kwargs): else: raise ValueError("Fields do not have the same grid geometry") - def to_latlon(self, **kwargs): + def to_latlon(self, index=None, **kwargs): r"""Return the latitudes/longitudes shared by all the fields. Parameters ---------- + index: array indexing object, optional + The index of the latitudes/longitudes to be extracted. When it is None + all the values are extracted. **kwargs: dict, optional Keyword arguments passed to :meth:`Field.to_latlon() ` diff --git a/tests/grib/test_grib_values.py b/tests/grib/test_grib_values.py index 716e77d3..14faac06 100644 --- a/tests/grib/test_grib_values.py +++ b/tests/grib/test_grib_values.py @@ -249,6 +249,123 @@ def test_grib_to_numpy_18_dtype(fl_type, array_backend, dtype): assert v.dtype == dtype +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_numpy_1_index(fl_type, array_backend): + ds = load_grib_data("test_single.grib", fl_type, array_backend, folder="data") + + eps = 1e-5 + + v = ds[0].to_numpy(flatten=True, index=[0, -1]) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (2,) + assert np.allclose(v, [260.43560791015625, 227.18560791015625]) + + v = ds[0].to_numpy(flatten=True, index=slice(None, None)) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + check_array( + v, + (84,), + first=260.43560791015625, + last=227.18560791015625, + meanv=274.36566743396577, + eps=eps, + ) + + v = ds[0].to_numpy(index=(slice(None, 2), slice(None, 3))) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (2, 3) + assert np.allclose( + v, + [ + [260.43560791, 260.43560791, 260.43560791], + [280.81060791, 277.06060791, 284.43560791], + ], + ) + + +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) +def test_grib_to_numpy_18_index(fl_type, array_backend): + ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) + + eps = 1e-5 + + v = ds.to_numpy(flatten=True, index=[0, -1]) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18, 2) + vf0 = v[0].flatten() + assert np.allclose(vf0, [272.5642, 240.56417846679688]) + vf15 = v[15].flatten() + assert np.allclose(vf15, [226.6531524658203, 206.6531524658203]) + + v = ds.to_numpy(flatten=True, index=slice(None, 2)) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18, 2) + vf0 = v[0].flatten() + assert np.allclose(vf0, [272.56417847, 272.56417847]) + vf15 = v[15].flatten() + assert np.allclose(vf15, [226.65315247, 226.65315247]) + + v = ds.to_numpy(flatten=True, index=slice(None, None)) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18, 84) + vf0 = v[0].flatten() + check_array( + vf0, + (84,), + first=272.5642, + last=240.56417846679688, + meanv=279.70703560965404, + eps=eps, + ) + + vf15 = v[15].flatten() + check_array( + vf15, + (84,), + first=226.6531524658203, + last=206.6531524658203, + meanv=227.84362865629652, + eps=eps, + ) + + v = ds.to_numpy(index=(slice(None, 2), slice(None, 3))) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18, 2, 3) + vf0 = v[0].flatten() + assert np.allclose( + vf0, + [ + 272.56417847, + 272.56417847, + 272.56417847, + 288.56417847, + 296.56417847, + 288.56417847, + ], + ) + vf15 = v[15].flatten() + assert np.allclose( + vf15, + [ + 226.65315247, + 226.65315247, + 226.65315247, + 230.65315247, + 230.65315247, + 230.65315247, + ], + ) + + @pytest.mark.parametrize("fl_type", FL_TYPES) @pytest.mark.parametrize("array_backend", ["numpy"]) @pytest.mark.parametrize( @@ -354,6 +471,105 @@ def test_grib_fieldlist_data(fl_type, array_backend, kwarg, expected_shape, expe assert np.allclose(d[2], latlon["lon"]) +@pytest.mark.parametrize("fl_type", FL_TYPES) +@pytest.mark.parametrize("array_backend", ["numpy"]) +def test_grib_fieldlist_data_index(fl_type, array_backend): + ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) + + eps = 1e-5 + + latlon = ds.to_latlon(flatten=True) + lat = latlon["lat"] + lon = latlon["lon"] + + index = [0, -1] + v = ds.data(flatten=True, index=index) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18 + 2, 2) + assert np.allclose(v[0].flatten(), lat[index]) + assert np.allclose(v[1].flatten(), lon[index]) + vf0 = v[2 + 0].flatten() + assert np.allclose(vf0, [272.5642, 240.56417846679688]) + vf15 = v[2 + 15].flatten() + assert np.allclose(vf15, [226.6531524658203, 206.6531524658203]) + + index = slice(None, 2) + v = ds.data(flatten=True, index=index) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18 + 2, 2) + assert np.allclose(v[0].flatten(), lat[index]) + assert np.allclose(v[1].flatten(), lon[index]) + vf0 = v[2 + 0].flatten() + assert np.allclose(vf0, [272.56417847, 272.56417847]) + vf15 = v[2 + 15].flatten() + assert np.allclose(vf15, [226.65315247, 226.65315247]) + + index = slice(None, None) + v = ds.data(flatten=True, index=index) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (18 + 2, 84) + assert np.allclose(v[0].flatten(), lat) + assert np.allclose(v[1].flatten(), lon) + vf0 = v[2 + 0].flatten() + check_array( + vf0, + (84,), + first=272.5642, + last=240.56417846679688, + meanv=279.70703560965404, + eps=eps, + ) + + vf15 = v[2 + 15].flatten() + check_array( + vf15, + (84,), + first=226.6531524658203, + last=206.6531524658203, + meanv=227.84362865629652, + eps=eps, + ) + + index = (slice(None, 2), slice(None, 3)) + v = ds.data(index=index) + assert isinstance(v, np.ndarray) + assert v.dtype == np.float64 + assert v.shape == (2 + 18, 2, 3) + latlon = ds.to_latlon() + lat = latlon["lat"] + lon = latlon["lon"] + assert np.allclose(v[0], lat[index]) + assert np.allclose(v[1], lon[index]) + + vf0 = v[2 + 0].flatten() + assert np.allclose( + vf0, + [ + 272.56417847, + 272.56417847, + 272.56417847, + 288.56417847, + 296.56417847, + 288.56417847, + ], + ) + vf15 = v[2 + 15].flatten() + assert np.allclose( + vf15, + [ + 226.65315247, + 226.65315247, + 226.65315247, + 230.65315247, + 230.65315247, + 230.65315247, + ], + ) + + @pytest.mark.parametrize("fl_type", FL_TYPES) @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) def test_grib_values_with_missing(fl_type, array_backend):