Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/registry #35

Merged
merged 15 commits into from
Nov 1, 2024
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ Keep it human-readable, your future self will thank you!
## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...HEAD)

### Added
- Add supporting_arrays to checkpoints
- Add factories registry
- Optional renaming of subcommands via `command` attribute [#34](https://github.com/ecmwf/anemoi-utils/pull/34)


## [0.4.1](https://github.com/ecmwf/anemoi-utils/compare/0.4.0...0.4.1) - 2024-10-23

## Fixed
Expand Down
4 changes: 2 additions & 2 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,15 @@

project = "Anemoi Utils"

author = "ECMWF"
author = "Anemoi contributors"

year = datetime.datetime.now().year
if year == 2024:
years = "2024"
else:
years = "2024-%s" % (year,)

copyright = "%s, ECMWF" % (years,)
copyright = "%s, Anemoi contributors" % (years,)

try:
from anemoi.utils._version import __version__
Expand Down
4 changes: 3 additions & 1 deletion src/anemoi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts.
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
Expand Down
83 changes: 74 additions & 9 deletions src/anemoi/utils/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
DEFAULT_FOLDER = "anemoi-metadata"


def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
def has_metadata(path: str, *, name: str = DEFAULT_NAME) -> bool:
"""Check if a checkpoint file has a metadata file

Parameters
Expand All @@ -49,13 +49,26 @@ def has_metadata(path: str, name: str = DEFAULT_NAME) -> bool:
return False


def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict:
def metadata_root(path: str, *, name: str = DEFAULT_NAME) -> bool:

with zipfile.ZipFile(path, "r") as f:
for b in f.namelist():
if os.path.basename(b) == name:
return os.path.dirname(b)
raise ValueError(f"Could not find '{name}' in {path}.")


def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAME) -> dict:
"""Load metadata from a checkpoint file

Parameters
----------
path : str
The path to the checkpoint file

supporting_arrays: bool, optional
If True, the function will return a dictionary with the supporting arrays

name : str, optional
The name of the metadata file in the zip archive

Expand All @@ -79,12 +92,29 @@ def load_metadata(path: str, name: str = DEFAULT_NAME) -> dict:

if metadata is not None:
with zipfile.ZipFile(path, "r") as f:
return json.load(f.open(metadata, "r"))
metadata = json.load(f.open(metadata, "r"))
if supporting_arrays:
metadata["supporting_arrays"] = load_supporting_arrays(f, metadata.get("supporting_arrays", {}))
return metadata, supporting_arrays

return metadata
else:
raise ValueError(f"Could not find '{name}' in {path}.")


def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None:
def load_supporting_arrays(zipf, entries) -> dict:
import numpy as np

supporting_arrays = {}
for key, entry in entries.items():
supporting_arrays[key] = np.frombuffer(
zipf.read(entry["path"]),
dtype=entry["dtype"],
).reshape(entry["shape"])
return supporting_arrays


def save_metadata(path, metadata, *, supporting_arrays=None, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> None:
"""Save metadata to a checkpoint file

Parameters
Expand All @@ -93,6 +123,8 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N
The path to the checkpoint file
metadata : JSON
A JSON serializable object
supporting_arrays: dict, optional
A dictionary of supporting NumPy arrays
name : str, optional
The name of the metadata file in the zip archive
folder : str, optional
Expand All @@ -118,19 +150,41 @@ def save_metadata(path, metadata, name=DEFAULT_NAME, folder=DEFAULT_FOLDER) -> N

directory = list(directories)[0]

LOG.info("Adding extra information to checkpoint %s", path)
LOG.info("Saving metadata to %s/%s/%s", directory, folder, name)

metadata = metadata.copy()
if supporting_arrays is not None:
metadata["supporting_arrays_paths"] = {
key: dict(path=f"{directory}/{folder}/{key}.numpy", shape=value.shape, dtype=str(value.dtype))
for key, value in supporting_arrays.items()
}
else:
metadata["supporting_arrays_paths"] = {}

zipf.writestr(
f"{directory}/{folder}/{name}",
json.dumps(metadata),
)

for name, entry in metadata["supporting_arrays_paths"].items():
value = supporting_arrays[name]
LOG.info(
"Saving supporting array `%s` to %s (shape=%s, dtype=%s)",
name,
entry["path"],
entry["shape"],
entry["dtype"],
)
zipf.writestr(entry["path"], value.tobytes())


def _edit_metadata(path, name, callback):
def _edit_metadata(path, name, callback, supporting_arrays=None):
new_path = f"{path}.anemoi-edit-{time.time()}-{os.getpid()}.tmp"

found = False

directory = None
with TemporaryDirectory() as temp_dir:
zipfile.ZipFile(path, "r").extractall(temp_dir)
total = 0
Expand All @@ -141,10 +195,21 @@ def _edit_metadata(path, name, callback):
if f == name:
found = True
callback(full)
directory = os.path.dirname(full)

if not found:
raise ValueError(f"Could not find '{name}' in {path}")

if supporting_arrays is not None:

for key, entry in supporting_arrays.items():
value = entry.tobytes()
fname = os.path.join(directory, f"{key}.numpy")
os.makedirs(os.path.dirname(fname), exist_ok=True)
with open(fname, "wb") as f:
f.write(value)
total += 1

with zipfile.ZipFile(new_path, "w", zipfile.ZIP_DEFLATED) as zipf:
with tqdm.tqdm(total=total, desc="Rebuilding checkpoint") as pbar:
for root, dirs, files in os.walk(temp_dir):
Expand All @@ -158,7 +223,7 @@ def _edit_metadata(path, name, callback):
LOG.info("Updated metadata in %s", path)


def replace_metadata(path, metadata, name=DEFAULT_NAME):
def replace_metadata(path, metadata, supporting_arrays=None, *, name=DEFAULT_NAME):

if not isinstance(metadata, dict):
raise ValueError(f"metadata must be a dict, got {type(metadata)}")
Expand All @@ -170,14 +235,14 @@ def callback(full):
with open(full, "w") as f:
json.dump(metadata, f)

_edit_metadata(path, name, callback)
return _edit_metadata(path, name, callback, supporting_arrays)


def remove_metadata(path, name=DEFAULT_NAME):
def remove_metadata(path, *, name=DEFAULT_NAME):

LOG.info("Removing metadata '%s' from %s", name, path)

def callback(full):
os.remove(full)

_edit_metadata(path, name, callback)
return _edit_metadata(path, name, callback)
5 changes: 3 additions & 2 deletions src/anemoi/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def check_config_mode(name="settings.toml", secrets_name=None, secrets=None) ->
CHECKED[name] = True


def find(metadata, what, result=None):
def find(metadata, what, result=None, *, select: callable = None):
if result is None:
result = []

Expand All @@ -369,7 +369,8 @@ def find(metadata, what, result=None):

if isinstance(metadata, dict):
if what in metadata:
result.append(metadata[what])
if select is None or select(metadata[what]):
result.append(metadata[what])

for k, v in metadata.items():
find(v, what, result)
Expand Down
98 changes: 98 additions & 0 deletions src/anemoi/utils/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
# (C) Copyright 2024 Anemoi contributors.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
#
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.


import importlib
import logging
import os
import sys

import entrypoints

LOG = logging.getLogger(__name__)


class Wrapper:
"""A wrapper for the registry"""

def __init__(self, name, registry):
self.name = name
self.registry = registry

def __call__(self, factory):
self.registry.register(self.name, factory)
return factory


class Registry:
"""A registry of factories"""

def __init__(self, package):

self.package = package
self.registered = {}
self.kind = package.split(".")[-1]

def register(self, name: str, factory: callable = None):

if factory is None:
return Wrapper(name, self)

self.registered[name] = factory

def _load(self, file):
name, _ = os.path.splitext(file)
try:
importlib.import_module(f".{name}", package=self.package)
except Exception:
LOG.warning(f"Error loading filter '{self.package}.{name}'", exc_info=True)

def lookup(self, name: str) -> callable:
if name in self.registered:
return self.registered[name]

directory = sys.modules[self.package].__path__[0]

for file in os.listdir(directory):

if file[0] == ".":
continue

if file == "__init__.py":
continue

full = os.path.join(directory, file)
if os.path.isdir(full):
if os.path.exists(os.path.join(full, "__init__.py")):
self._load(file)
continue

if file.endswith(".py"):
self._load(file)

entrypoint_group = f"anemoi.{self.kind}"
for entry_point in entrypoints.get_group_all(entrypoint_group):
if entry_point.name == name:
if name in self.registered:
LOG.warning(
f"Overwriting builtin '{name}' from {self.package} with plugin '{entry_point.module_name}'"
)
self.registered[name] = entry_point.load()

if name not in self.registered:
raise ValueError(f"Cannot load '{name}' from {self.package}")

return self.registered[name]

def create(self, name: str, *args, **kwargs):
factory = self.lookup(name)
return factory(*args, **kwargs)

def __call__(self, name: str, *args, **kwargs):
return self.create(name, *args, **kwargs)
Loading