From 07525bcd1160a433edd61dc8203beeb8e91bafa3 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 26 Oct 2024 20:39:10 +0000 Subject: [PATCH 01/12] add supporting_arrays to checkpoints --- CHANGELOG.md | 3 ++ src/anemoi/utils/checkpoints.py | 50 +++++++++++++++++++++++++++++---- 2 files changed, 47 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98b2ada..0c70b37 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD) +### Added +- Add supporting_arrays to checkpoints + ## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23 ## Fixed diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 085e4d1..9fcd501 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -27,7 +27,7 @@ DEFAULT_FOLDER = "anemoi-metadata" -def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool: +def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool: """Check if a checkpoint file has a metadata file Parameters @@ -49,13 +49,17 @@ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool: return False -def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict: +def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAME) -> dict: """Load metadata from a checkpoint file Parameters ---------- path : str The path to the checkpoint file + + supporting_arrays: bool, optional + If True, the function will return a dictionary with the supporting arrays + name : str, optional The name of the metadata file in the zip archive @@ -79,12 +83,23 @@ def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict: if metadata is not None: with zipfile.ZipFile(path, "r") as f: - return json.load(f.open(metadata, "r")) + metadata = json.load(f.open(metadata, "r")) + if supporting_arrays: + import numpy as np + + supporting_arrays = {} + for key, entry in metadata.get("supporting_arrays", {}): + supporting_arrays[key] = np.frombuffer( + f.read(entry["path"]), + dtype=entry["dtype"], + ).reshape(entry["shape"]) + metadata["supporting_arrays"] = supporting_arrays + return metadata else: raise ValueError(f"Could not find '{name}' in {path}.") -def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None: +def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None: """Save metadata to a checkpoint file Parameters @@ -93,6 +108,8 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N The path to the checkpoint file metadata : JSON A JSON serializable object + supporting_arrays: dict, optional + A dictionary of supporting NumPy arrays name : str, optional The name of the metadata file in the zip archive folder : str, optional @@ -118,13 +135,34 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N directory = list(directories)[0] + LOG.info("Adding extra information to checkpoint %s", path) LOG.info("Saving metadata to %s/%s/%s", directory, folder, name) + metadata = metadata.copy() + if supporting_arrays is not None: + metadata["supporting_arrays"] = { + key: dict(path=f"{directory}/{folder}/{key}.numpy", shape=value.shape, dtype=str(value.dtype)) + for key, value in supporting_arrays.items() + } + else: + metadata["supporting_arrays"] = {} + zipf.writestr( f"{directory}/{folder}/{name}", json.dumps(metadata), ) + for name, entry in metadata["supporting_arrays"].items(): + value = supporting_arrays[name] + LOG.info( + "Saving supporting array `%s` to %s (shape=%s, dtype=%s)", + name, + entry["path"], + entry["shape"], + entry["dtype"], + ) + zipf.writestr(entry["path"], value.tobytes()) + def _edit_metadata(path, name, callback): new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp" @@ -158,7 +196,7 @@ def _edit_metadata(path, name, callback): LOG.info("Updated metadata in %s", path) -def replace_metadata(path, metadata, name=DEFAULT_NAME): +def replace_metadata(path, metadata, *, name=DEFAULT_NAME): if not isinstance(metadata, dict): raise ValueError(f"metadata must be a dict, got {type(metadata)}") @@ -173,7 +211,7 @@ def callback(full): _edit_metadata(path, name, callback) -def remove_metadata(path, name=DEFAULT_NAME): +def remove_metadata(path, *, name=DEFAULT_NAME): LOG.info("Removing metadata '%s' from %s", name, path) From f4cfe9a672255f4b39894143142580e7cb82cb6e Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 26 Oct 2024 20:39:27 +0000 Subject: [PATCH 02/12] add supporting_arrays to checkpoints --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 0c70b37..f93d9ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ Keep it human-readable, your future self will thank you! ## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD) ### Added -- Add supporting_arrays to checkpoints +- Add supporting_arrays to checkpoints ## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23 From ee5d47de58a06a76f2f28f8cdb1d766fdd986d33 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 26 Oct 2024 20:45:33 +0000 Subject: [PATCH 03/12] add supporting_arrays to checkpoints --- src/anemoi/utils/checkpoints.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 9fcd501..9783178 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -94,6 +94,8 @@ def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAM dtype=entry["dtype"], ).reshape(entry["shape"]) metadata["supporting_arrays"] = supporting_arrays + return metadata, supporting_arrays + return metadata else: raise ValueError(f"Could not find '{name}' in {path}.") From 736a75830dcae6a104eddb8055f401bd75839d5a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Sat, 26 Oct 2024 20:49:27 +0000 Subject: [PATCH 04/12] add supporting_arrays to checkpoints --- src/anemoi/utils/checkpoints.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 9783178..aa94cce 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -88,7 +88,7 @@ def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAM import numpy as np supporting_arrays = {} - for key, entry in metadata.get("supporting_arrays", {}): + for key, entry in metadata.get("supporting_arrays", {}).items(): supporting_arrays[key] = np.frombuffer( f.read(entry["path"]), dtype=entry["dtype"], From c5cd4544f862b231ab841c69a46faf5f83ba04eb Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 28 Oct 2024 21:32:26 +0000 Subject: [PATCH 05/12] update --- src/anemoi/utils/checkpoints.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index aa94cce..8b2be7b 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -142,19 +142,19 @@ def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, metadata = metadata.copy() if supporting_arrays is not None: - metadata["supporting_arrays"] = { + metadata["supporting_arrays_paths"] = { key: dict(path=f"{directory}/{folder}/{key}.numpy", shape=value.shape, dtype=str(value.dtype)) for key, value in supporting_arrays.items() } else: - metadata["supporting_arrays"] = {} + metadata["supporting_arrays_paths"] = {} zipf.writestr( f"{directory}/{folder}/{name}", json.dumps(metadata), ) - for name, entry in metadata["supporting_arrays"].items(): + for name, entry in metadata["supporting_arrays_paths"].items(): value = supporting_arrays[name] LOG.info( "Saving supporting array `%s` to %s (shape=%s, dtype=%s)", From 1337c9318097a722583cf7b426eba23f29173c4a Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 28 Oct 2024 21:47:32 +0000 Subject: [PATCH 06/12] add supporting arrays --- src/anemoi/utils/checkpoints.py | 37 ++++++++++++++++++++++----------- 1 file changed, 25 insertions(+), 12 deletions(-) diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 8b2be7b..451713c 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -85,15 +85,7 @@ def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAM with zipfile.ZipFile(path, "r") as f: metadata = json.load(f.open(metadata, "r")) if supporting_arrays: - import numpy as np - - supporting_arrays = {} - for key, entry in metadata.get("supporting_arrays", {}).items(): - supporting_arrays[key] = np.frombuffer( - f.read(entry["path"]), - dtype=entry["dtype"], - ).reshape(entry["shape"]) - metadata["supporting_arrays"] = supporting_arrays + metadata["supporting_arrays"] = load_supporting_arrays(f, metadata.get("supporting_arrays", {})) return metadata, supporting_arrays return metadata @@ -101,6 +93,18 @@ def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAM raise ValueError(f"Could not find '{name}' in {path}.") +def load_supporting_arrays(zipf, entries) -> dict: + import numpy as np + + supporting_arrays = {} + for key, entry in entries.items(): + supporting_arrays[key] = np.frombuffer( + zipf.read(entry["path"]), + dtype=entry["dtype"], + ).reshape(entry["shape"]) + return supporting_arrays + + def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None: """Save metadata to a checkpoint file @@ -166,7 +170,7 @@ def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, zipf.writestr(entry["path"], value.tobytes()) -def _edit_metadata(path, name, callback): +def _edit_metadata(path, name, callback, supporting_arrays=None): new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp" found = False @@ -185,6 +189,15 @@ def _edit_metadata(path, name, callback): if not found: raise ValueError(f"Could not find '{name}' in {path}") + if supporting_arrays is not None: + for key, entry in supporting_arrays.items(): + value = entry.tobytes() + fname = os.path.join(temp_dir, key) + os.makedirs(os.path.dirname(fname), exist_ok=True) + with open(fname, "wb") as f: + f.write(value) + total += 1 + with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf: with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar: for root, dirs, files in os.walk(temp_dir): @@ -198,7 +211,7 @@ def _edit_metadata(path, name, callback): LOG.info("Updated metadata in %s", path) -def replace_metadata(path, metadata, *, name=DEFAULT_NAME): +def replace_metadata(path, metadata, supporting_arrays=None, *, name=DEFAULT_NAME): if not isinstance(metadata, dict): raise ValueError(f"metadata must be a dict, got {type(metadata)}") @@ -210,7 +223,7 @@ def callback(full): with open(full, "w") as f: json.dump(metadata, f) - _edit_metadata(path, name, callback) + _edit_metadata(path, name, callback, supporting_arrays) def remove_metadata(path, *, name=DEFAULT_NAME): From 7261743260781c1159c550f9ba1d6cc3d23db166 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Mon, 28 Oct 2024 22:27:06 +0000 Subject: [PATCH 07/12] add supporting arrays --- src/anemoi/utils/checkpoints.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 451713c..83d386d 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -49,6 +49,15 @@ def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool: return False +def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> bool: + + with zipfile.ZipFile(path, "r") as f: + for b in f.namelist(): + if os.path.basename(b) == name: + return os.path.dirname(b) + raise ValueError(f"Could not find '{name}' in {path}.") + + def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAME) -> dict: """Load metadata from a checkpoint file @@ -175,6 +184,7 @@ def _edit_metadata(path, name, callback, supporting_arrays=None): found = False + directory = None with TemporaryDirectory() as temp_dir: zipfile.ZipFile(path, "r").extractall(temp_dir) total = 0 @@ -185,14 +195,16 @@ def _edit_metadata(path, name, callback, supporting_arrays=None): if f == name: found = True callback(full) + directory = os.path.dirname(full) if not found: raise ValueError(f"Could not find '{name}' in {path}") if supporting_arrays is not None: + for key, entry in supporting_arrays.items(): value = entry.tobytes() - fname = os.path.join(temp_dir, key) + fname = os.path.join(directory, f"{key}.numpy") os.makedirs(os.path.dirname(fname), exist_ok=True) with open(fname, "wb") as f: f.write(value) @@ -223,7 +235,7 @@ def callback(full): with open(full, "w") as f: json.dump(metadata, f) - _edit_metadata(path, name, callback, supporting_arrays) + return _edit_metadata(path, name, callback, supporting_arrays) def remove_metadata(path, *, name=DEFAULT_NAME): @@ -233,4 +245,4 @@ def remove_metadata(path, *, name=DEFAULT_NAME): def callback(full): os.remove(full) - _edit_metadata(path, name, callback) + return _edit_metadata(path, name, callback) From 54afe53dc4c248eda35035705ce4e6959ad14b2b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Tue, 29 Oct 2024 20:04:18 +0000 Subject: [PATCH 08/12] copyright notice --- docs/conf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index 7760336..5d812ac 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,7 @@ project = "Anemoi Utils" -author = "ECMWF" +author = "Anemoi contributors" year = datetime.datetime.now().year if year == 2024: @@ -37,7 +37,7 @@ else: years = "2024-%s" % (year,) -copyright = "%s, ECMWF" % (years,) +copyright = "%s, Anemoi contributors" % (years,) try: from anemoi.utils._version import __version__ From 26c5ee8ef976cde5dec76e5f5b8ccd07ee02722b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 31 Oct 2024 16:46:10 +0000 Subject: [PATCH 09/12] add registry --- src/anemoi/utils/__init__.py | 4 +- src/anemoi/utils/registry.py | 79 ++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 1 deletion(-) create mode 100644 src/anemoi/utils/registry.py diff --git a/src/anemoi/utils/__init__.py b/src/anemoi/utils/__init__.py index 9733be2..7b9efcd 100644 --- a/src/anemoi/utils/__init__.py +++ b/src/anemoi/utils/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py new file mode 100644 index 0000000..2d5e939 --- /dev/null +++ b/src/anemoi/utils/registry.py @@ -0,0 +1,79 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + + +import importlib +import logging +import os +import sys + +import entrypoints + +LOG = logging.getLogger(__name__) + + +class Registry: + """A registry of factories""" + + def __init__(self, package): + + self.package = package + self.registered = {} + self.kind = package.split(".")[-1] + + def register(self, name: str, factory: callable): + self.registered[name] = factory + + def _load(self, file): + name, _ = os.path.splitext(file) + try: + importlib.import_module(f".{name}", package=self.package) + except Exception: + LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True) + + def lookup(self, name: str) -> callable: + if name in self.registered: + return self.registered[name] + + directory = sys.modules[self.package].__path__[0] + + for file in os.listdir(directory): + + if file[0] == ".": + continue + + if file == "__init__.py": + continue + + full = os.path.join(directory, file) + if os.path.isdir(full): + if os.path.exists(os.path.join(full, "__init__.py")): + self._load(file) + continue + + if file.endswith(".py"): + self._load(file) + + entrypoint_group = f"anemoi.{self.kind}" + for entry_point in entrypoints.get_group_all(entrypoint_group): + if entry_point.name == name: + if name in self.registered: + LOG.warning( + f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'" + ) + self.registered[name] = entry_point.load() + + if name not in self.registered: + raise ValueError(f"Cannot load '{name}' from {self.package}") + + return self.registered[name] + + def create(self, name: str, *args, **kwargs): + factory = self.lookup(name) + return factory(*args, **kwargs) From bcc7cfa105720135b1199c0ffe6fe1242cb8899b Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 31 Oct 2024 16:55:16 +0000 Subject: [PATCH 10/12] update registry --- CHANGELOG.md | 1 + src/anemoi/utils/registry.py | 10 +++++++++- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d0eeeca..d5650e1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ Keep it human-readable, your future self will thank you! ### Added - Add supporting_arrays to checkpoints +- Add factories registry ## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23 diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 2d5e939..9fad5c5 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -27,7 +27,12 @@ def __init__(self, package): self.registered = {} self.kind = package.split(".")[-1] - def register(self, name: str, factory: callable): + def register(self, name: str, factory: callable = None): + + # Decorator version + if factory is None: + return lambda f: self.register(name, f) + self.registered[name] = factory def _load(self, file): @@ -77,3 +82,6 @@ def lookup(self, name: str) -> callable: def create(self, name: str, *args, **kwargs): factory = self.lookup(name) return factory(*args, **kwargs) + + def __call__(self, name: str, *args, **kwargs): + return self.create(name, *args, **kwargs) From 362fe70206d3aec3bf790a37998a0a9b63ecee43 Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Thu, 31 Oct 2024 17:16:55 +0000 Subject: [PATCH 11/12] add registry --- src/anemoi/utils/registry.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 9fad5c5..9d4bcce 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -18,6 +18,18 @@ LOG = logging.getLogger(__name__) +class Wrapper: + """A wrapper for the registry""" + + def __init__(self, name, registry): + self.name = name + self.registry = registry + + def __call__(self, factory): + self.registry.register(self.name, factory) + return factory + + class Registry: """A registry of factories""" @@ -29,9 +41,8 @@ def __init__(self, package): def register(self, name: str, factory: callable = None): - # Decorator version if factory is None: - return lambda f: self.register(name, f) + return Wrapper(name, self) self.registered[name] = factory From 2e161d9d7590e0bad983f82504ce3a6ff212518f Mon Sep 17 00:00:00 2001 From: Baudouin Raoult Date: Fri, 1 Nov 2024 08:26:13 +0000 Subject: [PATCH 12/12] add select to find function --- src/anemoi/utils/config.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/anemoi/utils/config.py b/src/anemoi/utils/config.py index a6a9cb9..3a9406a 100644 --- a/src/anemoi/utils/config.py +++ b/src/anemoi/utils/config.py @@ -358,7 +358,7 @@ def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) -> CHECKED[name] = True -def find(metadata, what, result=None): +def find(metadata, what, result=None, *, select: callable = None): if result is None: result = [] @@ -369,7 +369,8 @@ def find(metadata, what, result=None): if isinstance(metadata, dict): if what in metadata: - result.append(metadata[what]) + if select is None or select(metadata[what]): + result.append(metadata[what]) for k, v in metadata.items(): find(v, what, result)