Skip to content

Commit

Permalink
Improve fieldlist indices implementation (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
sandorkertesz authored Aug 30, 2024
1 parent 8060550 commit fb8a720
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 41 deletions.
96 changes: 58 additions & 38 deletions src/earthkit/data/core/fieldlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import math
from abc import abstractmethod
from collections import defaultdict
from functools import cached_property

from earthkit.data.core import Base
from earthkit.data.core.index import Index
Expand All @@ -22,6 +23,58 @@
from earthkit.data.utils.metadata import metadata_argument


class FieldListIndices:
def __init__(self, field_list):
self.fs = field_list
self.user_indices = dict()

@cached_property
def default_index_keys(self):
if len(self.fs) > 0:
return self.fs[0]._metadata.index_keys()
else:
return []

def _index_value(self, key):
values = set()
for f in self.fs:
v = f.metadata(key, default=None)
if v is not None:
values.add(v)

return sorted(list(values))

@cached_property
def default_indices(self):
indices = defaultdict(set)
keys = self.default_index_keys
for f in self.fs:
v = f.metadata(keys, default=None)
for i, k in enumerate(keys):
if v[i] is not None:
indices[k].add(v[i])

return {k: sorted(list(v)) for k, v in indices.items()}

def indices(self, squeeze=False):
r = {**self.default_indices, **self.user_indices}

if squeeze:
return {k: v for k, v in r.items() if len(v) > 1}
else:
return r

def index(self, key):
if key in self.user_indices:
return self.user_indices[key]

if key in self.default_index_keys:
return self.default_indices[key]

self.user_indices[key] = self._index_value(key)
return self.user_indices[key]


class Field(Base):
r"""Represent a Field."""

Expand Down Expand Up @@ -684,8 +737,6 @@ class FieldList(Index):
defaults to "numpy".
"""

_md_indices = {}

def __init__(self, array_backend=None, **kwargs):
self._array_backend = ensure_backend(array_backend)
super().__init__(**kwargs)
Expand Down Expand Up @@ -741,31 +792,9 @@ def ignore(self):
else:
return False

@cached_method
def _default_index_keys(self):
if len(self) > 0:
return self[0]._metadata.index_keys()
else:
return []

def _find_index_values(self, key):
values = set()
for f in self:
v = f.metadata(key, default=None)
if v is not None:
values.add(v)
return sorted(list(values))

def _find_all_index_dict(self):
indices = defaultdict(set)
for f in self:
for k in self._default_index_keys():
v = f.metadata(k, default=None)
if v is None:
continue
indices[k].add(v)

return {k: sorted(list(v)) for k, v in indices.items()}
@cached_property
def _md_indices(self):
return FieldListIndices(self)

def indices(self, squeeze=False):
r"""Return the unique, sorted values for a set of metadata keys (see below)
Expand Down Expand Up @@ -806,12 +835,7 @@ def indices(self, squeeze=False):
used in :obj:`indices`.
"""
if not self._md_indices:
self._md_indices = self._find_all_index_dict()
if squeeze:
return {k: v for k, v in self._md_indices.items() if len(v) > 1}
else:
return self._md_indices
return self._md_indices.indices(squeeze=squeeze)

def index(self, key):
r"""Return the unique, sorted values of the specified metadata ``key`` from all the fields.
Expand Down Expand Up @@ -840,11 +864,7 @@ def index(self, key):
[300, 400, 500, 700, 850, 1000]
"""
if key in self.indices():
return self.indices()[key]

self._md_indices[key] = self._find_index_values(key)
return self._md_indices[key]
return self._md_indices.index(key)

def to_numpy(self, **kwargs):
r"""Return all the fields' values as an ndarray. It is formed as the array of the
Expand Down
15 changes: 12 additions & 3 deletions tests/grib/test_grib_inidces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def test_grib_indices_base(fl_type, array_backend):
ds = load_grib_data("tuv_pl.grib", fl_type, array_backend)

ref = {
ref_full = {
"class": ["od"],
"stream": ["oper"],
"levtype": ["pl"],
Expand All @@ -42,7 +42,7 @@ def test_grib_indices_base(fl_type, array_backend):
}

r = ds.indices()
assert r == ref
assert r == ref_full

ref = {
"levelist": [300, 400, 500, 700, 850, 1000],
Expand All @@ -55,6 +55,15 @@ def test_grib_indices_base(fl_type, array_backend):
r = ds.index("param")
assert r == ref

ref = [300, 400, 500, 700, 850, 1000]
ref_full["level"] = ref

r = ds.index("level")
assert r == ref

r = ds.indices()
assert r == ref_full


@pytest.mark.parametrize("fl_type", FL_TYPES)
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
Expand Down Expand Up @@ -155,7 +164,7 @@ def test_grib_indices_multi(fl_type, array_backend):

@pytest.mark.parametrize("fl_type", FL_TYPES)
@pytest.mark.parametrize("array_backend", ARRAY_BACKENDS)
def test_grib_indices_multi_Del(fl_type, array_backend):
def test_grib_indices_multi_sel(fl_type, array_backend):
f1 = load_grib_data("tuv_pl.grib", fl_type, array_backend)
f2 = load_grib_data("ml_data.grib", fl_type, array_backend, folder="data")
ds = f1 + f2
Expand Down

0 comments on commit fb8a720

Please sign in to comment.