Skip to content

Commit

Permalink
Implement a generic to_fieldlist method instead of to_decoded (#258)
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz authored Nov 14, 2023
1 parent e5042d6 commit fe88303
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 97 deletions.
60 changes: 44 additions & 16 deletions earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,11 +533,6 @@ def _attributes(self, names):
result[name] = self._metadata.get(name, None)
return result

def to_decoded(self, **kwargs):
from earthkit.data.sources.numpy_list import NumpyField

return NumpyField(self.to_numpy(**kwargs), self.metadata())


class FieldList(Index):
r"""Represents a list of :obj:`Field` \s."""
Expand Down Expand Up @@ -1195,26 +1190,59 @@ def write(self, f):
for s in self:
s.write(f)

def to_decoded(self, **kwargs):
r"""Convert the FieldList into a :class:`NumpyFieldList` storing all the
data in memory.
def to_fieldlist(self, backend, **kwargs):
r"""Convert to a new :class:`FieldList` based on the ``backend``.
In the resulting object each field is represented by an ndarray storing the
field values and a :class:`MetaData` object holding the field metadata. The
shape and dtype of the ndarray is controlled by the ``kwargs``.
Internally, the generated :class:`NumpyFieldList` will always store all the
field values in a single ndarray.
When the :class:`FieldList` is already in the required format no new
:class:`FieldList` is created but the current one is returned.
Parameters
----------
backend: str
Specifies the backend for the generated fieldlist. The supported values are as follows:
- "numpy": the generated fieldlist is a :class:`NumpyFieldList`, which represents
each field by an ndarray storing the field values and a :class:`MetaData` object holding
the field metadata. The shape and dtype of the ndarray is controlled by the ``kwargs``.
Please note that generated :class:`NumpyFieldList` stores all the field values in
a single ndarray.
**kwargs: dict, optional
Keyword arguments passed to :obj:`to_numpy`
When ``backend`` is "numpy" ``kwargs`` are passed to :obj:`to_numpy` to
extract the field values the resulting object will store.
Returns
-------
:class:`NumpyFieldList`
:class:`FieldList`
- the current :class:`FieldList` if it is already in the required format
- :class:`NumpyFieldList` when ``backend`` is "numpy"
Examples
--------
The following example will convert a fieldlist read from a file into a
:class:`NumpyFieldList` storing single precision field values.
>>> import numpy as np
>>> import earthkit.data
>>> ds = earthkit.data.from_source("file", "docs/examples/tuv_pl.grib")
>>> ds.path
'docs/examples/tuv_pl.grib'
>>> r = ds.to_fieldlist("numpy", dtype=np.float32)
>>> r
NumpyFieldList(fields=18)
>>> hasattr(r, "path")
False
>>> r.to_numpy().dtype
dtype('float32')
"""
converter = fieldlist_converters.get(backend, None)
if converter is not None:
return getattr(self, converter)(**kwargs)

def _to_numpy_fieldlist(self, **kwargs):
md = [f.metadata() for f in self]
return self.from_numpy(self.to_numpy(**kwargs), md)


fieldlist_converters = {"numpy": "_to_numpy_fieldlist"}
3 changes: 3 additions & 0 deletions earthkit/data/sources/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,9 @@ def to_numpy(self, **kwargs):
def values(self):
return self._reader.values

def to_fieldlist(self, *args, **kwargs):
return self._reader.to_fieldlist(*args, **kwargs)

def save(self, path):
return self._reader.save(path)

Expand Down
12 changes: 3 additions & 9 deletions earthkit/data/sources/numpy_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,6 @@ def _values(self, dtype=None):
else:
return self._array.astype(dtype, copy=False)

def to_decoded(self, **kwargs):
if self._array_matches(self._array, **kwargs):
return self
else:
return NumpyField(self.to_numpy(**kwargs), self.metadata())

def __repr__(self):
return f"{self.__class__.__name__}()"

Expand Down Expand Up @@ -124,7 +118,7 @@ def merge(cls, sources):
def __repr__(self):
return f"{self.__class__.__name__}(fields={len(self)})"

def to_decoded(self, **kwargs):
def _to_numpy_fieldlist(self, **kwargs):
if self[0]._array_matches(self._array[0], **kwargs):
return self
else:
Expand Down Expand Up @@ -167,12 +161,12 @@ class NumpyFieldList(NumpyFieldListCore):
r"""Represents a list of :obj:`NumpyField <data.sources.numpy_list.NumpyField>`\ s.
The preferred way to create a NumpyFieldList is to use either the
static :obj:`from_numpy` method or the :obj:`to_decoded` method.
static :obj:`from_numpy` method or the :obj:`to_fieldlist` method.
See Also
--------
from_numpy
to_decoded
to_fieldlist
"""

Expand Down
3 changes: 0 additions & 3 deletions earthkit/data/sources/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ def __init__(self, reader, **kwargs):
def __iter__(self):
return iter(self._reader)

def to_decoded(self, **kwargs):
return self._reader.to_decoded(**kwargs)


class StreamSource(Source):
def __init__(self, stream, group_by=None, **kwargs):
Expand Down
81 changes: 22 additions & 59 deletions tests/grib/test_grib_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,14 @@ def test_grib_from_stream_group_by(group_by):
for i, f in enumerate(fs):
assert len(f) == 3
assert f.metadata(("param", "level")) == ref[i]
assert f.to_decoded() is not f
assert f.to_fieldlist("numpy") is not f

# stream consumed, no data is available
assert sum([1 for _ in fs]) == 0


@pytest.mark.parametrize(
"decode_kwargs,expected_shape",
"convert_kwargs,expected_shape",
[
({}, (3, 7, 12)),
(None, (3, 7, 12)),
Expand All @@ -75,7 +75,7 @@ def test_grib_from_stream_group_by(group_by):
({"flatten": True}, (3, 84)),
],
)
def test_grib_from_stream_group_by_decode(decode_kwargs, expected_shape):
def test_grib_from_stream_group_by_convert_to_numpy(convert_kwargs, expected_shape):
group_by = "level"
with open(earthkit_examples_file("test6.grib"), "rb") as stream:
ds = from_source("stream", stream, group_by=group_by)
Expand All @@ -89,16 +89,16 @@ def test_grib_from_stream_group_by_decode(decode_kwargs, expected_shape):
[("t", 850), ("u", 850), ("v", 850)],
]

if decode_kwargs is None:
decode_kwargs = {}
if convert_kwargs is None:
convert_kwargs = {}

for i, f in enumerate(ds):
df = f.to_decoded(**decode_kwargs)
df = f.to_fieldlist("numpy", **convert_kwargs)
assert len(df) == 3
assert df.metadata(("param", "level")) == ref[i]
assert df._array.shape == expected_shape
assert df.to_numpy(**decode_kwargs).shape == expected_shape
assert df.to_decoded(**decode_kwargs) is df
assert df.to_numpy(**convert_kwargs).shape == expected_shape
assert df.to_fieldlist("numpy", **convert_kwargs) is df

# stream consumed, no data is available
assert sum([1 for _ in ds]) == 0
Expand Down Expand Up @@ -128,43 +128,6 @@ def test_grib_from_stream_single_batch():
assert sum([1 for _ in ds]) == 0


@pytest.mark.parametrize(
"decode_kwargs,expected_shape",
[
({}, (7, 12)),
(None, (7, 12)),
(None, (7, 12)),
({"flatten": False}, (7, 12)),
({"flatten": True}, (84,)),
],
)
def test_grib_from_stream_single_batch_decode(decode_kwargs, expected_shape):
with open(earthkit_examples_file("test6.grib"), "rb") as stream:
ds = from_source("stream", stream)

ref = [
("t", 1000),
("u", 1000),
("v", 1000),
("t", 850),
("u", 850),
("v", 850),
]

if decode_kwargs is None:
decode_kwargs = {}

for i, f in enumerate(ds):
df = f.to_decoded(**decode_kwargs)
assert df.metadata(("param", "level")) == ref[i], i
assert df._array.shape == expected_shape, i
assert df.to_numpy(**decode_kwargs).shape == expected_shape, i
assert df.to_decoded(**decode_kwargs) is df, i

# stream consumed, no data is available
assert sum([1 for _ in ds]) == 0


def test_grib_from_stream_multi_batch():
with open(earthkit_examples_file("test6.grib"), "rb") as stream:
fs = from_source("stream", stream, batch_size=2)
Expand All @@ -183,7 +146,7 @@ def test_grib_from_stream_multi_batch():


@pytest.mark.parametrize(
"decode_kwargs,expected_shape",
"convert_kwargs,expected_shape",
[
({}, (2, 7, 12)),
(None, (2, 7, 12)),
Expand All @@ -198,7 +161,7 @@ def test_grib_from_stream_multi_batch():
),
],
)
def test_grib_from_stream_multi_batch_decode(decode_kwargs, expected_shape):
def test_grib_from_stream_multi_batch_convert_to_numpy(convert_kwargs, expected_shape):
with open(earthkit_examples_file("test6.grib"), "rb") as stream:
ds = from_source("stream", stream, batch_size=2)

Expand All @@ -208,15 +171,15 @@ def test_grib_from_stream_multi_batch_decode(decode_kwargs, expected_shape):
[("u", 850), ("v", 850)],
]

if decode_kwargs is None:
decode_kwargs = {}
if convert_kwargs is None:
convert_kwargs = {}

for i, f in enumerate(ds):
df = f.to_decoded(**decode_kwargs)
df = f.to_fieldlist("numpy", **convert_kwargs)
assert df.metadata(("param", "level")) == ref[i], i
assert df._array.shape == expected_shape, i
assert df.to_numpy(**decode_kwargs).shape == expected_shape, i
assert df.to_decoded(**decode_kwargs) is df, i
assert df.to_numpy(**convert_kwargs).shape == expected_shape, i
assert df.to_fieldlist("numpy", **convert_kwargs) is df, i

# stream consumed, no data is available
assert sum([1 for _ in ds]) == 0
Expand Down Expand Up @@ -267,22 +230,22 @@ def test_grib_from_stream_in_memory():


@pytest.mark.parametrize(
"decode_kwargs,expected_shape",
"convert_kwargs,expected_shape",
[
({}, (6, 7, 12)),
({"flatten": False}, (6, 7, 12)),
({"flatten": True}, (6, 84)),
],
)
def test_grib_from_stream_in_memory_decode(decode_kwargs, expected_shape):
def test_grib_from_stream_in_memory_convert_to_numpy(convert_kwargs, expected_shape):
with open(earthkit_examples_file("test6.grib"), "rb") as stream:
ds_s = from_source(
"stream",
stream,
batch_size=0,
)

ds = ds_s.to_decoded(**decode_kwargs)
ds = ds_s.to_fieldlist("numpy", **convert_kwargs)

assert len(ds) == 6

Expand All @@ -302,7 +265,7 @@ def test_grib_from_stream_in_memory_decode(decode_kwargs, expected_shape):
assert val == ref, "method"

# data
assert ds.to_numpy(**decode_kwargs).shape == expected_shape
assert ds.to_numpy(**convert_kwargs).shape == expected_shape

ref = np.array(
[
Expand All @@ -316,13 +279,13 @@ def test_grib_from_stream_in_memory_decode(decode_kwargs, expected_shape):
)

if len(expected_shape) == 3:
vals = ds.to_numpy(**decode_kwargs)[:, 0, 0]
vals = ds.to_numpy(**convert_kwargs)[:, 0, 0]
else:
vals = ds.to_numpy(**decode_kwargs)[:, 0]
vals = ds.to_numpy(**convert_kwargs)[:, 0]

assert np.allclose(vals, ref)
assert ds._array.shape == expected_shape
assert ds.to_decoded(**decode_kwargs) is ds
assert ds.to_fieldlist("numpy", **convert_kwargs) is ds


def test_grib_save_when_loaded_from_stream():
Expand Down
2 changes: 1 addition & 1 deletion tests/numpy_fs/numpy_fs_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def check_numpy_fs(ds, ds_input, md_full):
assert r[2].metadata("param") == "u"


def check_numpy_fs_decoded(ds, ds_input, md_full, flatten=False, dtype=None):
def check_numpy_fs_from_to_fieldlist(ds, ds_input, md_full, flatten=False, dtype=None):
assert len(ds_input) in [1, 2, 3]
assert len(ds) == len(md_full)
assert ds.metadata("param") == md_full
Expand Down
21 changes: 12 additions & 9 deletions tests/numpy_fs/test_numpy_fs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

here = os.path.dirname(__file__)
sys.path.insert(0, here)
from numpy_fs_fixtures import check_numpy_fs, check_numpy_fs_decoded # noqa: E402
from numpy_fs_fixtures import ( # noqa: E402
check_numpy_fs,
check_numpy_fs_from_to_fieldlist,
)


def test_numpy_fs_grib_single_field():
Expand Down Expand Up @@ -123,28 +126,28 @@ def test_numpy_fs_grib_from_list_of_arrays_bad():
{"flatten": True, "dtype": np.float32},
],
)
def test_numpy_fs_grib_from_to_decoded(kwargs):
def test_numpy_fs_grib_from_to_fieldlist(kwargs):
ds = from_source("file", earthkit_examples_file("test.grib"))
md_full = ds.metadata("param")
assert len(ds) == 2

r = ds.to_decoded(**kwargs)
check_numpy_fs_decoded(r, [ds], md_full, **kwargs)
r = ds.to_fieldlist("numpy", **kwargs)
check_numpy_fs_from_to_fieldlist(r, [ds], md_full, **kwargs)


def test_numpy_fs_grib_from_to_decoded_repeat():
def test_numpy_fs_grib_from_to_fieldlist_repeat():
ds = from_source("file", earthkit_examples_file("test.grib"))
md_full = ds.metadata("param")
assert len(ds) == 2

kwargs = {}
r = ds.to_decoded(**kwargs)
check_numpy_fs_decoded(r, [ds], md_full, **kwargs)
r = ds.to_fieldlist("numpy", **kwargs)
check_numpy_fs_from_to_fieldlist(r, [ds], md_full, **kwargs)

kwargs = {"flatten": True, "dtype": np.float32}
r1 = r.to_decoded(**kwargs)
r1 = r.to_fieldlist("numpy", **kwargs)
assert r1 is not r
check_numpy_fs_decoded(r1, [ds], md_full, **kwargs)
check_numpy_fs_from_to_fieldlist(r1, [ds], md_full, **kwargs)


if __name__ == "__main__":
Expand Down

0 comments on commit fe88303

Please sign in to comment.