From 07d9d6dc56548edda315f9b23bd676fabda30fa0 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Wed, 23 Oct 2024 20:13:40 +0100 Subject: [PATCH] Alter field values --- src/earthkit/data/core/fieldlist.py | 14 +- src/earthkit/data/indexing/fieldlist.py | 28 +++ src/earthkit/data/readers/grib/codes.py | 8 +- src/earthkit/data/readers/grib/memory.py | 8 +- src/earthkit/data/readers/netcdf/field.py | 8 +- src/earthkit/data/sources/array_list.py | 8 +- src/earthkit/data/sources/forcings.py | 8 +- src/earthkit/data/writers/grib.py | 7 + tests/grib/test_grib_copy.py | 241 +++++++++++++++++++++- 9 files changed, 307 insertions(+), 23 deletions(-) diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index f98768d6..2463cfcf 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -729,7 +729,7 @@ def _attributes(self, names, remapping=None, joiner=None, default=None): # return {name: metadata(name) for name in names} - def to_field(self, array_backend=None, **kwargs): + def to_field(self, flatten=False, dtype=None, array_backend=None, values=None, **kwargs): r"""Convert to a new :class:`Field`. Parameters @@ -748,7 +748,17 @@ def to_field(self, array_backend=None, **kwargs): """ from earthkit.data.sources.array_list import ArrayField - return ArrayField(self.to_array(array_backend=array_backend, **kwargs), self._metadata.override()) + if values is None: + values = self.to_array( + flatten=flatten, + dtype=dtype, + array_backend=array_backend, + ) + + return ArrayField( + values, + self._metadata.override(**kwargs), + ) @staticmethod def _flatten(v): diff --git a/src/earthkit/data/indexing/fieldlist.py b/src/earthkit/data/indexing/fieldlist.py index 4e261ddf..a2c1a529 100644 --- a/src/earthkit/data/indexing/fieldlist.py +++ b/src/earthkit/data/indexing/fieldlist.py @@ -124,5 +124,33 @@ def _metadata(self): return self.__metadata +class NewFieldWrapper: + def __init__(self, field, values=None, **kwargs): + self._field = field + self.__values = values + + if kwargs: + from earthkit.data.core.metadata import WrappedMetadata + + self.__metadata = WrappedMetadata(field._metadata, extra=kwargs, owner=field) + else: + self.__metadata = field._metadata + + def _values(self, dtype=None): + if self.__values is None: + return self._field._values(dtype=dtype) + else: + if dtype is None: + return self.__values + return self.__values.astype(dtype) + + @property + def _metadata(self): + return self.__metadata + + def _has_new_values(self): + return self.__values is not None + + # For backwards compatibility FieldArray = SimpleFieldList diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index c30d7ebe..f1b0d5a3 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -15,7 +15,7 @@ import numpy as np from earthkit.data.core.fieldlist import Field -from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper +from earthkit.data.indexing.fieldlist import NewFieldWrapper from earthkit.data.readers.grib.metadata import GribFieldMetadata from earthkit.data.utils.message import CodesHandle from earthkit.data.utils.message import CodesMessagePositionIndex @@ -326,12 +326,12 @@ def message(self): return self.handle.get_buffer() def copy(self, **kwargs): - return NewMetadataGribField(self, **kwargs) + return NewGribField(self, **kwargs) -class NewMetadataGribField(NewFieldMetadataWrapper, GribField): +class NewGribField(NewFieldWrapper, GribField): def __init__(self, field, **kwargs): - NewFieldMetadataWrapper.__init__(self, field, **kwargs) + NewFieldWrapper.__init__(self, field, **kwargs) self._handle = field._handle GribField.__init__( self, diff --git a/src/earthkit/data/readers/grib/memory.py b/src/earthkit/data/readers/grib/memory.py index 01d344f3..7702bba8 100644 --- a/src/earthkit/data/readers/grib/memory.py +++ b/src/earthkit/data/readers/grib/memory.py @@ -11,7 +11,7 @@ import eccodes -from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper +from earthkit.data.indexing.fieldlist import NewFieldWrapper from earthkit.data.indexing.fieldlist import SimpleFieldList from earthkit.data.readers import Reader from earthkit.data.readers.grib.codes import GribCodesHandle @@ -156,12 +156,12 @@ def from_buffer(buf): ) def copy(self, **kwargs): - return NewMetadataGribFieldInMemory(self, **kwargs) + return NewGribFieldInMemory(self, **kwargs) -class NewMetadataGribFieldInMemory(NewFieldMetadataWrapper, GribFieldInMemory): +class NewGribFieldInMemory(NewFieldWrapper, GribFieldInMemory): def __init__(self, field, **kwargs): - NewFieldMetadataWrapper.__init__(self, field, **kwargs) + NewFieldWrapper.__init__(self, field, **kwargs) self._handle = field._handle GribFieldInMemory.__init__( self, diff --git a/src/earthkit/data/readers/netcdf/field.py b/src/earthkit/data/readers/netcdf/field.py index 33d98005..3fc75d78 100644 --- a/src/earthkit/data/readers/netcdf/field.py +++ b/src/earthkit/data/readers/netcdf/field.py @@ -17,7 +17,7 @@ from earthkit.data.core.geography import Geography from earthkit.data.core.metadata import MetadataAccessor from earthkit.data.core.metadata import RawMetadata -from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper +from earthkit.data.indexing.fieldlist import NewFieldWrapper from earthkit.data.utils.bbox import BoundingBox from earthkit.data.utils.dates import to_datetime @@ -296,12 +296,12 @@ def tidy(x): return tidy(self._ds[self._ds[self.variable].grid_mapping].attrs) def copy(self, **kwargs): - return NewMetadataXarrayField(self, **kwargs) + return NewXarrayField(self, **kwargs) -class NewMetadataXarrayField(NewFieldMetadataWrapper, XArrayField): +class NewXarrayField(NewFieldWrapper, XArrayField): def __init__(self, field, **kwargs): - NewFieldMetadataWrapper.__init__(self, field, **kwargs) + NewFieldWrapper.__init__(self, field, **kwargs) XArrayField.__init__(self, field.ds, field.variable, field.slices, field.non_dim_coords) diff --git a/src/earthkit/data/sources/array_list.py b/src/earthkit/data/sources/array_list.py index ec548ad4..56753830 100644 --- a/src/earthkit/data/sources/array_list.py +++ b/src/earthkit/data/sources/array_list.py @@ -11,7 +11,7 @@ import math from earthkit.data.core.fieldlist import Field -from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper +from earthkit.data.indexing.fieldlist import NewFieldWrapper from earthkit.data.utils.array import array_namespace LOG = logging.getLogger(__name__) @@ -96,12 +96,12 @@ def __setstate__(self, state: dict): self.__metadata = state.pop("_metadata") def copy(self, **kwargs): - return NewMetadataArrayField(self, **kwargs) + return NewArrayField(self, **kwargs) -class NewMetadataArrayField(NewFieldMetadataWrapper, ArrayField): +class NewArrayField(NewFieldWrapper, ArrayField): def __init__(self, field, **kwargs): - NewFieldMetadataWrapper.__init__(self, field, **kwargs) + NewFieldWrapper.__init__(self, field, **kwargs) ArrayField.__init__(self, field._array, None) diff --git a/src/earthkit/data/sources/forcings.py b/src/earthkit/data/sources/forcings.py index fced7b04..a8bb1cb3 100644 --- a/src/earthkit/data/sources/forcings.py +++ b/src/earthkit/data/sources/forcings.py @@ -20,7 +20,7 @@ from earthkit.data.core.metadata import RawMetadata from earthkit.data.decorators import cached_method from earthkit.data.decorators import normalize -from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper +from earthkit.data.indexing.fieldlist import NewFieldWrapper from earthkit.data.utils.dates import to_datetime LOG = logging.getLogger(__name__) @@ -256,15 +256,15 @@ def _values(self, dtype=None): return values def copy(self, **kwargs): - return NewMetadataForcingField(self, **kwargs) + return NewForcingField(self, **kwargs) def __repr__(self): return "ForcingField(%s,%s,%s)" % (self.param, self.date, self.number) -class NewMetadataForcingField(NewFieldMetadataWrapper, ForcingField): +class NewForcingField(NewFieldWrapper, ForcingField): def __init__(self, field, **kwargs): - NewFieldMetadataWrapper.__init__(self, field, **kwargs) + NewFieldWrapper.__init__(self, field, **kwargs) ForcingField.__init__(self, field.maker, field.date, field.param, field.proc, number=field.number) diff --git a/src/earthkit/data/writers/grib.py b/src/earthkit/data/writers/grib.py index 415b1363..ca18038d 100644 --- a/src/earthkit/data/writers/grib.py +++ b/src/earthkit/data/writers/grib.py @@ -48,6 +48,13 @@ def write(self, f, field, values=None, check_nans=True, bits_per_value=None): if "generatingProcessIdentifier" not in md: md["generatingProcessIdentifier"] = None + if values is None: + try: + if field._has_new_values(): + values = field.values + except AttributeError: + pass + output.write(values, check_nans=check_nans, missing_value=field.handle.MISSING_VALUE, **md) diff --git a/tests/grib/test_grib_copy.py b/tests/grib/test_grib_copy.py index e3cb1dce..066ea03b 100644 --- a/tests/grib/test_grib_copy.py +++ b/tests/grib/test_grib_copy.py @@ -18,6 +18,7 @@ from earthkit.data import FieldList from earthkit.data import from_source from earthkit.data.core.temporary import temp_file +from earthkit.data.sources.array_list import ArrayField here = os.path.dirname(__file__) sys.path.insert(0, here) @@ -25,7 +26,7 @@ @pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) -def test_grib_copy_core(fl_type): +def test_grib_copy_metadata(fl_type): ds_ori, _ = load_grib_data("test4.grib", fl_type) def _func1(field, key, original_metadata): @@ -124,3 +125,241 @@ def _func3(field, key, original_metadata): # assert ds_1.metadata("shortName") == ["q", "q"] # assert ds_1.metadata("level") == [600, 600] # assert ds_1.metadata("levelist") == [600, 600] + + +@pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) +def test_grib_copy_values(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + vals_ori = ds_ori[0].values + + # --------------- + # field + # --------------- + + f = ds_ori[0].copy(values=vals_ori + 1) + + assert f.metadata("param") == "t" + assert f.metadata("shortName") == "t" + assert f.metadata("level") == 500 + assert f.metadata("levelist") == 500 + assert f.metadata("date", "param") == (20070101, "t") + assert f.metadata("param", "date") == ("t", 20070101) + + assert np.allclose(f.values, vals_ori + 1) + assert np.allclose(ds_ori[0].values, vals_ori) + + # write back to grib + # we can only have ecCodes keys + with temp_file() as tmp: + f.save(tmp) + f_saved = from_source("file", tmp)[0] + assert f_saved.metadata("param") == "t" + assert f_saved.metadata("shortName") == "t" + assert f_saved.metadata("level") == 500 + assert f_saved.metadata("levelist") == 500 + assert np.allclose(f_saved.values, vals_ori + 1) + + # --------------- + # fieldlist + # --------------- + + fields = [] + for i in range(2): + f = ds_ori[i].copy(values=vals_ori + i + 1) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.metadata("param") == ["t", "z"] + assert ds.metadata("shortName") == ["t", "z"] + assert ds.metadata("level") == [500, 500] + assert ds.metadata("levelist") == [500, 500] + assert np.allclose(ds[0].values, vals_ori + 1) + assert np.allclose(ds[1].values, vals_ori + 2) + + # write back to grib + with temp_file() as tmp: + ds.save(tmp) + ds_saved = from_source("file", tmp) + assert ds_saved.metadata("param") == ["t", "z"] + assert ds_saved.metadata("shortName") == ["t", "z"] + assert ds_saved.metadata("level") == [500, 500] + assert ds_saved.metadata("levelist") == [500, 500] + assert np.allclose(ds_saved[0].values, vals_ori + 1) + assert np.allclose(ds_saved[1].values, vals_ori + 2) + + # TODO: implement the following + # serialise + # pickled_f = pickle.dumps(ds) + # ds_1 = pickle.loads(pickled_f) + + # assert ds_1.metadata("param") == ["q", "q"] + # assert ds_1.metadata("shortName") == ["q", "q"] + # assert ds_1.metadata("level") == [600, 600] + # assert ds_1.metadata("levelist") == [600, 600] + + +@pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) +def test_grib_copy_combined(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + vals_ori = ds_ori[0].values + + def _func1(field, key, original_metadata): + return original_metadata[key] + 100 + + # --------------- + # field + # --------------- + + f = ds_ori[0].copy( + values=vals_ori + 1, + param="q", + levelist=_func1, + ) + + assert isinstance(f, ArrayField) + assert f.metadata("param") == "q" + assert f.metadata("shortName") == "t" + assert f.metadata("level") == 500 + assert f.metadata("levelist") == 600 + assert f.metadata("date", "param") == (20070101, "q") + assert f.metadata("param", "date") == ("q", 20070101) + assert np.allclose(f.values, vals_ori + 1) + assert np.allclose(ds_ori[0].values, vals_ori) + + # write back to grib + # we can only have ecCodes keys + with temp_file() as tmp: + f.save(tmp) + f_saved = from_source("file", tmp)[0] + assert f_saved.metadata("param") == "q" + assert f_saved.metadata("shortName") == "q" + assert f_saved.metadata("level") == 600 + assert f_saved.metadata("levelist") == 600 + assert np.allclose(f_saved.values, vals_ori + 1) + + # --------------- + # fieldlist + # --------------- + + fields = [] + for i in range(2): + f = ds_ori[i].copy( + values=vals_ori + i + 1, + param="q", + levelist=_func1, + ) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.metadata("param") == ["q", "q"] + assert ds.metadata("shortName") == ["t", "z"] + assert ds.metadata("level") == [500, 500] + assert ds.metadata("levelist") == [600, 600] + assert np.allclose(ds[0].values, vals_ori + 1) + assert np.allclose(ds[1].values, vals_ori + 2) + + # write back to grib + with temp_file() as tmp: + ds.save(tmp) + ds_saved = from_source("file", tmp) + assert ds_saved.metadata("param") == ["q", "q"] + assert ds_saved.metadata("shortName") == ["q", "q"] + assert ds_saved.metadata("level") == [600, 600] + assert ds_saved.metadata("levelist") == [600, 600] + assert np.allclose(ds_saved[0].values, vals_ori + 1) + assert np.allclose(ds_saved[1].values, vals_ori + 2) + + # TODO: implement the following + # serialise + # pickled_f = pickle.dumps(ds) + # ds_1 = pickle.loads(pickled_f) + + # assert ds_1.metadata("param") == ["q", "q"] + # assert ds_1.metadata("shortName") == ["q", "q"] + # assert ds_1.metadata("level") == [600, 600] + # assert ds_1.metadata("levelist") == [600, 600] + + +@pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) +def test_grib_copy_to_field(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + vals_ori = ds_ori[0].values + + # --------------- + # field + # --------------- + + f = ds_ori[0].to_field( + values=vals_ori + 1, + shortName="q", + level=600, + ) + + assert f.metadata("param") == "q" + assert f.metadata("shortName") == "q" + assert f.metadata("level") == 600 + assert f.metadata("levelist") == 600 + assert f.metadata("date", "param") == (20070101, "q") + assert f.metadata("param", "date") == ("q", 20070101) + assert np.allclose(f.values, vals_ori + 1) + assert np.allclose(ds_ori[0].values, vals_ori) + + # write back to grib + # we can only have ecCodes keys + with temp_file() as tmp: + f.save(tmp) + f_saved = from_source("file", tmp)[0] + assert f_saved.metadata("param") == "q" + assert f_saved.metadata("shortName") == "q" + assert f_saved.metadata("level") == 600 + assert f_saved.metadata("levelist") == 600 + assert np.allclose(f_saved.values, vals_ori + 1) + + # --------------- + # fieldlist + # --------------- + + fields = [] + for i in range(2): + f = ds_ori[i].to_field( + values=vals_ori + i + 1, + shortName="q", + level=600, + ) + assert isinstance(f, ArrayField) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.metadata("param") == ["q", "q"] + assert ds.metadata("shortName") == ["q", "q"] + assert ds.metadata("level") == [600, 600] + assert ds.metadata("levelist") == [600, 600] + assert np.allclose(ds[0].values, vals_ori + 1) + assert np.allclose(ds[1].values, vals_ori + 2) + + # write back to grib + with temp_file() as tmp: + ds.save(tmp) + ds_saved = from_source("file", tmp) + assert ds_saved.metadata("param") == ["q", "q"] + assert ds_saved.metadata("shortName") == ["q", "q"] + assert ds_saved.metadata("level") == [600, 600] + assert ds_saved.metadata("levelist") == [600, 600] + assert np.allclose(ds_saved[0].values, vals_ori + 1) + assert np.allclose(ds_saved[1].values, vals_ori + 2) + + # TODO: implement the following + # serialise + # pickled_f = pickle.dumps(ds) + # ds_1 = pickle.loads(pickled_f) + + # assert ds_1.metadata("param") == ["q", "q"] + # assert ds_1.metadata("shortName") == ["q", "q"] + # assert ds_1.metadata("level") == [600, 600] + # assert ds_1.metadata("levelist") == [600, 600]