Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add subsetting option to data access methods #407

Merged
merged 3 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 61 additions & 22 deletions src/earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _metadata(self):
self.__metadata = self._make_metadata()
return self.__metadata

def to_numpy(self, flatten=False, dtype=None):
def to_numpy(self, flatten=False, dtype=None, index=None):
r"""Return the values stored in the field as an ndarray.

Parameters
Expand All @@ -137,6 +137,9 @@ def to_numpy(self, flatten=False, dtype=None):
dtype: str, numpy.dtype or None
Typecode or data-type of the array. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: ndarray indexing object, optional
The index of the values and to be extracted. When it
is None all the values are extracted

Returns
-------
Expand All @@ -148,10 +151,12 @@ def to_numpy(self, flatten=False, dtype=None):
v = numpy_backend().to_array(v, self.raw_values_backend)
shape = self._required_shape(flatten)
if shape != v.shape:
return v.reshape(shape)
v = v.reshape(shape)
if index is not None:
v = v[index]
return v

def to_array(self, flatten=False, dtype=None, array_backend=None):
def to_array(self, flatten=False, dtype=None, array_backend=None, index=None):
r"""Return the values stored in the field in the
format of :attr:`array_backend`.

Expand All @@ -163,6 +168,9 @@ def to_array(self, flatten=False, dtype=None, array_backend=None):
dtype: str, array.dtype or None
Typecode or data-type of the array. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: array indexing object, optional
The index of the values and to be extracted. When it
is None all the values are extracted

Returns
-------
Expand All @@ -177,17 +185,21 @@ def to_array(self, flatten=False, dtype=None, array_backend=None):
)
shape = self._required_shape(flatten)
if shape != v.shape:
return self._array_backend.array_ns.reshape(v, shape)
v = self._array_backend.array_ns.reshape(v, shape)
if index is not None:
v = v[index]
return v

def _required_shape(self, flatten):
return self.shape if not flatten else (math.prod(self.shape),)
def _required_shape(self, flatten, shape=None):
if shape is None:
shape = self.shape
return shape if not flatten else (math.prod(shape),)

def _array_matches(self, array, flatten=False, dtype=None):
shape = self._required_shape(flatten)
return shape == array.shape and (dtype is None or dtype == array.dtype)

def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None, index=None):
r"""Return the values and/or the geographical coordinates for each grid point.

Parameters
Expand All @@ -201,6 +213,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
dtype: str, array.dtype or None
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is ``float64``.
index: array indexing object, optional
The index of the values and or the latitudes/longitudes to be extracted. When it
is None all the values and/or coordinates are extracted.

Returns
-------
Expand Down Expand Up @@ -252,18 +267,22 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
if k not in _keys:
raise ValueError(f"data: invalid argument: {k}")

r = [self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1]) for k in keys]
shape = self._required_shape(flatten)
if shape != r[0].shape:
# r = [x.reshape(shape) for x in r]
r = [self._array_backend.array_ns.reshape(x, shape) for x in r]
r = []
for k in keys:
v = self._to_array(_keys[k][0](dtype=dtype), source_backend=_keys[k][1])
shape = self._required_shape(flatten)
if shape != v.shape:
v = self._array_backend.array_ns.reshape(v, shape)
if index is not None:
v = v[index]
r.append(v)

if len(r) == 1:
return r[0]
else:
return self._array_backend.array_ns.stack(r)

def to_points(self, flatten=False, dtype=None):
def to_points(self, flatten=False, dtype=None, index=None):
r"""Return the geographical coordinates in the data's original
Coordinate Reference System (CRS).

Expand All @@ -276,6 +295,9 @@ def to_points(self, flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the coordinates to be extracted. When it is None
all the values are extracted.

Returns
-------
Expand Down Expand Up @@ -303,14 +325,17 @@ def to_points(self, flatten=False, dtype=None):
if shape != x.shape:
x = self._array_backend.array_ns.reshape(x, shape)
y = self._array_backend.array_ns.reshape(y, shape)
if index is not None:
x = x[index]
y = y[index]
return dict(x=x, y=y)
elif self.projection().CARTOPY_CRS == "PlateCarree":
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype)
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index)
return dict(x=lon, y=lat)
else:
raise ValueError("to_points(): geographical coordinates in original CRS are not available")

def to_latlon(self, flatten=False, dtype=None):
def to_latlon(self, flatten=False, dtype=None, index=None):
r"""Return the latitudes/longitudes of all the gridpoints in the field.

Parameters
Expand All @@ -322,6 +347,9 @@ def to_latlon(self, flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the latitudes/longitudes to be extracted. When it is None
all the values are extracted.

Returns
-------
Expand All @@ -335,7 +363,7 @@ def to_latlon(self, flatten=False, dtype=None):
to_points

"""
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype)
lon, lat = self.data(("lon", "lat"), flatten=flatten, dtype=dtype, index=index)
return dict(lat=lat, lon=lon)

def grid_points(self):
Expand Down Expand Up @@ -869,7 +897,7 @@ def to_array(self, **kwargs):

@property
def values(self):
r"""array-likr: Get all the fields' values as a 2D array. It is formed as the array of
r"""array-like: Get all the fields' values as a 2D array. It is formed as the array of
:obj:`GribField.values <data.readers.grib.codes.GribField.values>` per field.

See Also
Expand All @@ -893,7 +921,13 @@ def values(self):
x = [f.values for f in self]
return self._array_backend.array_ns.stack(x)

def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
def data(
self,
keys=("lat", "lon", "value"),
flatten=False,
dtype=None,
index=None,
):
r"""Return the values and/or the geographical coordinates.

Only works when all the fields have the same grid geometry.
Expand All @@ -910,6 +944,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
Typecode or data-type of the arrays. When it is :obj:`None` the default
type used by the underlying data accessor is used. For GRIB it is
``float64``.
index: array indexing object, optional
The index of the values to be extracted from each field. When it is None all the
values are extracted.

Returns
-------
Expand Down Expand Up @@ -962,7 +999,7 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
keys = [keys]

if "lat" in keys or "lon" in keys:
latlon = self[0].to_latlon(flatten=flatten, dtype=dtype)
latlon = self[0].to_latlon(flatten=flatten, dtype=dtype, index=index)

r = []
for k in keys:
Expand All @@ -971,10 +1008,9 @@ def data(self, keys=("lat", "lon", "value"), flatten=False, dtype=None):
elif k == "lon":
r.append(latlon["lon"])
elif k == "value":
r.extend([f.to_array(flatten=flatten, dtype=dtype) for f in self])
r.extend([f.to_array(flatten=flatten, dtype=dtype, index=index) for f in self])
else:
raise ValueError(f"data: invalid argument: {k}")

return self._array_backend.array_ns.stack(r)

elif len(self) == 0:
Expand Down Expand Up @@ -1226,11 +1262,14 @@ def to_points(self, **kwargs):
else:
raise ValueError("Fields do not have the same grid geometry")

def to_latlon(self, **kwargs):
def to_latlon(self, index=None, **kwargs):
r"""Return the latitudes/longitudes shared by all the fields.

Parameters
----------
index: array indexing object, optional
The index of the latitudes/longitudes to be extracted. When it is None
all the values are extracted.
**kwargs: dict, optional
Keyword arguments passed to
:meth:`Field.to_latlon() <data.core.fieldlist.Field.to_latlon>`
Expand Down
Loading
Loading