Skip to content

Commit

Permalink
Enable to use forcings without providing a source (#495)
Browse files Browse the repository at this point in the history
* Enable to use forcing without providing a source
  • Loading branch information
sandorkertesz authored Oct 23, 2024
1 parent a2804c0 commit 527ce1b
Show file tree
Hide file tree
Showing 8 changed files with 178 additions and 38 deletions.
33 changes: 32 additions & 1 deletion src/earthkit/data/sources/forcings.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,12 +305,27 @@ def index_to_coords(index: int, shape):


class ForcingsFieldListCore(FieldList):
def __init__(self, source_or_dataset, request={}, **kwargs):
def __init__(self, source_or_dataset=None, *, request={}, **kwargs):
request = dict(**request)
request.update(kwargs)

self.request = self._request(**request)

def find_latlon():
lats = None
for k in ["latitudes", "latitude"]:
if k in self.request:
lats = np.asarray(self.request.pop(k))
break

lons = None
for k in ["longitudes", "longitude"]:
if k in self.request:
lons = np.asarray(self.request.pop(k))
break

return lats, lons

def find_numbers(source_or_dataset):
if "number" in self.request:
return self.request["number"]
Expand Down Expand Up @@ -345,6 +360,22 @@ def find_dates(source_or_dataset):

return source_or_dataset.unique_values("valid_datetime")["valid_datetime"]

if source_or_dataset is None:
lats, lons = find_latlon()
if lats is None:
raise ValueError("latitudes must be specified when no source or dataset provided")

if lons is None:
raise ValueError("longitudes must be specified when no source or dataset provided")

from earthkit.data import from_source

vals = np.ones(lats.shape)
d = {"values": vals, "latitudes": lats, "longitudes": lons}
# d.update(self.request)

source_or_dataset = from_source("list-of-dicts", [d])

self.dates = find_dates(source_or_dataset)

self.params = self.request["param"]
Expand Down
31 changes: 24 additions & 7 deletions tests/forcings/forcings_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
]


def load_forcings_fs(params=None, first_step=6, last_step=72):
def load_forcings_fs(params=None, first_step=6, last_step=72, input_data="grib"):
sample = from_source("file", earthkit_examples_file("test.grib"))

if params is None:
Expand All @@ -57,12 +57,29 @@ def load_forcings_fs(params=None, first_step=6, last_step=72):
for step in range(first_step, last_step + step_increment, step_increment):
dates.append(start + datetime.timedelta(hours=step))

ds = from_source(
"forcings",
sample,
date=dates,
param=params,
)
if input_data == "grib":
ds = from_source(
"forcings",
sample,
date=dates,
param=params,
)
elif input_data == "latlon":
d = {}
ll = sample[0].to_latlon()
d["latitudes"] = ll["lat"]
d["longitudes"] = ll["lon"]
# d["date"] = sample[0].metadata("date")
# d["param"] = sample[0].metadata("param")
ds = from_source(
"forcings",
**d,
date=dates,
param=params,
)
else:
raise ValueError(f"Unknown input_data: {input_data}")

assert len(ds) == len(dates) * len(params)

md = [[d.isoformat(), p] for d, p in itertools.product(dates, params)]
Expand Down
12 changes: 8 additions & 4 deletions tests/forcings/test_forcings_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,16 @@
import os
import sys

import pytest

here = os.path.dirname(__file__)
sys.path.insert(0, here)
from forcings_fixtures import load_forcings_fs # noqa: E402


def test_forcings_datetime():
ds, _ = load_forcings_fs(last_step=12)
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_datetime(input_data):
ds, _ = load_forcings_fs(last_step=12, input_data=input_data)

