Skip to content
This repository has been archived by the owner on Jan 10, 2025. It is now read-only.

Commit

Permalink
cleanup config
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Mar 7, 2024
1 parent be74713 commit b1e2069
Showing 1 changed file with 48 additions and 91 deletions.
139 changes: 48 additions & 91 deletions ecml_tools/create/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,34 @@
LOG = logging.getLogger(__name__)


def _get_first_key_if_dict(x):
if isinstance(x, str):
return x
return list(x.keys())[0]


def ensure_element_in_list(lst, elt, index):
if elt in lst:
assert lst[index] == elt
return lst

_lst = [_get_first_key_if_dict(d) for d in lst]
if elt in _lst:
assert _lst[index] == elt
return lst

return lst[:index] + [elt] + lst[index:]


def check_dict_value_and_set(dic, key, value):
if key in dic:
if dic[key] == value:
return
raise ValueError(f"Cannot use {key}={dic[key]}. Must use {value}.")
print(f"Setting {key}={value} in config")
dic[key] = value


class DictObj(dict):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -111,11 +139,12 @@ def statistics(self):


class LoadersConfig(Config):
purpose = "undefined"

def __init__(self, config, *args, **kwargs):
super().__init__(config, *args, **kwargs)

# TODO: should use a json schema to validate the config

if "description" not in self:
raise ValueError("Must provide a description in the config.")

Expand All @@ -129,6 +158,8 @@ def __init__(self, config, *args, **kwargs):

if "dates" in self.output:
raise ValueError("Obsolete: Dates should not be provided in output config.")
if not isinstance(self.dates, dict):
raise ValueError(f"Dates must be a dict. Got {self.dates}")

# deprecated/obsolete
if "order" in self.output:
Expand All @@ -138,17 +169,24 @@ def __init__(self, config, *args, **kwargs):
if "loop" in self:
raise ValueError(f"Do not use 'loop'. Use dates instead. {list(self.keys())}")

self.normalise()

if "licence" not in self:
raise ValueError("Must provide a licence in the config.")
self.licence = "unknown"
print(f"❗ Setting licence={self.licence} because it was not provided.")
if "copyright" not in self:
raise ValueError("Must provide a copyright in the config.")
self.copyright = "unknown"
print(f"❗ Setting copyright={self.copyright} because it was not provided.")

if not isinstance(self.dates, dict):
raise ValueError(f"Dates must be a dict. Got {self.dates}")
check_dict_value_and_set(self.output, "flatten_grid", True)
check_dict_value_and_set(self.output, "ensemble_dimension", 2)

assert isinstance(self.output.order_by, (list, tuple)), self.output.order_by
self.output.order_by = ensure_element_in_list(self.output.order_by, "number", self.output.ensemble_dimension)

order_by = self.output.order_by
assert len(order_by) == 3, order_by
assert _get_first_key_if_dict(order_by[0]) == "valid_datetime", order_by
assert _get_first_key_if_dict(order_by[2]) == "number", order_by

def normalise(self):
if "order_by" in self.output:
self.output.order_by = normalize_order_by(self.output.order_by)

Expand All @@ -158,46 +196,7 @@ def normalise(self):
self.reading_chunks = self.get("reading_chunks")
assert "flatten_values" not in self.output
assert "flatten_grid" in self.output, self.output

assert "statistics" in self.output
statistics_axis_name = self.output.statistics
statistics_axis = -1
for i, k in enumerate(self.output.order_by):
if k == statistics_axis_name:
statistics_axis = i

assert statistics_axis >= 0, f"{self.output.statistics} not in {list(self.output.order_by.keys())}"

self.statistics_names = self.output.order_by[statistics_axis_name]

# TODO: consider 2D grid points
self.statistics_axis = statistics_axis

@classmethod
def _get_first_key_if_dict(cls, x):
if isinstance(x, str):
return x
return list(x.keys())[0]

def check_dict_value_and_set(self, dic, key, value):
if key in dic:
if dic[key] == value:
return
raise ValueError(f"Cannot use {key}={dic[key]} with {self.purpose} purpose. Must use {value}.")
print(f"Setting {key}={value} because purpose={self.purpose}")
dic[key] = value

def ensure_element_in_list(self, lst, elt, index):
if elt in lst:
assert lst[index] == elt
return lst

_lst = [self._get_first_key_if_dict(d) for d in lst]
if elt in _lst:
assert _lst[index] == elt
return lst

return lst[:index] + [elt] + lst[index:]

def get_serialisable_dict(self):
return _prepare_serialisation(self)
Expand All @@ -206,42 +205,6 @@ def get_variables_names(self):
return self.output.order_by[self.output.statistics]


class UnknownPurposeConfig(LoadersConfig):
purpose = "unknown"

def normalise(self):
self.output.flatten_grid = self.output.get("flatten_grid", False)
self.output.ensemble_dimension = self.output.get("ensemble_dimension", False)
super().normalise() # must be called last


class AifsPurposeConfig(LoadersConfig):
purpose = "aifs"

def normalise(self):
if "licence" not in self:
self.licence = "CC-BY-4.0"
print(f"❗ Setting licence={self.licence} because it was not provided.")
if "copyright" not in self:
self.copyright = "ecmwf"
print(f"❗ Setting copyright={self.copyright} because it was not provided.")

self.check_dict_value_and_set(self.output, "flatten_grid", True)
self.check_dict_value_and_set(self.output, "ensemble_dimension", 2)

assert isinstance(self.output.order_by, (list, tuple)), self.output.order_by
self.output.order_by = self.ensure_element_in_list(
self.output.order_by, "number", self.output.ensemble_dimension
)

order_by = self.output.order_by
assert len(order_by) == 3, order_by
assert self._get_first_key_if_dict(order_by[0]) == "valid_datetime", order_by
assert self._get_first_key_if_dict(order_by[2]) == "number", order_by

super().normalise() # must be called last


def _prepare_serialisation(o):
if isinstance(o, dict):
dic = {}
Expand Down Expand Up @@ -272,21 +235,15 @@ def _prepare_serialisation(o):
return str(o)


CONFIGS = {
None: UnknownPurposeConfig,
"aifs": AifsPurposeConfig,
}


def loader_config(config):
config = Config(config)
obj = CONFIGS[config.get("purpose")](config)
obj = LoadersConfig(config)

# yaml round trip to check that serialisation works as expected
copy = obj.get_serialisable_dict()
copy = yaml.load(yaml.dump(copy), Loader=yaml.SafeLoader)
copy = Config(copy)
copy = CONFIGS[config.get("purpose")](config)
copy = LoadersConfig(config)
assert yaml.dump(obj) == yaml.dump(copy), (obj, copy)

return copy
Expand Down

0 comments on commit b1e2069

Please sign in to comment.