From fb8a720686ffbe8e2c70a67b3e1e447b156e24b9 Mon Sep 17 00:00:00 2001 From: Sandor Kertesz Date: Fri, 30 Aug 2024 17:10:39 +0100 Subject: [PATCH] Improve fieldlist indices implementation (#436) --- src/earthkit/data/core/fieldlist.py | 96 +++++++++++++++++------------ tests/grib/test_grib_inidces.py | 15 ++++- 2 files changed, 70 insertions(+), 41 deletions(-) diff --git a/src/earthkit/data/core/fieldlist.py b/src/earthkit/data/core/fieldlist.py index ee0dbe5e..c2ca9425 100644 --- a/src/earthkit/data/core/fieldlist.py +++ b/src/earthkit/data/core/fieldlist.py @@ -10,6 +10,7 @@ import math from abc import abstractmethod from collections import defaultdict +from functools import cached_property from earthkit.data.core import Base from earthkit.data.core.index import Index @@ -22,6 +23,58 @@ from earthkit.data.utils.metadata import metadata_argument +class FieldListIndices: + def __init__(self, field_list): + self.fs = field_list + self.user_indices = dict() + + @cached_property + def default_index_keys(self): + if len(self.fs) > 0: + return self.fs[0]._metadata.index_keys() + else: + return [] + + def _index_value(self, key): + values = set() + for f in self.fs: + v = f.metadata(key, default=None) + if v is not None: + values.add(v) + + return sorted(list(values)) + + @cached_property + def default_indices(self): + indices = defaultdict(set) + keys = self.default_index_keys + for f in self.fs: + v = f.metadata(keys, default=None) + for i, k in enumerate(keys): + if v[i] is not None: + indices[k].add(v[i]) + + return {k: sorted(list(v)) for k, v in indices.items()} + + def indices(self, squeeze=False): + r = {**self.default_indices, **self.user_indices} + + if squeeze: + return {k: v for k, v in r.items() if len(v) > 1} + else: + return r + + def index(self, key): + if key in self.user_indices: + return self.user_indices[key] + + if key in self.default_index_keys: + return self.default_indices[key] + + self.user_indices[key] = self._index_value(key) + return self.user_indices[key] + + class Field(Base): r"""Represent a Field.""" @@ -684,8 +737,6 @@ class FieldList(Index): defaults to "numpy". """ - _md_indices = {} - def __init__(self, array_backend=None, **kwargs): self._array_backend = ensure_backend(array_backend) super().__init__(**kwargs) @@ -741,31 +792,9 @@ def ignore(self): else: return False - @cached_method - def _default_index_keys(self): - if len(self) > 0: - return self[0]._metadata.index_keys() - else: - return [] - - def _find_index_values(self, key): - values = set() - for f in self: - v = f.metadata(key, default=None) - if v is not None: - values.add(v) - return sorted(list(values)) - - def _find_all_index_dict(self): - indices = defaultdict(set) - for f in self: - for k in self._default_index_keys(): - v = f.metadata(k, default=None) - if v is None: - continue - indices[k].add(v) - - return {k: sorted(list(v)) for k, v in indices.items()} + @cached_property + def _md_indices(self): + return FieldListIndices(self) def indices(self, squeeze=False): r"""Return the unique, sorted values for a set of metadata keys (see below) @@ -806,12 +835,7 @@ def indices(self, squeeze=False): used in :obj:`indices`. """ - if not self._md_indices: - self._md_indices = self._find_all_index_dict() - if squeeze: - return {k: v for k, v in self._md_indices.items() if len(v) > 1} - else: - return self._md_indices + return self._md_indices.indices(squeeze=squeeze) def index(self, key): r"""Return the unique, sorted values of the specified metadata ``key`` from all the fields. @@ -840,11 +864,7 @@ def index(self, key): [300, 400, 500, 700, 850, 1000] """ - if key in self.indices(): - return self.indices()[key] - - self._md_indices[key] = self._find_index_values(key) - return self._md_indices[key] + return self._md_indices.index(key) def to_numpy(self, **kwargs): r"""Return all the fields' values as an ndarray. It is formed as the array of the diff --git a/tests/grib/test_grib_inidces.py b/tests/grib/test_grib_inidces.py index ffbb8ad8..37d722a2 100644 --- a/tests/grib/test_grib_inidces.py +++ b/tests/grib/test_grib_inidces.py @@ -27,7 +27,7 @@ def test_grib_indices_base(fl_type, array_backend): ds = load_grib_data("tuv_pl.grib", fl_type, array_backend) - ref = { + ref_full = { "class": ["od"], "stream": ["oper"], "levtype": ["pl"], @@ -42,7 +42,7 @@ def test_grib_indices_base(fl_type, array_backend): } r = ds.indices() - assert r == ref + assert r == ref_full ref = { "levelist": [300, 400, 500, 700, 850, 1000], @@ -55,6 +55,15 @@ def test_grib_indices_base(fl_type, array_backend): r = ds.index("param") assert r == ref + ref = [300, 400, 500, 700, 850, 1000] + ref_full["level"] = ref + + r = ds.index("level") + assert r == ref + + r = ds.indices() + assert r == ref_full + @pytest.mark.parametrize("fl_type", FL_TYPES) @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) @@ -155,7 +164,7 @@ def test_grib_indices_multi(fl_type, array_backend): @pytest.mark.parametrize("fl_type", FL_TYPES) @pytest.mark.parametrize("array_backend", ARRAY_BACKENDS) -def test_grib_indices_multi_Del(fl_type, array_backend): +def test_grib_indices_multi_sel(fl_type, array_backend): f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend) f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data") ds = f1 + f2