ref = {
"base_time": [None],
Expand All @@ -32,8 +35,9 @@ def test_forcings_datetime():
assert ds.datetime() == ref


def test_forcings_valid_datetime():
ds, _ = load_forcings_fs(last_step=12)
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_valid_datetime(input_data):
ds, _ = load_forcings_fs(last_step=12, input_data=input_data)
f = ds[4]

assert f.metadata("valid_datetime") == "2020-05-13T18:00:00"
Expand Down
15 changes: 9 additions & 6 deletions tests/forcings/test_forcings_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from forcings_fixtures import load_forcings_fs # noqa: E402


def _build_proc_ref():
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def _build_proc_ref(input_data):
import yaml

ds, _ = load_forcings_fs(params=all_params, last_step=12)
ds, _ = load_forcings_fs(params=all_params, last_step=12, input_data=input_data)
d = {}
for p in all_params:
# print(f"p={p}")
Expand All @@ -44,11 +45,12 @@ def _build_proc_ref():
yaml.dump(d, outfile, sort_keys=True)


def test_forcings_proc():
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_proc(input_data):
with open(earthkit_test_data_file(os.path.join("forcings", "proc.yaml")), "r") as f:
ref = yaml.safe_load(f)

ds, _ = load_forcings_fs(params=all_params, last_step=12)
ds, _ = load_forcings_fs(params=all_params, last_step=12, input_data=input_data)

for p in all_params:
f = ds.sel(param=p, valid_datetime="2020-05-13T18:00:00")
Expand All @@ -60,9 +62,10 @@ def test_forcings_proc():
assert np.isclose(np.nanmean(v), r["mean"])


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.parametrize("param,coord", [("latitude", "lat"), ("longitude", "lon")])
def test_forcings_proc_latlon(param, coord):
ds, _ = load_forcings_fs(params=all_params, last_step=12)
def test_forcings_proc_latlon(input_data, param, coord):
ds, _ = load_forcings_fs(params=all_params, last_step=12, input_data=input_data)

latlon = ds[0].to_latlon(flatten=True)
coord = latlon[coord]
Expand Down
10 changes: 6 additions & 4 deletions tests/forcings/test_forcings_sel.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from forcings_fixtures import load_forcings_fs # noqa: E402


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.parametrize(
"params,expected_meta",
[
Expand All @@ -43,8 +44,8 @@
(dict(INVALIDKEY="sin_logitude"), []),
],
)
def test_forcings_sel_single_file_1(params, expected_meta):
ds, _ = load_forcings_fs()
def test_forcings_sel_single_file_1(input_data, params, expected_meta):
ds, _ = load_forcings_fs(input_data=input_data)

g = ds.sel(**params)
assert len(g) == len(expected_meta)
Expand All @@ -54,8 +55,9 @@ def test_forcings_sel_single_file_1(params, expected_meta):
return


def test_forcings_sel_single_file_as_dict():
ds, _ = load_forcings_fs()
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_sel_single_file_as_dict(input_data):
ds, _ = load_forcings_fs(input_data=input_data)

g = ds.sel(
{
Expand Down
30 changes: 18 additions & 12 deletions tests/forcings/test_forcings_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,18 @@
from forcings_fixtures import load_forcings_fs # noqa: E402


def test_forcings_single_index_bad():
ds, _ = load_forcings_fs()
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_single_index_bad(input_data):
ds, _ = load_forcings_fs(input_data=input_data)
idx = len(ds) + 10
with pytest.raises(IndexError):
ds[idx]


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.parametrize("index", [0, 2, 95, -1, -96])
def test_forcings_single_index(index):
ds, md = load_forcings_fs()
def test_forcings_single_index(input_data, index):
ds, md = load_forcings_fs(input_data=input_data)
num = len(ds)
r = ds[index]

Expand All @@ -43,6 +45,7 @@ def test_forcings_single_index(index):
assert len(ds) == num


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.parametrize(
"indexes",
[
Expand All @@ -54,8 +57,8 @@ def test_forcings_single_index(index):
slice(91, None),
],
)
def test_forcings_slice(indexes):
ds, md = load_forcings_fs()
def test_forcings_slice(input_data, indexes):
ds, md = load_forcings_fs(input_data=input_data)
num = len(ds)
r = ds[indexes]

Expand All @@ -71,6 +74,7 @@ def test_forcings_slice(indexes):
assert len(ds) == num


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.parametrize(
"indexes1,indexes2",
[
Expand All @@ -79,8 +83,8 @@ def test_forcings_slice(indexes):
((1, 16, 5, 9), (1, 3)),
],
)
def test_forcings_array_indexing(indexes1, indexes2):
ds, md = load_forcings_fs()
def test_forcings_array_indexing(input_data, indexes1, indexes2):
ds, md = load_forcings_fs(input_data=input_data)

# first subset
r = ds[indexes1]
Expand All @@ -97,6 +101,7 @@ def test_forcings_array_indexing(indexes1, indexes2):
assert r1.metadata(["valid_datetime", "param"]) == ref_md


@pytest.mark.parametrize("input_data", ["grib", "latlon"])
@pytest.mark.skip(reason="Index range checking disabled")
@pytest.mark.parametrize(
"indexes",
Expand All @@ -106,14 +111,15 @@ def test_forcings_array_indexing(indexes1, indexes2):
((1, 16, 5, 9), (1, 3)),
],
)
def test_forcings_array_indexing_bad(indexes):
ds, _ = load_forcings_fs()
def test_forcings_array_indexing_bad(input_data, indexes):
ds, _ = load_forcings_fs(input_data=input_data)
with pytest.raises(IndexError):
ds[indexes]


def test_forcings_fieldlist_iterator():
ds, md = load_forcings_fs()
@pytest.mark.parametrize("input_data", ["grib", "latlon"])
def test_forcings_fieldlist_iterator(input_data):
ds, md = load_forcings_fs(input_data=input_data)
# sn = ds.metadata(["valid_datetime", "param"])
sn = md
assert len(sn) == len(ds)
Expand Down
75 changes: 75 additions & 0 deletions tests/forcings/test_forcings_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import os
import sys

import pytest

from earthkit.data import from_source
from earthkit.data.testing import earthkit_examples_file
from earthkit.data.testing import earthkit_test_data_file
Expand Down Expand Up @@ -115,6 +117,79 @@ def test_forcings_3():
assert f.metadata("param") == r[1]


@pytest.mark.parametrize("lat_key,lon_key", [("latitudes", "longitudes"), ("latitude", "longitude")])
@pytest.mark.parametrize(
"filename", ["t_time_series.grib", "rgg_small_subarea_cellarea_ref.grib", "mercator.grib"]
)
def test_forcings_from_lat_lon_core(lat_key, lon_key, filename):
sample = from_source("file", earthkit_test_data_file(filename))

dates = [
datetime.datetime(2020, 12, 21, 12, 0),
datetime.datetime(2020, 12, 21, 15, 0),
datetime.datetime(2020, 12, 21, 18, 0),
datetime.datetime(2020, 12, 21, 21, 0),
datetime.datetime(2020, 12, 23, 12, 0),
]

params = all_params

ll = sample[0].to_latlon() # flatten=True is important here
lats = ll["lat"]
lons = ll["lon"]

d = {}
d[lat_key] = lats
d[lon_key] = lons

ds = from_source("forcings", **d, date=dates, param=params)

num = len(dates) * len(params)
assert len(ds) == num

ref = [(d, p) for d, p in itertools.product(dates, params)]
assert len(ds) == len(ref)
for f, r in zip(ds, ref):
assert f.metadata("valid_datetime") == r[0].isoformat()
assert f.metadata("param") == r[1]
assert f.to_numpy().shape == lats.shape


def test_forcings_from_lat_lon_bad():
sample = from_source("file", earthkit_test_data_file("t_time_series.grib"))

params = all_params

ll = sample[0].to_latlon()
lats = ll["lat"]
lons = ll["lon"]

with pytest.raises(ValueError):
from_source(
"forcings",
latitudes=lats,
param=params,
)
with pytest.raises(ValueError):
from_source(
"forcings",
latitude=lats,
param=params,
)
with pytest.raises(ValueError):
from_source(
"forcings",
longitudes=lons,
param=params,
)
with pytest.raises(ValueError):
from_source(
"forcings",
longitude=lons,
param=params,
)


if __name__ == "__main__":
from earthkit.data.testing import main

Expand Down
Loading

0 comments on commit 527ce1b

Please sign in to comment.