Skip to content

Commit

Permalink
Merge pull request #64 from esm-tools/test/timeaverage
Browse files Browse the repository at this point in the history
Tests: Timeaverage Unit tests
  • Loading branch information
pgierz authored Nov 27, 2024
2 parents 3042e47 + c000f42 commit c48f078
Show file tree
Hide file tree
Showing 6 changed files with 301 additions and 43 deletions.
10 changes: 4 additions & 6 deletions src/pymorize/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@
import xarray as xr
from xarray.core.utils import is_scalar

from .timeaverage import _frequency_from_approx_interval
from .dataset_helpers import (
get_time_label,
has_time_axis,
is_datetime_type,
needs_resampling,
)
from .timeaverage import _frequency_from_approx_interval


def _filename_time_range(ds, rule) -> str:
Expand Down Expand Up @@ -102,7 +102,6 @@ def _filename_time_range(ds, rule) -> str:
raise NotImplementedError(f"No implementation for {frequency_str} yet.")



def create_filepath(ds, rule):
"""
Generate a filepath when given an xarray dataset and a rule.
Expand Down Expand Up @@ -190,12 +189,11 @@ def save_dataset(da: xr.DataArray, rule):
return da.to_netcdf(filepath, mode="w", format="NETCDF4")
if isinstance(da, xr.DataArray):
da = da.to_dataset()
file_timespan = rule.file_timespan
frequency_str = _frequency_from_approx_interval(file_timespan)
if not needs_resampling(da, frequency_str):
file_timespan = getattr(rule, "file_timespan", None)
if not needs_resampling(da, file_timespan):
filepath = create_filepath(da, rule)
return da.to_netcdf(filepath, mode="w", format="NETCDF4")
groups = da.resample(time=frequency_str)
groups = da.resample(time=file_timespan)
paths = []
datasets = []
for group_name, group_ds in groups:
Expand Down
62 changes: 35 additions & 27 deletions src/pymorize/timeaverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
_split_by_chunks(dataset: xr.DataArray) -> Tuple[Dict, xr.DataArray]:
Split a large dataset into sub-datasets for each chunk.
_get_time_method(table_id: str) -> str:
Determine the time method based on the table_id string.
_get_time_method(frequency: str) -> str:
Determine the time method based on the frequency string from
rule.data_request_variable.frequency.
_frequency_from_approx_interval(interval: str) -> str:
Convert an interval expressed in days to a frequency string.
Expand All @@ -32,8 +33,10 @@
"""

import functools
import itertools

import numpy as np
import pandas as pd
import xarray as xr

Expand Down Expand Up @@ -63,6 +66,8 @@ def _split_by_chunks(dataset: xr.DataArray):
"""
chunk_slices = {}
logger.info(f"{dataset.chunks=}")
if not dataset.chunks:
raise ValueError("Dataset has no chunks")
if isinstance(dataset, xr.Dataset):
chunker = dataset.chunks
elif isinstance(dataset, xr.DataArray):
Expand All @@ -80,28 +85,26 @@ def _split_by_chunks(dataset: xr.DataArray):
yield (selection, dataset[selection])


def _get_time_method(table_id: str) -> str:
def _get_time_method(frequency: str) -> str:
"""
Determine the time method based on the table_id string.
Determine the time method based on the frequency string from CMIP6 table for
a specific variable (rule.data_request_variable.frequency).
This function checks the ending of the table_id string and returns a corresponding time method.
If the table_id ends with 'Pt', it returns 'INSTANTANEOUS'.
If the table_id ends with 'C' or 'CM', it returns 'CLIMATOLOGY'.
In all other cases, it returns 'MEAN'.
The type of time method influences how the data is processed for time averaging.
Parameters
----------
table_id : str
The table_id string to check.
frequency : str
The frequency string from CMIP6 tables (example: "mon").
Returns
-------
str
The corresponding time method ('INSTANTANEOUS', 'CLIMATOLOGY', or 'MEAN').
"""
if table_id.endswith("Pt"):
if frequency.endswith("Pt"):
return "INSTANTANEOUS"
if table_id.endswith("C") or table_id.endswith("CM"):
if frequency.endswith("C") or frequency.endswith("CM"):
return "CLIMATOLOGY"
return "MEAN"

Expand Down Expand Up @@ -135,23 +138,19 @@ def _frequency_from_approx_interval(interval: str):
("year", lambda x: f"{x}YE", 365),
("month", lambda x: f"{x}ME", 30),
("day", lambda x: f"{x}D", 1),
("hour", lambda x: f"{x}H", 24),
("minute", lambda x: f"{x}min", 24 * 60),
("second", lambda x: f"{x}s", 24 * 60 * 60),
("millisecond", lambda x: f"{x}ms", 24 * 60 * 60 * 1000),
("hour", lambda x: f"{x}H", 1 / 24),
("minute", lambda x: f"{x}min", 1.0 / (24 * 60)),
("second", lambda x: f"{x}s", 1.0 / (24 * 60 * 60)),
("millisecond", lambda x: f"{x}ms", 1.0 / (24 * 60 * 60 * 1000)),
]
try:
interval = float(interval)
except ValueError:
return interval
to_divide = {"decade", "year", "month", "day"}
raise ValueError(f"Invalid interval: {interval}")
isclose = functools.partial(np.isclose, rtol=1e-3)
for name, func, val in notation:
if name in to_divide:
value = interval // val
else:
value = interval * val
if value >= 1:
value = round(value)
if (interval >= val) or isclose(interval, val):
value = round(interval / val)
value = "" if value == 1 else value
return func(value)

Expand All @@ -175,20 +174,27 @@ def _compute_file_timespan(da: xr.DataArray):
The maximum timespan among all chunks of the data array.
"""
if "time" not in da.dims:
raise ValueError("missing the 'time' dimension")
# Check if "time" dimension is empty
if da.time.size == 0:
raise ValueError("no time values in this chunk")
chunks = _split_by_chunks(da)
tmp_file_timespan = []
for i in range(3):
try:
subset_name, subset = next(chunks)
except StopIteration:
pass
break
else:
logger.info(f"{subset_name=}")
logger.info(f"{subset.time.data[-1]=}")
logger.info(f"{subset.time.data[0]=}")
tmp_file_timespan.append(
pd.Timedelta(subset.time.data[-1] - subset.time.data[0]).days
)
if not tmp_file_timespan:
raise ValueError("No chunks found")
file_timespan = max(tmp_file_timespan)
return file_timespan

