diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index 3adf33c0..f98768d6 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -79,12 +79,6 @@ def index(self, key): class Field(Base): r"""Represent a Field.""" - def __init__( - self, - metadata=None, - ): - self.__metadata = metadata - @abstractmethod def _values(self, dtype=None): r"""Return the raw values extracted from the underlying storage format @@ -113,12 +107,10 @@ def values(self): return self._flatten(self._values()) @property + @abstractmethod def _metadata(self): r"""Metadata: Get the object representing the field's metadata.""" - if self.__metadata is None: - # TODO: remove this legacy method - self.__metadata = self._make_metadata() - return self.__metadata + self._not_implemented() def to_numpy(self, flatten=False, dtype=None, index=None): r"""Return the values stored in the field as an ndarray. @@ -637,6 +629,16 @@ def metadata(self, *keys, astype=None, remapping=None, patches=None, **kwargs): else: return self._metadata.as_namespace(None) + @abstractmethod + def copy(self, **kwargs): + r"""Return a copy of the field. + + Returns + ------- + :obj:`Field` + """ + self._not_implemented() + def dump(self, namespace=all, **kwargs): r"""Generate dump with all the metadata keys belonging to ``namespace``. @@ -668,6 +670,28 @@ def dump(self, namespace=all, **kwargs): """ return self._metadata.dump(namespace=namespace, **kwargs) + def save(self, filename, append=False, **kwargs): + r"""Write the field into a file. + + Parameters + ---------- + filename: str, optional + The target file path, if not defined attempts will be made to detect the filename + append: bool, optional + When it is true append data to the target file. Otherwise + the target file be overwritten if already exists. Default is False + **kwargs: dict, optional + Other keyword arguments passed to :obj:`write`. + + See Also + -------- + :obj:`write` + + """ + flag = "wb" if not append else "ab" + with open(filename, flag) as f: + self.write(f, **kwargs) + def __getitem__(self, key): """Return the value of the metadata ``key``.""" return self._metadata.get(key) diff --git a/src/earthkit/data/core/metadata.py b/src/earthkit/data/core/metadata.py index 0e63dda5..36faece4 100644 --- a/src/earthkit/data/core/metadata.py +++ b/src/earthkit/data/core/metadata.py @@ -73,32 +73,56 @@ def get(self, md, key, default=None, *, astype=None, raise_on_missing=False): else: raise KeyError(f"{key}, reason={e}") + def __call__(self, func): + @functools.wraps(func) + def wrapped(cls, key, *args, **kwargs): + if key in self: + return self.get(cls, key, *args, **kwargs) + return func(cls, key, *args, **kwargs) -def cacheable_metadata(func): - @functools.wraps(func) - def wrapped(self, key, default=None, *, astype=None, raise_on_missing=False): - if self._cache is not None: - cache_id = (key, default, astype, raise_on_missing) - if cache_id in self._cache: - return self._cache[cache_id] - - v = func(self, key, default=default, astype=astype, raise_on_missing=raise_on_missing) - self._cache[cache_id] = v - return v - else: - return func(self, key, default=default, astype=astype, raise_on_missing=raise_on_missing) - - return wrapped + return wrapped -def custom_accessor(func): - @functools.wraps(func) - def wrapped(self, key, *args, **kwargs): - if self.CUSTOM_ACCESSOR and key in self.CUSTOM_ACCESSOR: - return self.CUSTOM_ACCESSOR.get(self, key, *args, **kwargs) - return func(self, key, *args, **kwargs) +class MetadataCacheHandler: + @staticmethod + def make(cache=None): + if cache is True: + return dict() + elif cache is not False and cache is not None: + return cache + + @staticmethod + def clone_empty(cache): + if cache is not None: + return cache.__class__() + + @staticmethod + def serialise(cache): + if cache is not None: + return cache.__class__ + + @staticmethod + def deserialise(state): + cache = state + if state is not None: + return cache() + + @staticmethod + def cache_get(func): + @functools.wraps(func) + def wrapped(self, key, default=None, *, astype=None, raise_on_missing=False): + if self._cache is not None: + cache_id = (key, default, astype, raise_on_missing) + if cache_id in self._cache: + return self._cache[cache_id] + + v = func(self, key, default=default, astype=astype, raise_on_missing=raise_on_missing) + self._cache[cache_id] = v + return v + else: + return func(self, key, default=default, astype=astype, raise_on_missing=raise_on_missing) - return wrapped + return wrapped class Metadata(metaclass=ABCMeta): @@ -121,22 +145,6 @@ class Metadata(metaclass=ABCMeta): """ - DATA_FORMAT = None - NAMESPACES = [] - LS_KEYS = [] - DESCRIBE_KEYS = [] - INDEX_KEYS = [] - ALIASES = [] - CUSTOM_ACCESSOR = None - - _cache = None - - def __init__(self, cache=False): - if cache is True: - self._cache = dict() - elif cache is not False and cache is not None: - self._cache = cache - @abstractmethod def __iter__(self): """Return an iterator over the metadata keys.""" @@ -189,13 +197,10 @@ def items(self): """ pass - @cacheable_metadata - @custom_accessor + @abstractmethod def get(self, key, default=None, *, astype=None, raise_on_missing=False): r"""Return the value for ``key``. - When the instance is created with ``cache=True`` all the result is cached. - Parameters ---------- key: str @@ -224,10 +229,6 @@ def get(self, key, default=None, *, astype=None, raise_on_missing=False): a missing value. """ - return self._get(key, default=default, astype=astype, raise_on_missing=raise_on_missing) - - @abstractmethod - def _get(self, key, astype=None, default=None, raise_on_missing=False): pass @abstractmethod @@ -249,6 +250,7 @@ def override(self, *args, **kwargs): """ pass + @abstractmethod def namespaces(self): r"""Return the available namespaces. @@ -256,8 +258,9 @@ def namespaces(self): ------- list of str """ - return self.NAMESPACES + pass + @abstractmethod def as_namespace(self, namespace=None): r"""Return all the keys/values from a namespace. @@ -273,14 +276,14 @@ def as_namespace(self, namespace=None): All the keys/values from the `namespace`. """ - if namespace is None or namespace == "": - return {k: v for k, v in self.items()} - return {} + pass + @abstractmethod def dump(self, **kwargs): r"""Generate a dump from the metadata content.""" - return None + pass + @abstractmethod def datetime(self): r"""Return the date and time of the field. @@ -297,10 +300,7 @@ def datetime(self): 'valid_time': datetime.datetime(2020, 12, 21, 18, 0)} """ - return { - "base_time": self.base_datetime(), - "valid_time": self.valid_datetime(), - } + pass @abstractmethod def base_datetime(self): @@ -311,33 +311,30 @@ def valid_datetime(self): pass @property + @abstractmethod def geography(self): r""":obj:`Geography`: Get geography description. If it is not available None is returned. """ - return None - - @property - def gridspec(self): - r""":class:`~data.core.gridspec.GridSpec`: Get grid description. - - If it is not available None is returned. - """ - return None if self.geography is None else self.geography.gridspec() + pass + @abstractmethod def ls_keys(self): r"""Return the keys to be used with the :meth:`ls` method.""" - return self.LS_KEYS + pass + @abstractmethod def describe_keys(self): r"""Return the keys to be used with the :meth:`describe` method.""" - return self.DESCRIBE_KEYS + pass + @abstractmethod def index_keys(self): r"""Return the keys to be used with the :meth:`indices` method.""" - return self.INDEX_KEYS + pass + @abstractmethod def data_format(self): r"""Return the underlying data format. @@ -346,17 +343,27 @@ def data_format(self): str """ - return self.DATA_FORMAT + pass + @abstractmethod def _hide_internal_keys(self): - return self + """If the metadata object has internal keys, hide them. + + Returns + ------- + WrappedMetadata, Metadata + If the metadata object has internal keys, return a new wrapped object with the + internal keys hidden. Otherwise return the metadata object itself. + """ + pass class WrappedMetadata: - def __init__(self, metadata, extra=None, hidden=None, merge=True): + def __init__(self, metadata, extra=None, hidden=None, owner=None, merge=True): self.metadata = metadata self.extra = extra if extra is not None else dict() self.hidden = hidden if hidden is not None else [] + self.owner = owner for k in self.hidden: if k in self.extra: @@ -429,7 +436,7 @@ def get(self, key, default=None, *, astype=None, raise_on_missing=False, **kwarg return default if key in self.extra: - v = self.extra[key] + v = self._extra_value(key) if astype is not None and v is not None: try: return astype(v) @@ -441,6 +448,25 @@ def get(self, key, default=None, *, astype=None, raise_on_missing=False, **kwarg key, default=default, astype=astype, raise_on_missing=raise_on_missing, **kwargs ) + def _extra_value(self, key): + v = self.extra[key] + if callable(v): + v = v(self.owner, key, self.metadata) + return v + + def as_namespace(self, namespace): + r = dict() + if namespace is None: + r = dict(self.items()) + for k, v in self.extra.items(): + if k in r: + r[k] = self._extra_value(k) + else: + r = self.metadata.as_namespace(namespace) + # TODO: add filtering based on extra + + return r + def override(self, *args, **kwargs): md = self.metadata.override(*args, **kwargs) return self.__class__(md, self.extra, hidden=self.hidden, merge=True) @@ -489,21 +515,25 @@ def __init__(self, *args, **kwargs): self._d = dict(*args, **kwargs) super().__init__() - def override(self, *args, **kwargs): - d = dict(**self._d) - d.update(*args, **kwargs) - return RawMetadata(d) - def __len__(self): return len(self._d) def __contains__(self, key): return key in self._d + def __getitem__(self, key): + return self.get(key, raise_on_missing=True) + def __iter__(self): return iter(self.keys()) - def _get(self, key, default=None, astype=None, raise_on_missing=False): + def keys(self): + return self._d.keys() + + def items(self): + return self._d.items() + + def get(self, key, default=None, *, astype=None, raise_on_missing=False): if not raise_on_missing: v = self._d.get(key, default) else: @@ -516,11 +546,19 @@ def _get(self, key, default=None, astype=None, raise_on_missing=False): return None return v - def keys(self): - return self._d.keys() + def override(self, *args, **kwargs): + d = dict(**self._d) + d.update(*args, **kwargs) + return RawMetadata(d) - def items(self): - return self._d.items() + def namespaces(self): + return [] + + def as_namespace(self, namespace): + return {} + + def datetime(self): + return None def base_datetime(self): return None @@ -528,8 +566,27 @@ def base_datetime(self): def valid_datetime(self): return None - def as_namespace(self, namespace): - return {} + @property + def geography(self): + return None + + def dump(self, **kwargs): + return None + + def ls_keys(self): + return [] + + def describe_keys(self): + return [] + + def index_keys(self): + return [] + + def data_format(self): + return None + + def _hide_internal_keys(self): + return self def __repr__(self): return f"{self.__class__.__name__}({self._d.__repr__()})" diff --git a/src/earthkit/data/indexing/fieldlist.py b/src/earthkit/data/indexing/fieldlist.py index c3c5170f..4e261ddf 100644 --- a/src/earthkit/data/indexing/fieldlist.py +++ b/src/earthkit/data/indexing/fieldlist.py @@ -85,5 +85,44 @@ def merge(cls, sources): return cls.from_fields(list(chain(*[f for f in sources]))) +class WrappedField: + def __init__(self, field): + self._field = field + + def __getattr__(self, name): + return getattr(self._field, name) + + def __repr__(self) -> str: + return repr(self._field) + + +# class NewDataField(WrappedField): +# def __init__(self, field, data): +# super().__init__(field) +# self._data = data +# self.shape = data.shape + +# def to_numpy(self, flatten=False, dtype=None, index=None): +# data = self._data +# if dtype is not None: +# data = data.astype(dtype) +# if flatten: +# data = data.flatten() +# if index is not None: +# data = data[index] +# return data + + +class NewFieldMetadataWrapper: + def __init__(self, field, **kwargs): + from earthkit.data.core.metadata import WrappedMetadata + + self.__metadata = WrappedMetadata(field._metadata, extra=kwargs, owner=field) + + @property + def _metadata(self): + return self.__metadata + + # For backwards compatibility FieldArray = SimpleFieldList diff --git a/src/earthkit/data/readers/grib/codes.py b/src/earthkit/data/readers/grib/codes.py index 7cbf014e..c30d7ebe 100644 --- a/src/earthkit/data/readers/grib/codes.py +++ b/src/earthkit/data/readers/grib/codes.py @@ -15,6 +15,7 @@ import numpy as np from earthkit.data.core.fieldlist import Field +from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper from earthkit.data.readers.grib.metadata import GribFieldMetadata from earthkit.data.utils.message import CodesHandle from earthkit.data.utils.message import CodesMessagePositionIndex @@ -300,7 +301,7 @@ def __repr__(self): self._metadata.get("number", None), ) - def write(self, f, bits_per_value=None): + def write(self, f, **kwargs): r"""Writes the message to a file object. Parameters @@ -311,14 +312,9 @@ def write(self, f, bits_per_value=None): Set the ``bitsPerValue`` GRIB key in the generated GRIB message. When None the ``bitsPerValue`` stored in the metadata will be used. """ - if bits_per_value is not None: - handle = self.handle.clone() - handle.set_long("bitsPerValue", bits_per_value) - else: - handle = self.handle + from earthkit.data.writers import write - # assert isinstance(f, io.IOBase) - handle.write_to(f) + write(f, self, **kwargs) def message(self): r"""Returns a buffer containing the encoded message. @@ -328,3 +324,20 @@ def message(self): bytes """ return self.handle.get_buffer() + + def copy(self, **kwargs): + return NewMetadataGribField(self, **kwargs) + + +class NewMetadataGribField(NewFieldMetadataWrapper, GribField): + def __init__(self, field, **kwargs): + NewFieldMetadataWrapper.__init__(self, field, **kwargs) + self._handle = field._handle + GribField.__init__( + self, + field.path, + field._offset, + field._length, + handle_manager=field._handle_manager, + use_metadata_cache=field._use_metadata_cache, + ) diff --git a/src/earthkit/data/readers/grib/memory.py b/src/earthkit/data/readers/grib/memory.py index 8f43d0ba..01d344f3 100644 --- a/src/earthkit/data/readers/grib/memory.py +++ b/src/earthkit/data/readers/grib/memory.py @@ -11,6 +11,7 @@ import eccodes +from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper from earthkit.data.indexing.fieldlist import SimpleFieldList from earthkit.data.readers import Reader from earthkit.data.readers.grib.codes import GribCodesHandle @@ -154,6 +155,20 @@ def from_buffer(buf): GribCodesHandle(handle, None, None), use_metadata_cache=get_use_grib_metadata_cache() ) + def copy(self, **kwargs): + return NewMetadataGribFieldInMemory(self, **kwargs) + + +class NewMetadataGribFieldInMemory(NewFieldMetadataWrapper, GribFieldInMemory): + def __init__(self, field, **kwargs): + NewFieldMetadataWrapper.__init__(self, field, **kwargs) + self._handle = field._handle + GribFieldInMemory.__init__( + self, + field._handle, + use_metadata_cache=field._use_metadata_cache, + ) + class GribFieldListInMemory(SimpleFieldList): """Represent a GRIB field list in memory loaded lazily""" diff --git a/src/earthkit/data/readers/grib/metadata.py b/src/earthkit/data/readers/grib/metadata.py index cdeeb241..f7e9ce3c 100644 --- a/src/earthkit/data/readers/grib/metadata.py +++ b/src/earthkit/data/readers/grib/metadata.py @@ -14,6 +14,7 @@ from earthkit.data.core.geography import Geography from earthkit.data.core.metadata import Metadata from earthkit.data.core.metadata import MetadataAccessor +from earthkit.data.core.metadata import MetadataCacheHandler from earthkit.data.core.metadata import WrappedMetadata from earthkit.data.indexing.database import GRIB_KEYS_NAMES from earthkit.data.readers.grib.gridspec import make_gridspec @@ -301,24 +302,20 @@ class GribMetadata(Metadata): "vertical", ] - DATA_FORMAT = "grib" - - CUSTOM_ACCESSOR = MetadataAccessor( - { - "valid_datetime": ["valid_datetime", "valid_time"], - "gridspec": ["grid_spec", "gridspec"], - "base_datetime": ["base_datetime", "forecast_reference_time", "base_time"], - "reference_datetime": "reference_datetime", - "indexing_datetime": ["indexing_time", "indexing_datetime"], - "step_timedelta": "step_timedelta", - "param_level": "param_level", - } - ) + ACCESSORS = { + "valid_datetime": ["valid_datetime", "valid_time"], + "gridspec": ["grid_spec", "gridspec"], + "base_datetime": ["base_datetime", "forecast_reference_time", "base_time"], + "reference_datetime": "reference_datetime", + "indexing_datetime": ["indexing_time", "indexing_datetime"], + "step_timedelta": "step_timedelta", + "param_level": "param_level", + } __handle_type = None - def __init__(self, **kwargs): - super().__init__(**kwargs) + def __init__(self, cache=None, **kwargs): + self._cache = MetadataCacheHandler.make(cache) @staticmethod def _handle_type(): @@ -350,7 +347,9 @@ def keys(self): def items(self): return self._handle.items() - def _get(self, key, default=None, astype=None, raise_on_missing=False): + @MetadataCacheHandler.cache_get + @MetadataAccessor(ACCESSORS) + def get(self, key, default=None, *, astype=None, raise_on_missing=False): def _key_name(key): if key == "param": key = "shortName" @@ -409,11 +408,10 @@ def override(self, *args, headers_only_clone=True, **kwargs): handle.set_values(vals) # ensure that the cache settings are the same - cache = self._cache - if cache is not None: - cache = cache.__class__() + return StandAloneGribMetadata(handle, cache=MetadataCacheHandler.clone_empty(self._cache)) - return StandAloneGribMetadata(handle, cache=cache) + def namespaces(self): + return self.NAMESPACES def as_namespace(self, namespace=None): r"""Return all the keys/values from a namespace. @@ -450,6 +448,12 @@ def as_namespace(self, namespace=None): def geography(self): return GribFieldGeography(self) + def datetime(self): + return { + "base_time": self.base_datetime(), + "valid_time": self.valid_datetime(), + } + def base_datetime(self): return self._datetime("dataDate", "dataTime") @@ -476,9 +480,6 @@ def _datetime(self, date_key, time_key): def param_level(self): return f"{self.get('shortName')}{self.get('level', default='')}" - def namespaces(self): - return self.NAMESPACES - def dump(self, namespace=all, **kwargs): r"""Generate dump with all the metadata keys belonging to ``namespace``. @@ -531,6 +532,22 @@ def dump(self, namespace=all, **kwargs): return format_namespace_dump(r, selected="parameter", details=self.__class__.__name__, **kwargs) + def ls_keys(self): + return self.LS_KEYS + + def describe_keys(self): + return self.DESCRIBE_KEYS + + def index_keys(self): + return self.INDEX_KEYS + + def data_format(self): + return "grib" + + @property + def gridspec(self): + return self.geography.gridspec() + class GribFieldMetadata(GribMetadata): """Represent the metadata of a GRIB field. @@ -594,18 +611,14 @@ def __getstate__(self) -> dict: ret = {} ret["_handle"] = self._handle.get_buffer() # we do not serialize the cache contents - ret["_cache"] = self._cache if self._cache is None else self._cache.__class__ + ret["_cache"] = MetadataCacheHandler.serialise(self._cache) return ret def __setstate__(self, state: dict): from earthkit.data.readers.grib.memory import GribMessageMemoryReader - cache = state.pop("_cache") - if cache is not None: - cache = cache() - + self._cache = MetadataCacheHandler.deserialise(state.pop("_cache")) self.__handle = next(GribMessageMemoryReader(state.pop("_handle"))).handle - self._cache = cache class RestrictedGribMetadata(WrappedMetadata): diff --git a/src/earthkit/data/readers/grib/output.py b/src/earthkit/data/readers/grib/output.py index e9a635a8..7cccebf4 100644 --- a/src/earthkit/data/readers/grib/output.py +++ b/src/earthkit/data/readers/grib/output.py @@ -77,6 +77,7 @@ def encode( metadata={}, template=None, return_bytes=False, + missing_value=9999, **kwargs, ): # Make a copy as we may modify it @@ -105,7 +106,7 @@ def encode( if np.isnan(values).any(): # missing_value = np.finfo(values.dtype).max - missing_value = 9999 + missing_value = missing_value values = np.nan_to_num(values, nan=missing_value) metadata["missingValue"] = missing_value metadata["bitmapPresent"] = 1 @@ -120,8 +121,13 @@ def encode( for k in ("class", "type", "stream", "expver", "setLocalDefinition"): metadata.pop(k, None) + # TODO: revisit that logic if "generatingProcessIdentifier" not in metadata: metadata["generatingProcessIdentifier"] = 255 + else: + # kee + if metadata["generatingProcessIdentifier"] is None: + metadata.pop("generatingProcessIdentifier") LOG.debug("GribOutput.metadata %s", metadata) diff --git a/src/earthkit/data/readers/netcdf/field.py b/src/earthkit/data/readers/netcdf/field.py index 45a84731..33d98005 100644 --- a/src/earthkit/data/readers/netcdf/field.py +++ b/src/earthkit/data/readers/netcdf/field.py @@ -17,6 +17,7 @@ from earthkit.data.core.geography import Geography from earthkit.data.core.metadata import MetadataAccessor from earthkit.data.core.metadata import RawMetadata +from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper from earthkit.data.utils.bbox import BoundingBox from earthkit.data.utils.dates import to_datetime @@ -98,12 +99,10 @@ class XArrayMetadata(RawMetadata): ] MARS_KEYS = ["param", "step", "levelist", "levtype", "number", "date", "time"] - CUSTOM_ACCESSOR = MetadataAccessor( - { - "valid_datetime": ["valid_datetime", "valid_time"], - "base_datetime": ["base_datetime", "forecast_reference_time", "base_time"], - } - ) + ACCESSORS = { + "valid_datetime": ["valid_datetime", "valid_time"], + "base_datetime": ["base_datetime", "forecast_reference_time", "base_time"], + } def __init__(self, field): if not isinstance(field, XArrayField): @@ -160,6 +159,9 @@ def override(self, *args, **kwargs): def geography(self): return XArrayFieldGeography(self, self._field._ds, self._field.variable) + def namespaces(self): + return self.NAMESPACES + def as_namespace(self, namespace=None): if not isinstance(namespace, str) and namespace is not None: raise TypeError("namespace must be a str or None") @@ -180,6 +182,12 @@ def _as_mars(self): time=self.get("time", None), ) + def datetime(self): + return { + "base_time": self.base_datetime(), + "valid_time": self.valid_datetime(), + } + def base_datetime(self): v = self.valid_datetime() if v is not None: @@ -189,7 +197,8 @@ def valid_datetime(self): if self.time is not None: return to_datetime(self.time) - def _get(self, key, default=None, raise_on_missing=False, **kwargs): + @MetadataAccessor(ACCESSORS) + def get(self, key, default=None, *, raise_on_missing=False, **kwargs): if key.startswith("mars."): key = key[5:] if key not in self.MARS_KEYS: @@ -209,7 +218,10 @@ def _key_name(key): key = "level" return key - return super()._get(_key_name(key), default=default, raise_on_missing=raise_on_missing, **kwargs) + return super().get(_key_name(key), default=default, raise_on_missing=raise_on_missing, **kwargs) + + def ls_keys(self): + return self.LS_KEYS class XArrayField(Field): @@ -283,6 +295,15 @@ def tidy(x): return tidy(self._ds[self._ds[self.variable].grid_mapping].attrs) + def copy(self, **kwargs): + return NewMetadataXarrayField(self, **kwargs) + + +class NewMetadataXarrayField(NewFieldMetadataWrapper, XArrayField): + def __init__(self, field, **kwargs): + NewFieldMetadataWrapper.__init__(self, field, **kwargs) + XArrayField.__init__(self, field.ds, field.variable, field.slices, field.non_dim_coords) + class NetCDFMetadata(XArrayMetadata): pass diff --git a/src/earthkit/data/sources/array_list.py b/src/earthkit/data/sources/array_list.py index 0d99c064..ec548ad4 100644 --- a/src/earthkit/data/sources/array_list.py +++ b/src/earthkit/data/sources/array_list.py @@ -11,6 +11,7 @@ import math from earthkit.data.core.fieldlist import Field +from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper from earthkit.data.utils.array import array_namespace LOG = logging.getLogger(__name__) @@ -39,9 +40,10 @@ def __init__(self, array, metadata): metadata = UserMetadata(metadata, values=array) # TODO: this solution is questionable due to performance reasons - metadata = metadata._hide_internal_keys() + if metadata is not None: + metadata = metadata._hide_internal_keys() - super().__init__(metadata=metadata) + self.__metadata = metadata self._array = array def _values(self, dtype=None): @@ -73,7 +75,11 @@ def write(self, f, **kwargs): """ from earthkit.data.writers import write - write(f, self.to_numpy(flatten=True), self._metadata, **kwargs) + write(f, self, values=self.to_numpy(flatten=True), **kwargs) + + @property + def _metadata(self): + return self.__metadata @property def handle(self): @@ -87,8 +93,16 @@ def __getstate__(self) -> dict: def __setstate__(self, state: dict): self._array = state.pop("_array") - metadata = state.pop("_metadata") - super().__init__(metadata=metadata) + self.__metadata = state.pop("_metadata") + + def copy(self, **kwargs): + return NewMetadataArrayField(self, **kwargs) + + +class NewMetadataArrayField(NewFieldMetadataWrapper, ArrayField): + def __init__(self, field, **kwargs): + NewFieldMetadataWrapper.__init__(self, field, **kwargs) + ArrayField.__init__(self, field._array, None) def from_array(array, metadata): diff --git a/src/earthkit/data/sources/forcings.py b/src/earthkit/data/sources/forcings.py index e8730f4e..b7cc9c5c 100644 --- a/src/earthkit/data/sources/forcings.py +++ b/src/earthkit/data/sources/forcings.py @@ -10,6 +10,7 @@ import datetime import itertools import logging +from functools import cached_property import numpy as np @@ -19,6 +20,7 @@ from earthkit.data.core.metadata import RawMetadata from earthkit.data.decorators import cached_method from earthkit.data.decorators import normalize +from earthkit.data.indexing.fieldlist import NewFieldMetadataWrapper from earthkit.data.utils.dates import to_datetime LOG = logging.getLogger(__name__) @@ -35,6 +37,12 @@ def __init__(self, d, geography): def geography(self): return self._geo + def datetime(self): + return { + "base_time": self.base_datetime(), + "valid_time": self.valid_datetime(), + } + def base_datetime(self): return None @@ -44,6 +52,9 @@ def valid_datetime(self): def step_timedelta(self): return datetime.timedelta() + def ls_keys(self): + return self.LS_KEYS + class ForcingMaker: def __init__(self, field): @@ -225,17 +236,18 @@ def __init__(self, maker, date, param, proc, number=None): self.number = number # self._shape = shape # self._geometry = self.maker.field.metadata().geography + + @cached_property + def _metadata(self): d = dict( - valid_datetime=date if isinstance(date, str) else date.isoformat(), - param=param, + valid_datetime=self.date if isinstance(self.date, str) else self.date.isoformat(), + param=self.param, level=None, levelist=None, - number=number, + number=self.number, levtype=None, ) - super().__init__( - metadata=ForcingMetadata(d, self.maker.field.metadata().geography), - ) + return ForcingMetadata(d, self.maker.field.metadata().geography) def _values(self, dtype=None): values = self.proc(self.date) @@ -243,10 +255,19 @@ def _values(self, dtype=None): values = values.astype(dtype) return values + def copy(self, **kwargs): + return NewMetadataForcingField(self, **kwargs) + def __repr__(self): return "ForcingField(%s,%s,%s)" % (self.param, self.date, self.number) +class NewMetadataForcingField(NewFieldMetadataWrapper, ForcingField): + def __init__(self, field, **kwargs): + NewFieldMetadataWrapper.__init__(self, field, **kwargs) + ForcingField.__init__(self, field.maker, field.date, field.param, field.proc, number=field.number) + + def make_datetime(date, time): if time is None: return date diff --git a/src/earthkit/data/utils/metadata/dict.py b/src/earthkit/data/utils/metadata/dict.py index 46b0d5f4..f714875a 100644 --- a/src/earthkit/data/utils/metadata/dict.py +++ b/src/earthkit/data/utils/metadata/dict.py @@ -231,22 +231,12 @@ class UserMetadata(Metadata): ("param", "shortName"), ] - CUSTOM_ACCESSOR = MetadataAccessor( - { - "base_datetime": "base_datetime", - "valid_datetime": "valid_datetime", - "step_timedelta": "step_timedelta", - "param_level": "param_level", - }, - aliases=[ - ("dataDate", "date"), - ("dataTime", "time"), - ("forecast_reference_time", "base_datetime"), - ("level", "levelist"), - ("step", "endStep", "stepRange"), - ("param", "shortName"), - ], - ) + ACCESSORS = { + "base_datetime": "base_datetime", + "valid_datetime": "valid_datetime", + "step_timedelta": "step_timedelta", + "param_level": "param_level", + } LS_KEYS = ["param", "level", "base_datetime", "valid_datetime", "step", "number"] @@ -259,6 +249,9 @@ def __len__(self): def __contains__(self, key): return key in self._data + def __getitem__(self, key): + return self.get(key, raise_on_missing=True) + def __iter__(self): return iter(self._keys()) @@ -268,7 +261,8 @@ def keys(self): def items(self): return self._data.items() - def _get(self, key, default=None, astype=None, raise_on_missing=False): + @MetadataAccessor(ACCESSORS, ALIASES) + def get(self, key, default=None, *, astype=None, raise_on_missing=False): def _key_name(key): if key in self._data: return key @@ -291,6 +285,12 @@ def _key_name(key): else: return astype(v) + def datetime(self): + return { + "base_time": self.base_datetime(), + "valid_time": self.valid_datetime(), + } + def base_datetime(self): v = self._get_one(["base_datetime", "forecast_reference_time"]) if v is not None: @@ -344,3 +344,27 @@ def geography(self): def override(self, *args, **kwargs): raise NotImplementedError("override is not implemented for UserMetadata") + + def namespaces(self): + return [] + + def as_namespace(self, namespace=None): + return {} + + def dump(self, **kwargs): + return None + + def ls_keys(self): + return self.LS_KEYS + + def describe_keys(self): + return [] + + def index_keys(self): + return None + + def data_format(self): + return "dict" + + def _hide_internal_keys(self): + return self diff --git a/src/earthkit/data/writers/__init__.py b/src/earthkit/data/writers/__init__.py index 0ce0a602..70b94ba7 100644 --- a/src/earthkit/data/writers/__init__.py +++ b/src/earthkit/data/writers/__init__.py @@ -27,10 +27,10 @@ def write(self, f, values, metadata, **kwargs): pass -def write(f, values, metadata, **kwargs): - x = _writers(metadata.data_format()) +def write(f, field, **kwargs): + x = _writers(field._metadata.data_format()) c = x() - c.write(f, values, metadata, **kwargs) + c.write(f, field, **kwargs) @locked diff --git a/src/earthkit/data/writers/grib.py b/src/earthkit/data/writers/grib.py index 2362ab5c..415b1363 100644 --- a/src/earthkit/data/writers/grib.py +++ b/src/earthkit/data/writers/grib.py @@ -13,7 +13,7 @@ class GribWriter(Writer): DATA_FORMAT = "grib" - def write(self, f, values, metadata, check_nans=True, bits_per_value=None): + def write(self, f, field, values=None, check_nans=True, bits_per_value=None): r"""Write a GRIB field to a file object. Parameters @@ -30,25 +30,25 @@ def write(self, f, values, metadata, check_nans=True, bits_per_value=None): Set the ``bitsPerValue`` GRIB key in the generated GRIB message. When None the ``bitsPerValue`` stored in the metadata will be used. """ - handle = metadata._handle.clone() - if bits_per_value is None: - bits_per_value = metadata.get("bitsPerValue", 0) + from earthkit.data.readers.grib.output import new_grib_output - if bits_per_value != 0: - handle.set_long("bitsPerValue", bits_per_value) + output = new_grib_output(f, template=field) - if check_nans: - import numpy as np + md = {} + # wrapped metadata + if hasattr(field._metadata, "extra"): + md = {k: field._metadata._extra_value(k) for k, v in field._metadata.extra.items()} - if np.isnan(values).any(): - missing_value = handle.MISSING_VALUE - values = np.nan_to_num(values, nan=missing_value) - handle.set_double("missingValue", missing_value) - handle.set_long("bitmapPresent", 1) + if bits_per_value is not None: + if field._metadata.get("bitsPerValue", 0) != bits_per_value: + md["bitsPerValue"] = bits_per_value - handle.set_values(values) - handle.write(f) + # keep the original generatingProcessIdentifier if not set + if "generatingProcessIdentifier" not in md: + md["generatingProcessIdentifier"] = None + + output.write(values, check_nans=check_nans, missing_value=field.handle.MISSING_VALUE, **md) Writer = GribWriter diff --git a/tests/core/test_metadata.py b/tests/core/test_metadata.py index 11779cec..45344009 100644 --- a/tests/core/test_metadata.py +++ b/tests/core/test_metadata.py @@ -303,7 +303,7 @@ def test_grib_metadata_override_headers_only_false(): md2["average"] -def test_grib_metadata_wrapped(): +def test_grib_metadata_wrapped_core(): ds = from_source("file", earthkit_examples_file("test.grib")) md = ds[0].metadata() md_num = len(md) @@ -399,6 +399,43 @@ def test_grib_metadata_wrapped(): break +def test_grib_metadata_wrapped_callable(): + ds = from_source("file", earthkit_examples_file("test4.grib")) + md = ds[0].metadata() + assert md["perturbationNumber"] == 0 + assert md["shortName"] == "t" + assert md["levelist"] == 500 + + def _func1(fs, key, original_metadata): + return original_metadata.get("param") + "_" + original_metadata.get("levelist", astype=str) + + def _func2(fs, key, original_metadata): + return fs.mars_area + + def _func3(fs, key, original_metadata): + return "_" + str(original_metadata.get(key)) + + extra = { + "my_custom_key": "2", + "name": _func1, + "mars_area": _func2, + "gridType": _func3, + "perturbationNumber": 3, + } + md_ori = StandAloneGribMetadata(md._handle) + from earthkit.data.core.metadata import WrappedMetadata + + # extra keys are not added to the metadata + md = WrappedMetadata(md_ori, extra=extra, owner=ds[0]) + + assert md["my_custom_key"] == "2" + assert md["perturbationNumber"] == 3 + assert md["name"] == "t_500" + assert np.allclose(np.array(md["mars_area"]), np.array([90.0, 0.0, -90.0, 359.0])) + assert md["gridType"] == "_regular_ll" + assert md["typeOfLevel"] == "isobaricInhPa" + + if __name__ == "__main__": from earthkit.data.testing import main diff --git a/tests/grib/test_grib_copy.py b/tests/grib/test_grib_copy.py new file mode 100644 index 00000000..e3cb1dce --- /dev/null +++ b/tests/grib/test_grib_copy.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 + +# (C) Copyright 2020 ECMWF. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. +# + +import os +import sys + +import numpy as np +import pytest + +from earthkit.data import FieldList +from earthkit.data import from_source +from earthkit.data.core.temporary import temp_file + +here = os.path.dirname(__file__) +sys.path.insert(0, here) +from grib_fixtures import load_grib_data # noqa: E402 + + +@pytest.mark.parametrize("fl_type", ["file", "array", "memory"]) +def test_grib_copy_core(fl_type): + ds_ori, _ = load_grib_data("test4.grib", fl_type) + + def _func1(field, key, original_metadata): + return original_metadata[key] + 100 + + def _func2(field, key, original_metadata): + return field.mars_area + + def _func3(field, key, original_metadata): + return original_metadata.get("param") + "_" + str(original_metadata.get("levelist")) + + # --------------- + # field + # --------------- + + f = ds_ori[0].copy( + param="q", + levelist=_func1, + mars_area=_func2, + name=_func3, + ) + + assert f.metadata("param") == "q" + assert f.metadata("shortName") == "t" + assert f.metadata("level") == 500 + assert f.metadata("levelist") == 600 + assert f.metadata("date", "param") == (20070101, "q") + assert f.metadata("param", "date") == ("q", 20070101) + assert np.allclose(np.array(f.metadata("mars_area")), np.array([90.0, 0.0, -90.0, 359.0])) + assert f.metadata("name") == "t_500" + + # TODO: apply wrapped metadata to namespaces + assert f.metadata(namespace="mars") == { + "class": "ea", + "date": 20070101, + "domain": "g", + "expver": "0001", + "levelist": 500, + "levtype": "pl", + "param": "t", + "step": 0, + "stream": "oper", + "time": 1200, + "type": "an", + } + + # write back to grib + # we can only have ecCodes keys + with temp_file() as tmp: + f = ds_ori[0].copy( + param="q", + levelist=_func1, + ) + + f.save(tmp) + f_saved = from_source("file", tmp)[0] + assert f_saved.metadata("param") == "q" + assert f_saved.metadata("shortName") == "q" + assert f_saved.metadata("level") == 600 + assert f_saved.metadata("levelist") == 600 + + # --------------- + # fieldlist + # --------------- + + fields = [] + for i in range(2): + f = ds_ori[i].copy( + param="q", + levelist=_func1, + ) + fields.append(f) + + ds = FieldList.from_fields(fields) + + assert ds.metadata("param") == ["q", "q"] + assert ds.metadata("shortName") == ["t", "z"] + assert ds.metadata("level") == [500, 500] + assert ds.metadata("levelist") == [600, 600] + + # write back to grib + with temp_file() as tmp: + ds.save(tmp) + ds_saved = from_source("file", tmp) + assert ds_saved.metadata("param") == ["q", "q"] + assert ds_saved.metadata("shortName") == ["q", "q"] + assert ds_saved.metadata("level") == [600, 600] + assert ds_saved.metadata("levelist") == [600, 600] + + # TODO: implement the following + # serialise + # pickled_f = pickle.dumps(ds) + # ds_1 = pickle.loads(pickled_f) + + # assert ds_1.metadata("param") == ["q", "q"] + # assert ds_1.metadata("shortName") == ["q", "q"] + # assert ds_1.metadata("level") == [600, 600] + # assert ds_1.metadata("levelist") == [600, 600]