Expand All @@ -214,15 +220,17 @@ def compute_average(da: xr.DataArray, rule):
The time averaged data array.
"""
file_timespan = _compute_file_timespan(da)
rule.file_timespan = file_timespan
rule.file_timespan = getattr(rule, "file_timespan", None) or pd.Timedelta(
file_timespan, unit="D"
)
drv = rule.data_request_variable
approx_interval = drv.table.approx_interval
approx_interval_in_hours = pd.offsets.Hour(float(approx_interval) * 24)
frequency_str = _frequency_from_approx_interval(approx_interval)
logger.debug(f"{approx_interval=} {frequency_str=}")
# attach the frequency_str to rule, it is referenced when creating file name
rule.frequency_str = frequency_str
time_method = _get_time_method(drv.table.table_id)
time_method = _get_time_method(drv.frequency)
rule.time_method = time_method
if time_method == "INSTANTANEOUS":
ds = da.resample(time=frequency_str).first()
Expand Down
13 changes: 8 additions & 5 deletions tests/integration/test_basic_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from pymorize.cmorizer import CMORizer
from pymorize.logging import logger

def test_init(test_config):
logger.info(f"Processing {test_config}")
with open(test_config, "r") as f:
cfg = yaml.safe_load(f)
CMORizer.from_dict(cfg)
# If we get this far, it was possible to construct
# the object, so this test passes:
assert True

@pytest.mark.skipif(
shutil.which("sbatch") is None, reason="sbatch is not available on this host"
Expand All @@ -22,8 +30,3 @@ def test_process(test_config):
cmorizer.process()


def test_init(test_config):
logger.info(f"Processing {test_config}")
with open(test_config, "r") as f:
cfg = yaml.safe_load(f)
cmorizer = CMORizer.from_dict(cfg)
15 changes: 15 additions & 0 deletions tests/integration/test_fesom_2p6_pimesh_esm_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,21 @@ def test_process_progressive_pipeline(
cmorizer.process()


def test_init(fesom_2p6_pimesh_esm_tools_config, fesom_2p6_pimesh_esm_tools_data):
logger.info(f"Processing {fesom_2p6_pimesh_esm_tools_config}")
with open(fesom_2p6_pimesh_esm_tools_config, "r") as f:
cfg = yaml.safe_load(f)
for rule in cfg["rules"]:
for input in rule["inputs"]:
input["path"] = input["path"].replace(
"REPLACE_ME", str(fesom_2p6_pimesh_esm_tools_data)
)
CMORizer.from_dict(cfg)
# If we get this far, it was possible to construct
# the object, so this test passes:
assert True


def test_process(fesom_2p6_pimesh_esm_tools_config, fesom_2p6_pimesh_esm_tools_data):
logger.info(f"Processing {fesom_2p6_pimesh_esm_tools_config}")
with open(fesom_2p6_pimesh_esm_tools_config, "r") as f:
Expand Down
12 changes: 7 additions & 5 deletions tests/meta/test_xarray_open_mfdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,14 @@
],
)
def test_open_fesom_2p6_pimesh_esm_tools(fesom_2p6_pimesh_esm_tools_data, engine):
matching_files = [
f
for f in (fesom_2p6_pimesh_esm_tools_data / "outdata/fesom/").iterdir()
if f.name.startswith("temp.fesom")
]
assert len(matching_files) > 0
ds = xr.open_mfdataset(
(
f
for f in (fesom_2p6_pimesh_esm_tools_data / "outdata/fesom/").iterdir()
if f.name.startswith("temp")
),
matching_files,
engine=engine,
)
assert isinstance(ds, xr.Dataset)
Expand Down
Loading

0 comments on commit c48f078

Please sign in to comment.