diff --git a/.github/workflows/changelog-pr-update.yml b/.github/workflows/changelog-pr-update.yml index 73cb1eb..f1fc6bc 100644 --- a/.github/workflows/changelog-pr-update.yml +++ b/.github/workflows/changelog-pr-update.yml @@ -1,18 +1,18 @@ -name: Check Changelog Update on PR -on: - pull_request: - types: [assigned, opened, synchronize, reopened, labeled, unlabeled] - branches: - - main - - develop - paths-ignore: - - .pre-commit-config.yaml - - .readthedocs.yaml -jobs: - Check-Changelog: - name: Check Changelog Action - runs-on: ubuntu-20.04 - steps: - - uses: tarides/changelog-check-action@v2 - with: - changelog: CHANGELOG.md +# name: Check Changelog Update on PR +# on: +# pull_request: +# types: [assigned, opened, synchronize, reopened, labeled, unlabeled] +# branches: +# - main +# - develop +# paths-ignore: +# - .pre-commit-config.yaml +# - .readthedocs.yaml +# jobs: +# Check-Changelog: +# name: Check Changelog Action +# runs-on: ubuntu-20.04 +# steps: +# - uses: tarides/changelog-check-action@v2 +# with: +# changelog: CHANGELOG.md diff --git a/.github/workflows/changelog-release-update.yml b/.github/workflows/changelog-release-update.yml index 17d9525..305ed24 100644 --- a/.github/workflows/changelog-release-update.yml +++ b/.github/workflows/changelog-release-update.yml @@ -25,6 +25,7 @@ jobs: with: latest-version: ${{ github.event.release.tag_name }} heading-text: ${{ github.event.release.name }} + release-notes: ${{ github.event.release.body }} - name: Create Pull Request uses: peter-evans/create-pull-request@v6 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6e88f0b..245cf29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,7 +27,7 @@ repos: - id: python-check-blanket-noqa # Check for # noqa: all - id: python-no-log-warn # Check for log.warn - repo: https://github.com/psf/black-pre-commit-mirror - rev: 24.8.0 + rev: 24.10.0 hooks: - id: black args: [--line-length=120] @@ -40,7 +40,7 @@ repos: - --force-single-line-imports - --profile black - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.6.9 + rev: v0.7.2 hooks: - id: ruff args: @@ -65,7 +65,7 @@ repos: - id: docconvert args: ["numpy"] - repo: https://github.com/tox-dev/pyproject-fmt - rev: "2.2.4" + rev: "v2.5.0" hooks: - id: pyproject-fmt - repo: https://github.com/jshwi/docsig # Check docstrings against function sig diff --git a/CHANGELOG.md b/CHANGELOG.md index 687403e..ef40ff4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,8 +8,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 Please add your functional changes to the appropriate section in the PR. Keep it human-readable, your future self will thank you! -## [Unreleased](https://github.com/ecmwf/anemoi-utils/compare/0.4.4...HEAD) - ## [0.4.4](https://github.com/ecmwf/anemoi-utils/compare/0.4.3...0.4.4) - 2024-11-01 ## [0.4.3](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...0.4.3) - 2024-10-26 @@ -17,7 +15,6 @@ Keep it human-readable, your future self will thank you! ## [0.4.2](https://github.com/ecmwf/anemoi-utils/compare/0.4.1...0.4.2) - 2024-10-25 ### Added - - Add supporting_arrays to checkpoints - Add factories registry - Optional renaming of subcommands via `command` attribute [#34](https://github.com/ecmwf/anemoi-utils/pull/34) @@ -53,7 +50,9 @@ Keep it human-readable, your future self will thank you! - Changelog merge strategy- Codeowners file - Create dependency on wcwidth. MIT licence. - Add distribution name dictionary to provenance [#15](https://github.com/ecmwf/anemoi-utils/pull/15) & [#19](https://github.com/ecmwf/anemoi-utils/pull/19) -- Add anonimize() function. +- Add anonymize() function. +- Add transfer to ssh:// target (experimental) +- Deprecated 'anemoi.utils.s3' ### Changed diff --git a/pyproject.toml b/pyproject.toml index 3848964..414078c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,14 +1,12 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# https://packaging.python.org/en/latest/guides/writing-pyproject-toml/ - [build-system] requires = [ "setuptools>=60", "setuptools-scm>=8" ] @@ -35,6 +33,7 @@ classifiers = [ "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", "Programming Language :: Python :: Implementation :: CPython", "Programming Language :: Python :: Implementation :: PyPy", ] diff --git a/src/anemoi/utils/__main__.py b/src/anemoi/utils/__main__.py index be940c2..0057b94 100644 --- a/src/anemoi/utils/__main__.py +++ b/src/anemoi/utils/__main__.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# from anemoi.utils.cli import cli_main from anemoi.utils.cli import make_parser diff --git a/src/anemoi/utils/checkpoints.py b/src/anemoi/utils/checkpoints.py index 83d386d..35e06ac 100644 --- a/src/anemoi/utils/checkpoints.py +++ b/src/anemoi/utils/checkpoints.py @@ -94,8 +94,8 @@ def load_metadata(path: str, *, supporting_arrays=False, name: str = DEFAULT_NAM with zipfile.ZipFile(path, "r") as f: metadata = json.load(f.open(metadata, "r")) if supporting_arrays: - metadata["supporting_arrays"] = load_supporting_arrays(f, metadata.get("supporting_arrays", {})) - return metadata, supporting_arrays + arrays = load_supporting_arrays(f, metadata.get("supporting_arrays_paths", {})) + return metadata, arrays return metadata else: diff --git a/src/anemoi/utils/commands/__init__.py b/src/anemoi/utils/commands/__init__.py index cebb539..e5e2219 100644 --- a/src/anemoi/utils/commands/__init__.py +++ b/src/anemoi/utils/commands/__init__.py @@ -1,12 +1,11 @@ -#!/usr/bin/env python -# (C) Copyright 2024 ECMWF. +# (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# import os diff --git a/src/anemoi/utils/commands/config.py b/src/anemoi/utils/commands/config.py index 167163f..549fbea 100644 --- a/src/anemoi/utils/commands/config.py +++ b/src/anemoi/utils/commands/config.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 # (C) Copyright 2024 Anemoi contributors. # # This software is licensed under the terms of the Apache Licence Version 2.0 diff --git a/src/anemoi/utils/compatibility.py b/src/anemoi/utils/compatibility.py new file mode 100644 index 0000000..699cdf8 --- /dev/null +++ b/src/anemoi/utils/compatibility.py @@ -0,0 +1,76 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +from __future__ import annotations + +import functools +from typing import Any +from typing import Callable + + +def aliases( + aliases: dict[str, str | list[str]] | None = None, **kwargs: str | list[str] +) -> Callable[[Callable], Callable]: + """Alias keyword arguments in a function call. + + Allows for dynamically renaming keyword arguments in a function call. + + Parameters + ---------- + aliases : dict[str, str | list[str]] | None, optional + Key, value pair of aliases, with keys being the true name, and value being a str or list of aliases, + by default None + **kwargs : str | list[str] + Kwargs form of aliases + + Returns + ------- + Callable + Decorator function that renames keyword arguments in a function call. + + Raises + ------ + ValueError + If the aliasing would result in duplicate keys. + + Examples + -------- + ```python + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + func(a=1, c=2) # (1, 2) + func(b=1, d=2) # (1, 2) + ``` + + """ + + if aliases is None: + aliases = {} + aliases.update(kwargs) + + aliases = {v: k for k, vs in aliases.items() for v in (vs if isinstance(vs, list) else [vs])} + + def decorator(func: Callable) -> Callable: + @functools.wraps(func) + def wrapper(*args, **kwargs) -> Any: + keys = kwargs.keys() + for k in set(keys).intersection(set(aliases.keys())): + if aliases[k] in keys: + raise ValueError( + f"When aliasing {k} with {aliases[k]} duplicate keys were present. Cannot include both." + ) + kwargs[aliases[k]] = kwargs.pop(k) + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/src/anemoi/utils/mars/__init__.py b/src/anemoi/utils/mars/__init__.py index 6130b82..66b9de0 100644 --- a/src/anemoi/utils/mars/__init__.py +++ b/src/anemoi/utils/mars/__init__.py @@ -1,6 +1,8 @@ -# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# (C) Copyright 2024 Anemoi contributors. +# # This software is licensed under the terms of the Apache Licence Version 2.0 # which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# # In applying this licence, ECMWF does not waive the privileges and immunities # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. diff --git a/src/anemoi/utils/registry.py b/src/anemoi/utils/registry.py index 9d4bcce..03ee6ed 100644 --- a/src/anemoi/utils/registry.py +++ b/src/anemoi/utils/registry.py @@ -33,11 +33,12 @@ def __call__(self, factory): class Registry: """A registry of factories""" - def __init__(self, package): + def __init__(self, package, key="_type"): self.package = package self.registered = {} self.kind = package.split(".")[-1] + self.key = key def register(self, name: str, factory: callable = None): @@ -86,6 +87,8 @@ def lookup(self, name: str) -> callable: self.registered[name] = entry_point.load() if name not in self.registered: + for e in self.registered: + LOG.info(f"Registered: {e}") raise ValueError(f"Cannot load '{name}' from {self.package}") return self.registered[name] @@ -96,3 +99,31 @@ def create(self, name: str, *args, **kwargs): def __call__(self, name: str, *args, **kwargs): return self.create(name, *args, **kwargs) + + def from_config(self, config, *args, **kwargs): + if isinstance(config, str): + config = {config: {}} + + if not isinstance(config, dict): + raise ValueError(f"Invalid config: {config}") + + if self.key in config: + config = config.copy() + key = config.pop(self.key) + return self.create(key, *args, **config, **kwargs) + + if len(config) == 1: + key = list(config.keys())[0] + value = config[key] + + if isinstance(value, dict): + return self.create(key, *args, **value, **kwargs) + + if isinstance(value, list): + return self.create(key, *args, *value, **kwargs) + + return self.create(key, *args, value, **kwargs) + + raise ValueError( + f"Entry '{config}' must either be a string, a dictionray with a single entry, or a dictionary with a '{self.key}' key" + ) diff --git a/src/anemoi/utils/remote/__init__.py b/src/anemoi/utils/remote/__init__.py new file mode 100644 index 0000000..d814119 --- /dev/null +++ b/src/anemoi/utils/remote/__init__.py @@ -0,0 +1,328 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import concurrent.futures +import logging +import os +import shutil +from abc import abstractmethod + +import tqdm + +from ..humanize import bytes_to_human + +LOGGER = logging.getLogger(__name__) + + +def _ignore(number_of_files, total_size, total_transferred, transfering): + pass + + +class Loader: + + def transfer_folder(self, *, source, target, overwrite=False, resume=False, verbosity=1, threads=1, progress=None): + assert verbosity == 1, verbosity + + if progress is None: + progress = _ignore + + # from boto3.s3.transfer import TransferConfig + # config = TransferConfig(use_threads=False) + config = None + with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: + try: + if verbosity > 0: + LOGGER.info(f"{self.action} {source} to {target}") + + total_size = 0 + total_transferred = 0 + + futures = [] + for name in self.list_source(source): + + futures.append( + executor.submit( + self.transfer_file, + source=self.source_path(name, source), + target=self.target_path(name, source, target), + overwrite=overwrite, + resume=resume, + verbosity=verbosity - 1, + config=config, + ) + ) + total_size += self.source_size(name) + + if len(futures) % 10000 == 0: + + progress(len(futures), total_size, 0, False) + + if verbosity > 0: + LOGGER.info(f"Preparing transfer, {len(futures):,} files... ({bytes_to_human(total_size)})") + done, _ = concurrent.futures.wait( + futures, + timeout=0.001, + return_when=concurrent.futures.FIRST_EXCEPTION, + ) + # Trigger exceptions if any + for future in done: + future.result() + + number_of_files = len(futures) + progress(number_of_files, total_size, 0, True) + + if verbosity > 0: + LOGGER.info(f"{self.action} {number_of_files:,} files ({bytes_to_human(total_size)})") + with tqdm.tqdm(total=total_size, unit="B", unit_scale=True, unit_divisor=1024) as pbar: + for future in concurrent.futures.as_completed(futures): + size = future.result() + pbar.update(size) + total_transferred += size + progress(number_of_files, total_size, total_transferred, True) + else: + for future in concurrent.futures.as_completed(futures): + size = future.result() + total_transferred += size + progress(number_of_files, total_size, total_transferred, True) + + except Exception: + executor.shutdown(wait=False, cancel_futures=True) + raise + + def transfer_file(self, source, target, overwrite, resume, verbosity, threads=1, progress=None, config=None): + try: + return self._transfer_file(source, target, overwrite, resume, verbosity, threads=threads, config=config) + except Exception as e: + LOGGER.exception(f"Error transferring {source} to {target}") + LOGGER.error(e) + raise + + @abstractmethod + def list_source(self, source): + raise NotImplementedError + + @abstractmethod + def source_path(self, local_path, source): + raise NotImplementedError + + @abstractmethod + def target_path(self, source_path, source, target): + raise NotImplementedError + + @abstractmethod + def source_size(self, local_path): + raise NotImplementedError + + @abstractmethod + def copy(self, source, target, **kwargs): + raise NotImplementedError + + @abstractmethod + def get_temporary_target(self, target, pattern): + raise NotImplementedError + + @abstractmethod + def rename_target(self, target, temporary_target): + raise NotImplementedError + + +class BaseDownload(Loader): + action = "Downloading" + + @abstractmethod + def copy(self, source, target, **kwargs): + raise NotImplementedError + + def get_temporary_target(self, target, pattern): + dirname, basename = os.path.split(target) + return pattern.format(dirname=dirname, basename=basename) + + def rename_target(self, target, new_target): + os.rename(target, new_target) + + def delete_target(self, target): + if os.path.exists(target): + shutil.rmtree(target) + + +class BaseUpload(Loader): + action = "Uploading" + + def copy(self, source, target, **kwargs): + if os.path.isdir(source): + self.transfer_folder(source=source, target=target, **kwargs) + else: + self.transfer_file(source=source, target=target, **kwargs) + + def list_source(self, source): + for root, _, files in os.walk(source): + for file in files: + yield os.path.join(root, file) + + def source_path(self, local_path, source): + return local_path + + def target_path(self, source_path, source, target): + relative_path = os.path.relpath(source_path, source) + path = os.path.join(target, relative_path) + return path + + def source_size(self, local_path): + return os.path.getsize(local_path) + + +class TransferMethodNotImplementedError(NotImplementedError): + pass + + +class Transfer: + """This is the internal API and should not be used directly. Use the transfer function instead.""" + + TransferMethodNotImplementedError = TransferMethodNotImplementedError + + def __init__( + self, + source, + target, + overwrite=False, + resume=False, + verbosity=1, + threads=1, + progress=None, + temporary_target=False, + ): + if target == ".": + target = os.path.basename(source) + + temporary_target = { + False: "{dirname}/{basename}", + True: "{dirname}-downloading/{basename}", + "-tmp/*": "{dirname}-tmp/{basename}", + "*-tmp": "{dirname}/{basename}-tmp", + "tmp-*": "{dirname}/tmp-{basename}", + }.get(temporary_target, temporary_target) + assert isinstance(temporary_target, str), (type(temporary_target), temporary_target) + + self.source = source + self.target = target + self.overwrite = overwrite + self.resume = resume + self.verbosity = verbosity + self.threads = threads + self.progress = progress + self.temporary_target = temporary_target + + cls = _find_transfer_class(self.source, self.target) + self.loader = cls() + + def run(self): + + target = self.loader.get_temporary_target(self.target, self.temporary_target) + if target != self.target: + LOGGER.info(f"Using temporary target {target} to copy to {self.target}") + + if self.overwrite: + # delete the target if it exists + LOGGER.info(f"Deleting {self.target}") + self.delete_target(target) + + # carefully delete the temporary target if it exists + head, tail = os.path.split(self.target) + head_, tail_ = os.path.split(target) + if not head_.startswith(head) or tail not in tail_: + LOGGER.info(f"{target} is too different from {self.target} to delete it automatically.") + else: + self.delete_target(target) + + self.loader.copy( + self.source, + target, + overwrite=self.overwrite, + resume=self.resume, + verbosity=self.verbosity, + threads=self.threads, + progress=self.progress, + ) + + self.rename_target(target, self.target) + + return self + + def rename_target(self, target, new_target): + if target != new_target: + LOGGER.info(f"Renaming temporary target {target} into {self.target}") + return self.loader.rename_target(target, new_target) + + def delete_target(self, target): + return self.loader.delete_target(target) + + +def _find_transfer_class(source, target): + from_ssh = source.startswith("ssh://") + into_ssh = target.startswith("ssh://") + + from_s3 = source.startswith("s3://") + into_s3 = target.startswith("s3://") + + from_local = not from_ssh and not from_s3 + into_local = not into_ssh and not into_s3 + + # check that exactly one source type and one target type is specified + assert sum([into_ssh, into_local, into_s3]) == 1, (into_ssh, into_local, into_s3) + assert sum([from_ssh, from_local, from_s3]) == 1, (from_ssh, from_local, from_s3) + + if from_local and into_ssh: # local -> ssh + from .ssh import RsyncUpload + + return RsyncUpload + + if from_s3 and into_local: # local <- S3 + from .s3 import S3Download + + return S3Download + + if from_local and into_s3: # local -> S3 + from .s3 import S3Upload + + return S3Upload + + raise TransferMethodNotImplementedError(f"Transfer from {source} to {target} is not implemented") + + +# this is the public API +def transfer(*args, **kwargs) -> Loader: + """Parameters + ---------- + source : str + A path to a local file or folder or a URL to a file or a folder on S3. + The url should start with 's3://'. + target : str + A path to a local file or folder or a URL to a file or a folder on S3 or a remote folder. + The url should start with 's3://' or 'ssh://'. + overwrite : bool, optional + If the data is alreay on in the target location it will be overwritten. + By default False + resume : bool, optional + If the data is alreay on S3 it will not be uploaded, unless the remote file has a different size + Ignored if the target is an SSH remote folder (ssh://). + By default False + verbosity : int, optional + The level of verbosity, by default 1 + progress: callable, optional + A callable that will be called with the number of files, the total size of the files, the total size + transferred and a boolean indicating if the transfer has started. By default None + threads : int, optional + The number of threads to use when uploading a directory, by default 1 + temporary_target : bool, optional + Experimental feature + If True and if the target location supports it, the data will be uploaded to a temporary location + then renamed to the final location. Supported by SSH and local targets, not supported by S3. + By default False + """ + copier = Transfer(*args, **kwargs) + copier.run() + return copier diff --git a/src/anemoi/utils/remote/s3.py b/src/anemoi/utils/remote/s3.py new file mode 100644 index 0000000..672d182 --- /dev/null +++ b/src/anemoi/utils/remote/s3.py @@ -0,0 +1,386 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +"""This module provides functions to upload, download, list and delete files and folders on S3. +The functions of this package expect that the AWS credentials are set up in the environment +typicaly by setting the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables or +by creating a `~/.aws/credentials` file. It is also possible to set the `endpoint_url` in the same file +to use a different S3 compatible service:: + + [default] + endpoint_url = https://some-storage.somewhere.world + aws_access_key_id = xxxxxxxxxxxxxxxxxxxxxxxx + aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxx + +Alternatively, the `endpoint_url`, and keys can be set in one of +the `~/.config/anemoi/settings.toml` +or `~/.config/anemoi/settings-secrets.toml` files. + +""" + +import logging +import os +import threading +from copy import deepcopy +from typing import Iterable + +import tqdm + +from ..config import load_config +from ..humanize import bytes_to_human +from . import BaseDownload +from . import BaseUpload + +LOGGER = logging.getLogger(__name__) + + +# s3_clients are not thread-safe, so we need to create a new client for each thread + +thread_local = threading.local() + + +def s3_client(bucket, region=None): + import boto3 + from botocore import UNSIGNED + from botocore.client import Config + + if not hasattr(thread_local, "s3_clients"): + thread_local.s3_clients = {} + + key = f"{bucket}-{region}" + + boto3_config = dict(max_pool_connections=25) + + if key in thread_local.s3_clients: + return thread_local.s3_clients[key] + + boto3_config = dict(max_pool_connections=25) + + if region: + # This is using AWS + + options = {"region_name": region} + + # Anonymous access + if not ( + os.path.exists(os.path.expanduser("~/.aws/credentials")) + or ("AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ) + ): + boto3_config["signature_version"] = UNSIGNED + + else: + + # We may be accessing a different S3 compatible service + # Use anemoi.config to get the configuration + + options = {} + config = load_config(secrets=["aws_access_key_id", "aws_secret_access_key"]) + + cfg = config.get("object-storage", {}) + for k, v in cfg.items(): + if isinstance(v, (str, int, float, bool)): + options[k] = v + + for k, v in cfg.get(bucket, {}).items(): + if isinstance(v, (str, int, float, bool)): + options[k] = v + + type = options.pop("type", "s3") + if type != "s3": + raise ValueError(f"Unsupported object storage type {type}") + + if "config" in options: + boto3_config.update(options["config"]) + del options["config"] + from botocore.client import Config + + options["config"] = Config(**boto3_config) + + thread_local.s3_clients[key] = boto3.client("s3", **options) + + return thread_local.s3_clients[key] + + +class S3Upload(BaseUpload): + + def get_temporary_target(self, target, pattern): + return target + + def rename_target(self, target, temporary_target): + pass + + def delete_target(self, target): + pass + # delete(target) + + def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None): + + from botocore.exceptions import ClientError + + assert target.startswith("s3://") + + _, _, bucket, key = target.split("/", 3) + s3 = s3_client(bucket) + + size = os.path.getsize(source) + + if verbosity > 0: + LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})") + + try: + results = s3.head_object(Bucket=bucket, Key=key) + remote_size = int(results["ContentLength"]) + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise + remote_size = None + + if remote_size is not None: + if remote_size != size: + LOGGER.warning( + f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})" + ) + elif resume: + # LOGGER.info(f"{target} already exists, skipping") + return size + + if remote_size is not None and not overwrite and not resume: + raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") + + if verbosity > 0: + with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar: + s3.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x), Config=config) + else: + s3.upload_file(source, bucket, key, Config=config) + + return size + + +class S3Download(BaseDownload): + + def copy(self, source, target, **kwargs): + assert source.startswith("s3://") + + if source.endswith("/"): + self.transfer_folder(source=source, target=target, **kwargs) + else: + self.transfer_file(source=source, target=target, **kwargs) + + def list_source(self, source): + yield from _list_objects(source) + + def source_path(self, s3_object, source): + _, _, bucket, _ = source.split("/", 3) + return f"s3://{bucket}/{s3_object['Key']}" + + def target_path(self, s3_object, source, target): + _, _, _, folder = source.split("/", 3) + local_path = os.path.join(target, os.path.relpath(s3_object["Key"], folder)) + os.makedirs(os.path.dirname(local_path), exist_ok=True) + return local_path + + def source_size(self, s3_object): + return s3_object["Size"] + + def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None): + # from boto3.s3.transfer import TransferConfig + + _, _, bucket, key = source.split("/", 3) + s3 = s3_client(bucket) + + try: + response = s3.head_object(Bucket=bucket, Key=key) + except s3.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + raise ValueError(f"{source} does not exist ({bucket}, {key})") + raise + + size = int(response["ContentLength"]) + + if verbosity > 0: + LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})") + + if overwrite: + resume = False + + if resume: + if os.path.exists(target): + local_size = os.path.getsize(target) + if local_size != size: + LOGGER.warning( + f"{target} already with different size, re-downloading (remote={size}, local={local_size})" + ) + else: + # if verbosity > 0: + # LOGGER.info(f"{target} already exists, skipping") + return size + + if os.path.exists(target) and not overwrite: + raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") + + if verbosity > 0: + with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar: + s3.download_file(bucket, key, target, Callback=lambda x: pbar.update(x), Config=config) + else: + s3.download_file(bucket, key, target, Config=config) + + return size + + +def _list_objects(target, batch=False): + _, _, bucket, prefix = target.split("/", 3) + s3 = s3_client(bucket) + + paginator = s3.get_paginator("list_objects_v2") + + for page in paginator.paginate(Bucket=bucket, Prefix=prefix): + if "Contents" in page: + objects = deepcopy(page["Contents"]) + if batch: + yield objects + else: + yield from objects + + +def _delete_folder(target) -> None: + _, _, bucket, _ = target.split("/", 3) + s3 = s3_client(bucket) + + total = 0 + for batch in _list_objects(target, batch=True): + LOGGER.info(f"Deleting {len(batch):,} objects from {target}") + s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": o["Key"]} for o in batch]}) + total += len(batch) + LOGGER.info(f"Deleted {len(batch):,} objects (total={total:,})") + + +def _delete_file(target) -> None: + from botocore.exceptions import ClientError + + _, _, bucket, key = target.split("/", 3) + s3 = s3_client(bucket) + + try: + s3.head_object(Bucket=bucket, Key=key) + exits = True + except ClientError as e: + if e.response["Error"]["Code"] != "404": + raise + exits = False + + if not exits: + LOGGER.warning(f"{target} does not exist. Did you mean to delete a folder? Then add a trailing '/'") + return + + LOGGER.info(f"Deleting {target}") + s3.delete_object(Bucket=bucket, Key=key) + LOGGER.info(f"{target} is deleted") + + +def delete(target) -> None: + """Delete a file or a folder from S3. + + Parameters + ---------- + target : str + The URL of a file or a folder on S3. The url should start with 's3://'. If the URL ends with a '/' it is + assumed to be a folder, otherwise it is assumed to be a file. + """ + + assert target.startswith("s3://") + + if target.endswith("/"): + _delete_folder(target) + else: + _delete_file(target) + + +def list_folder(folder) -> Iterable: + """List the sub folders in a folder on S3. + + Parameters + ---------- + folder : str + The URL of a folder on S3. The url should start with 's3://'. + + Returns + ------- + list + A list of the subfolders names in the folder. + """ + + assert folder.startswith("s3://") + if not folder.endswith("/"): + folder += "/" + + _, _, bucket, prefix = folder.split("/", 3) + + s3 = s3_client(bucket) + paginator = s3.get_paginator("list_objects_v2") + + for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): + if "CommonPrefixes" in page: + yield from [folder + _["Prefix"] for _ in page.get("CommonPrefixes")] + + +def object_info(target) -> dict: + """Get information about an object on S3. + + Parameters + ---------- + target : str + The URL of a file or a folder on S3. The url should start with 's3://'. + + Returns + ------- + dict + A dictionary with information about the object. + """ + + _, _, bucket, key = target.split("/", 3) + s3 = s3_client(bucket) + + try: + return s3.head_object(Bucket=bucket, Key=key) + except s3.exceptions.ClientError as e: + if e.response["Error"]["Code"] == "404": + raise ValueError(f"{target} does not exist") + raise + + +def object_acl(target) -> dict: + """Get information about an object's ACL on S3. + + Parameters + ---------- + target : str + The URL of a file or a folder on S3. The url should start with 's3://'. + + Returns + ------- + dict + A dictionary with information about the object's ACL. + """ + + _, _, bucket, key = target.split("/", 3) + s3 = s3_client() + + return s3.get_object_acl(Bucket=bucket, Key=key) + + +def download(source, target, *args, **kwargs): + from . import transfer + + assert source.startswith("s3://"), f"source {source} should start with 's3://'" + return transfer(source, target, *args, **kwargs) + + +def upload(source, target, *args, **kwargs): + from . import transfer + + assert target.startswith("s3://"), f"target {target} should start with 's3://'" + return transfer(source, target, *args, **kwargs) diff --git a/src/anemoi/utils/remote/ssh.py b/src/anemoi/utils/remote/ssh.py new file mode 100644 index 0000000..9ba49ef --- /dev/null +++ b/src/anemoi/utils/remote/ssh.py @@ -0,0 +1,133 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import logging +import os +import random +import shlex +import subprocess + +from ..humanize import bytes_to_human +from . import BaseUpload + +LOGGER = logging.getLogger(__name__) + + +def call_process(*args): + proc = subprocess.Popen( + args, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + stdout, stderr = proc.communicate() + if proc.returncode != 0: + print(stdout) + msg = f"{' '.join(args)} failed: {stderr}" + raise RuntimeError(msg) + + return stdout.decode("utf-8").strip() + + +class SshBaseUpload(BaseUpload): + + def _parse_target(self, target): + assert target.startswith("ssh://"), target + + target = target[6:] + hostname, path = target.split(":") + + if "+" in hostname: + hostnames = hostname.split("+") + hostname = hostnames[random.randint(0, len(hostnames) - 1)] + + return hostname, path + + def get_temporary_target(self, target, pattern): + hostname, path = self._parse_target(target) + dirname, basename = os.path.split(path) + path = pattern.format(dirname=dirname, basename=basename) + return f"ssh://{hostname}:{path}" + + def rename_target(self, target, new_target): + hostname, path = self._parse_target(target) + hostname, new_path = self._parse_target(new_target) + call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(new_path))) + call_process("ssh", hostname, "mv", shlex.quote(path), shlex.quote(new_path)) + + def delete_target(self, target): + pass + # hostname, path = self._parse_target(target) + # LOGGER.info(f"Deleting {target}") + # call_process("ssh", hostname, "rm", "-rf", shlex.quote(path)) + + +class RsyncUpload(SshBaseUpload): + + def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None): + hostname, path = self._parse_target(target) + + size = os.path.getsize(source) + + if verbosity > 0: + LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})") + + call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(path))) + call_process( + "rsync", + "-av", + "--partial", + # it would be nice to avoid two ssh calls, but the following is not possible, + # this is because it requires a shell command and would not be safe. + # # f"--rsync-path='mkdir -p {os.path.dirname(path)} && rsync'", + source, + f"{hostname}:{path}", + ) + return size + + +class ScpUpload(SshBaseUpload): + + def _transfer_file(self, source, target, overwrite, resume, verbosity, threads, config=None): + hostname, path = self._parse_target(target) + + size = os.path.getsize(source) + + if verbosity > 0: + LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})") + + remote_size = None + try: + out = call_process("ssh", hostname, "stat", "-c", "%s", shlex.quote(path)) + remote_size = int(out) + except RuntimeError: + remote_size = None + + if remote_size is not None: + if remote_size != size: + LOGGER.warning( + f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})" + ) + elif resume: + # LOGGER.info(f"{target} already exists, skipping") + return size + + if remote_size is not None and not overwrite and not resume: + raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") + + call_process("ssh", hostname, "mkdir", "-p", shlex.quote(os.path.dirname(path))) + call_process("scp", source, shlex.quote(f"{hostname}:{path}")) + + return size + + +def upload(source, target, **kwargs) -> None: + uploader = RsyncUpload() + + if os.path.isdir(source): + uploader.transfer_folder(source=source, target=target, **kwargs) + else: + uploader.transfer_file(source=source, target=target, **kwargs) diff --git a/src/anemoi/utils/s3.py b/src/anemoi/utils/s3.py index 66af693..dd68c2d 100644 --- a/src/anemoi/utils/s3.py +++ b/src/anemoi/utils/s3.py @@ -7,554 +7,57 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. +import warnings -"""This module provides functions to upload, download, list and delete files and folders on S3. -The functions of this package expect that the AWS credentials are set up in the environment -typicaly by setting the `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables or -by creating a `~/.aws/credentials` file. It is also possible to set the `endpoint_url` in the same file -to use a different S3 compatible service:: +from .remote import transfer +from .remote.s3 import delete as delete_ +from .remote.s3 import s3_client as s3_client_ - [default] - endpoint_url = https://some-storage.somewhere.world - aws_access_key_id = xxxxxxxxxxxxxxxxxxxxxxxx - aws_secret_access_key = xxxxxxxxxxxxxxxxxxxxxxxx +warnings.warn( + "The anemoi.utils.s3 module is deprecated and will be removed in a future release. " + "Please use the 'anemoi.utils.remote' or 'anemoi.utils.remote.s3' module instead.", + DeprecationWarning, + stacklevel=2, +) -Alternatively, the `endpoint_url`, and keys can be set in one of -the `~/.config/anemoi/settings.toml` -or `~/.config/anemoi/settings-secrets.toml` files. -""" - -import concurrent.futures -import logging -import os -import threading -from copy import deepcopy - -import tqdm - -from .config import load_config -from .humanize import bytes_to_human - -LOGGER = logging.getLogger(__name__) - - -# s3_clients are not thread-safe, so we need to create a new client for each thread - -thread_local = threading.local() - - -def s3_client(bucket, region=None): - import boto3 - from botocore import UNSIGNED - from botocore.client import Config - - if not hasattr(thread_local, "s3_clients"): - thread_local.s3_clients = {} - - key = f"{bucket}-{region}" - - boto3_config = dict(max_pool_connections=25) - - if key in thread_local.s3_clients: - return thread_local.s3_clients[key] - - boto3_config = dict(max_pool_connections=25) - - if region: - # This is using AWS - - options = {"region_name": region} - - # Anonymous access - if not ( - os.path.exists(os.path.expanduser("~/.aws/credentials")) - or ("AWS_ACCESS_KEY_ID" in os.environ and "AWS_SECRET_ACCESS_KEY" in os.environ) - ): - boto3_config["signature_version"] = UNSIGNED - - else: - - # We may be accessing a different S3 compatible service - # Use anemoi.config to get the configuration - - options = {} - config = load_config(secrets=["aws_access_key_id", "aws_secret_access_key"]) - - cfg = config.get("object-storage", {}) - for k, v in cfg.items(): - if isinstance(v, (str, int, float, bool)): - options[k] = v - - for k, v in cfg.get(bucket, {}).items(): - if isinstance(v, (str, int, float, bool)): - options[k] = v - - type = options.pop("type", "s3") - if type != "s3": - raise ValueError(f"Unsupported object storage type {type}") - - if "config" in options: - boto3_config.update(options["config"]) - del options["config"] - from botocore.client import Config - - options["config"] = Config(**boto3_config) - - thread_local.s3_clients[key] = boto3.client("s3", **options) - - return thread_local.s3_clients[key] - - -def _ignore(number_of_files, total_size, total_transferred, transfering): - pass - - -class Transfer: - - def transfer_folder(self, *, source, target, overwrite=False, resume=False, verbosity=1, threads=1, progress=None): - assert verbosity == 1, verbosity - - if progress is None: - progress = _ignore - - # from boto3.s3.transfer import TransferConfig - # config = TransferConfig(use_threads=False) - config = None - with concurrent.futures.ThreadPoolExecutor(max_workers=threads) as executor: - try: - if verbosity > 0: - LOGGER.info(f"{self.action} {source} to {target}") - - total_size = 0 - total_transferred = 0 - - futures = [] - for name in self.list_source(source): - - futures.append( - executor.submit( - self.transfer_file, - source=self.source_path(name, source), - target=self.target_path(name, source, target), - overwrite=overwrite, - resume=resume, - verbosity=verbosity - 1, - config=config, - ) - ) - total_size += self.source_size(name) - - if len(futures) % 10000 == 0: - - progress(len(futures), total_size, 0, False) - - if verbosity > 0: - LOGGER.info(f"Preparing transfer, {len(futures):,} files... ({bytes_to_human(total_size)})") - done, _ = concurrent.futures.wait( - futures, - timeout=0.001, - return_when=concurrent.futures.FIRST_EXCEPTION, - ) - # Trigger exceptions if any - for future in done: - future.result() - - number_of_files = len(futures) - progress(number_of_files, total_size, 0, True) - - if verbosity > 0: - LOGGER.info(f"{self.action} {number_of_files:,} files ({bytes_to_human(total_size)})") - with tqdm.tqdm(total=total_size, unit="B", unit_scale=True, unit_divisor=1024) as pbar: - for future in concurrent.futures.as_completed(futures): - size = future.result() - pbar.update(size) - total_transferred += size - progress(number_of_files, total_size, total_transferred, True) - else: - for future in concurrent.futures.as_completed(futures): - size = future.result() - total_transferred += size - progress(number_of_files, total_size, total_transferred, True) - - except Exception: - executor.shutdown(wait=False, cancel_futures=True) - raise - - -class Upload(Transfer): - action = "Uploading" - - def list_source(self, source): - for root, _, files in os.walk(source): - for file in files: - yield os.path.join(root, file) - - def source_path(self, local_path, source): - return local_path - - def target_path(self, source_path, source, target): - relative_path = os.path.relpath(source_path, source) - s3_path = os.path.join(target, relative_path) - return s3_path - - def source_size(self, local_path): - return os.path.getsize(local_path) - - def transfer_file(self, source, target, overwrite, resume, verbosity, progress=None, config=None): - try: - return self._transfer_file(source, target, overwrite, resume, verbosity, config=config) - except Exception as e: - LOGGER.exception(f"Error transferring {source} to {target}") - LOGGER.error(e) - raise - - def _transfer_file(self, source, target, overwrite, resume, verbosity, config=None): - - from botocore.exceptions import ClientError - - assert target.startswith("s3://") - - _, _, bucket, key = target.split("/", 3) - s3 = s3_client(bucket) - - size = os.path.getsize(source) - - if verbosity > 0: - LOGGER.info(f"{self.action} {source} to {target} ({bytes_to_human(size)})") - - try: - results = s3.head_object(Bucket=bucket, Key=key) - remote_size = int(results["ContentLength"]) - except ClientError as e: - if e.response["Error"]["Code"] != "404": - raise - remote_size = None - - if remote_size is not None: - if remote_size != size: - LOGGER.warning( - f"{target} already exists, but with different size, re-uploading (remote={remote_size}, local={size})" - ) - elif resume: - # LOGGER.info(f"{target} already exists, skipping") - return size - - if remote_size is not None and not overwrite and not resume: - raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") - - if verbosity > 0: - with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar: - s3.upload_file(source, bucket, key, Callback=lambda x: pbar.update(x), Config=config) - else: - s3.upload_file(source, bucket, key, Config=config) - - return size - - -class Download(Transfer): - action = "Downloading" - - def list_source(self, source): - yield from _list_objects(source) - - def source_path(self, s3_object, source): - _, _, bucket, _ = source.split("/", 3) - return f"s3://{bucket}/{s3_object['Key']}" - - def target_path(self, s3_object, source, target): - _, _, _, folder = source.split("/", 3) - local_path = os.path.join(target, os.path.relpath(s3_object["Key"], folder)) - os.makedirs(os.path.dirname(local_path), exist_ok=True) - return local_path - - def source_size(self, s3_object): - return s3_object["Size"] - - def transfer_file(self, source, target, overwrite, resume, verbosity, progress=None, config=None): - try: - return self._transfer_file(source, target, overwrite, resume, verbosity, config=config) - except Exception as e: - LOGGER.exception(f"Error transferring {source} to {target}") - LOGGER.error(e) - raise - - def _transfer_file(self, source, target, overwrite, resume, verbosity, config=None): - # from boto3.s3.transfer import TransferConfig - - _, _, bucket, key = source.split("/", 3) - s3 = s3_client(bucket) - - try: - response = s3.head_object(Bucket=bucket, Key=key) - except s3.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - raise ValueError(f"{source} does not exist ({bucket}, {key})") - raise - - size = int(response["ContentLength"]) - - if verbosity > 0: - LOGGER.info(f"Downloading {source} to {target} ({bytes_to_human(size)})") - - if overwrite: - resume = False - - if resume: - if os.path.exists(target): - local_size = os.path.getsize(target) - if local_size != size: - LOGGER.warning( - f"{target} already with different size, re-downloading (remote={size}, local={size})" - ) - else: - # if verbosity > 0: - # LOGGER.info(f"{target} already exists, skipping") - return size - - if os.path.exists(target) and not overwrite: - raise ValueError(f"{target} already exists, use 'overwrite' to replace or 'resume' to skip") - - if verbosity > 0: - with tqdm.tqdm(total=size, unit="B", unit_scale=True, unit_divisor=1024, leave=False) as pbar: - s3.download_file(bucket, key, target, Callback=lambda x: pbar.update(x), Config=config) - else: - s3.download_file(bucket, key, target, Config=config) - - return size +def s3_client(*args, **kwargs): + warnings.warn( + "The 's3_client' function (from anemoi.utils.s3 import s3_client) function is deprecated and will be removed in a future release. " + "Please use the 's3_client' function (from anemoi.utils.remote.s3 import s3_client) instead.", + DeprecationWarning, + stacklevel=2, + ) + return s3_client_(*args, **kwargs) def upload(source, target, *, overwrite=False, resume=False, verbosity=1, progress=None, threads=1) -> None: - """Upload a file or a folder to S3. - - Parameters - ---------- - source : str - A path to a file or a folder to upload. - target : str - A URL to a file or a folder on S3. The url should start with 's3://'. - overwrite : bool, optional - If the data is alreay on S3 it will be overwritten, by default False - resume : bool, optional - If the data is alreay on S3 it will not be uploaded, unless the remote file - has a different size, by default False - verbosity : int, optional - The level of verbosity, by default 1 - progress: callable, optional - A callable that will be called with the number of files, the total size of the files, the total size - transferred and a boolean indicating if the transfer has started. By default None - threads : int, optional - The number of threads to use when uploading a directory, by default 1 - """ - - uploader = Upload() - - if os.path.isdir(source): - uploader.transfer_folder( - source=source, - target=target, - overwrite=overwrite, - resume=resume, - verbosity=verbosity, - progress=progress, - threads=threads, - ) - else: - uploader.transfer_file( - source=source, - target=target, - overwrite=overwrite, - resume=resume, - verbosity=verbosity, - progress=progress, - ) - - -def download(source, target, *, overwrite=False, resume=False, verbosity=1, progress=None, threads=1) -> None: - """Download a file or a folder from S3. - - Parameters - ---------- - source : str - The URL of a file or a folder on S3. The url should start with 's3://'. If the URL ends with a '/' it is - assumed to be a folder, otherwise it is assumed to be a file. - target : str - The local path where the file or folder will be downloaded. - overwrite : bool, optional - If false, files which have already been download will be skipped, unless their size - does not match their size on S3 , by default False - resume : bool, optional - If the data is alreay on local it will not be downloaded, unless the remote file - has a different size, by default False - verbosity : int, optional - The level of verbosity, by default 1 - progress: callable, optional - A callable that will be called with the number of files, the total size of the files, the total size - transferred and a boolean indicating if the transfer has started. By default None - threads : int, optional - The number of threads to use when downloading a directory, by default 1 - """ - assert source.startswith("s3://") - - downloader = Download() - - if source.endswith("/"): - downloader.transfer_folder( - source=source, - target=target, - overwrite=overwrite, - resume=resume, - verbosity=verbosity, - progress=progress, - threads=threads, - ) - else: - downloader.transfer_file( - source=source, - target=target, - overwrite=overwrite, - resume=resume, - verbosity=verbosity, - progress=progress, - ) - - -def _list_objects(target, batch=False): - _, _, bucket, prefix = target.split("/", 3) - s3 = s3_client(bucket) - - paginator = s3.get_paginator("list_objects_v2") - - for page in paginator.paginate(Bucket=bucket, Prefix=prefix): - if "Contents" in page: - objects = deepcopy(page["Contents"]) - if batch: - yield objects - else: - yield from objects - - -def _delete_folder(target) -> None: - _, _, bucket, _ = target.split("/", 3) - s3 = s3_client(bucket) - - total = 0 - for batch in _list_objects(target, batch=True): - LOGGER.info(f"Deleting {len(batch):,} objects from {target}") - s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": o["Key"]} for o in batch]}) - total += len(batch) - LOGGER.info(f"Deleted {len(batch):,} objects (total={total:,})") - - -def _delete_file(target) -> None: - from botocore.exceptions import ClientError - - _, _, bucket, key = target.split("/", 3) - s3 = s3_client(bucket) - - try: - s3.head_object(Bucket=bucket, Key=key) - exits = True - except ClientError as e: - if e.response["Error"]["Code"] != "404": - raise - exits = False - - if not exits: - LOGGER.warning(f"{target} does not exist. Did you mean to delete a folder? Then add a trailing '/'") - return - - LOGGER.info(f"Deleting {target}") - s3.delete_object(Bucket=bucket, Key=key) - LOGGER.info(f"{target} is deleted") - - -def delete(target) -> None: - """Delete a file or a folder from S3. - - Parameters - ---------- - target : str - The URL of a file or a folder on S3. The url should start with 's3://'. If the URL ends with a '/' it is - assumed to be a folder, otherwise it is assumed to be a file. - """ - - assert target.startswith("s3://") - - if target.endswith("/"): - _delete_folder(target) - else: - _delete_file(target) - - -def list_folder(folder) -> list: - """List the sub folders in a folder on S3. - - Parameters - ---------- - folder : str - The URL of a folder on S3. The url should start with 's3://'. - - Returns - ------- - list - A list of the subfolders names in the folder. - """ - - assert folder.startswith("s3://") - if not folder.endswith("/"): - folder += "/" - - _, _, bucket, prefix = folder.split("/", 3) - - s3 = s3_client(bucket) - paginator = s3.get_paginator("list_objects_v2") - - for page in paginator.paginate(Bucket=bucket, Prefix=prefix, Delimiter="/"): - if "CommonPrefixes" in page: - yield from [folder + _["Prefix"] for _ in page.get("CommonPrefixes")] - - -def object_info(target) -> dict: - """Get information about an object on S3. - - Parameters - ---------- - target : str - The URL of a file or a folder on S3. The url should start with 's3://'. - - Returns - ------- - dict - A dictionary with information about the object. - """ - - _, _, bucket, key = target.split("/", 3) - s3 = s3_client(bucket) - - try: - return s3.head_object(Bucket=bucket, Key=key) - except s3.exceptions.ClientError as e: - if e.response["Error"]["Code"] == "404": - raise ValueError(f"{target} does not exist") - raise - - -def object_acl(target) -> dict: - """Get information about an object's ACL on S3. - - Parameters - ---------- - target : str - The URL of a file or a folder on S3. The url should start with 's3://'. - - Returns - ------- - dict - A dictionary with information about the object's ACL. - """ - - _, _, bucket, key = target.split("/", 3) - s3 = s3_client() - - return s3.get_object_acl(Bucket=bucket, Key=key) + warnings.warn( + "The 'upload' function (from anemoi.utils.s3 import upload) function is deprecated and will be removed in a future release. " + "Please use the 'transfer' function (from anemoi.utils.remote import transfer) instead.", + DeprecationWarning, + stacklevel=2, + ) + return transfer( + source, target, overwrite=overwrite, resume=resume, verbosity=verbosity, progress=progress, threads=threads + ) + + +def download(*args, **kwargs): + warnings.warn( + "The 'download' function (from anemoi.utils.s3 import download) function is deprecated and will be removed in a future release. " + "Please use the 'transfer' function (from anemoi.utils.remote import transfer) instead.", + DeprecationWarning, + stacklevel=2, + ) + return transfer(*args, **kwargs) + + +def delete(*args, **kwargs): + warnings.warn( + "The 'delete' function (from anemoi.utils.s3 import delete) function is deprecated and will be removed in a future release. " + "Please use the 'transfer' function (from anemoi.utils.remote.s3 import delete) instead.", + DeprecationWarning, + stacklevel=2, + ) + return delete_(*args, **kwargs) diff --git a/tests/test-transfer-data/directory/b/c/x b/tests/test-transfer-data/directory/b/c/x new file mode 100644 index 0000000..587be6b --- /dev/null +++ b/tests/test-transfer-data/directory/b/c/x @@ -0,0 +1 @@ +x diff --git a/tests/test-transfer-data/directory/b/y b/tests/test-transfer-data/directory/b/y new file mode 100644 index 0000000..975fbec --- /dev/null +++ b/tests/test-transfer-data/directory/b/y @@ -0,0 +1 @@ +y diff --git "a/tests/test-transfer-data/directory/exotic filename ;^\"'[=.,#]()\303\252\303\274\303\247\303\262\342\234\205.txt" "b/tests/test-transfer-data/directory/exotic filename ;^\"'[=.,#]()\303\252\303\274\303\247\303\262\342\234\205.txt" new file mode 100644 index 0000000..174cbe5 --- /dev/null +++ "b/tests/test-transfer-data/directory/exotic filename ;^\"'[=.,#]()\303\252\303\274\303\247\303\262\342\234\205.txt" @@ -0,0 +1 @@ +exotic diff --git a/tests/test-transfer-data/directory/z b/tests/test-transfer-data/directory/z new file mode 100644 index 0000000..b680253 --- /dev/null +++ b/tests/test-transfer-data/directory/z @@ -0,0 +1 @@ +z diff --git a/tests/test-transfer-data/file b/tests/test-transfer-data/file new file mode 100644 index 0000000..f73f309 --- /dev/null +++ b/tests/test-transfer-data/file @@ -0,0 +1 @@ +file diff --git a/tests/test_compatibility.py b/tests/test_compatibility.py new file mode 100644 index 0000000..9f270f6 --- /dev/null +++ b/tests/test_compatibility.py @@ -0,0 +1,32 @@ +# (C) Copyright 2024 Anemoi contributors. +# +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import pytest + +from anemoi.utils.compatibility import aliases + + +def test_aliases() -> None: + + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + assert func(a=1, c=2) == (1, 2) + assert func(a=1, d=2) == (1, 2) + assert func(b=1, d=2) == (1, 2) + + +def test_duplicate_values() -> None: + @aliases(a="b", c=["d", "e"]) + def func(a, c): + return a, c + + with pytest.raises(ValueError): + func(a=1, b=2) diff --git a/tests/test_remote.py b/tests/test_remote.py new file mode 100644 index 0000000..7e5bab8 --- /dev/null +++ b/tests/test_remote.py @@ -0,0 +1,175 @@ +# (C) Copyright 2024 European Centre for Medium-Range Weather Forecasts. +# This software is licensed under the terms of the Apache Licence Version 2.0 +# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0. +# In applying this licence, ECMWF does not waive the privileges and immunities +# granted to it by virtue of its status as an intergovernmental organisation +# nor does it submit to any jurisdiction. + +import os +import shutil + +import pytest + +from anemoi.utils.remote import TransferMethodNotImplementedError +from anemoi.utils.remote import _find_transfer_class +from anemoi.utils.remote import transfer + +IN_CI = (os.environ.get("GITHUB_WORKFLOW") is not None) or (os.environ.get("IN_CI_HPC") is not None) + +LOCAL = [ + "/absolute/path/to/file", + "relative/file", + "/absolute/path/to/dir/", + "relative/dir/", + ".", + "..", + "./", + "file", + "dir/", + "/dir/", + "/dir", + "/file", +] +S3 = ["s3://bucket/key/", "s3://bucket/key"] +SSH = [ + "ssh://hostname:/absolute/file", + "ssh://hostname:relative/file", + "ssh://hostname:/absolute/dir/", + "ssh://hostname:relative/dir/", +] + +ROOT_S3_READ = "s3://ml-tests/test-data/anemoi-utils/pytest/transfer" +ROOT_S3_WRITE = f"s3://ml-tmp/anemoi-utils/pytest/transfer/test-{os.getpid()}" + +LOCAL_TEST_DATA = os.path.dirname(__file__) + "/test-transfer-data" + + +@pytest.mark.parametrize("source", LOCAL) +@pytest.mark.parametrize("target", S3) +def test_transfer_find_s3_upload(source, target): + from anemoi.utils.remote.s3 import S3Upload + + assert _find_transfer_class(source, target) == S3Upload + + +@pytest.mark.parametrize("source", S3) +@pytest.mark.parametrize("target", LOCAL) +def test_transfer_find_s3_download(source, target): + from anemoi.utils.remote.s3 import S3Download + + assert _find_transfer_class(source, target) == S3Download + + +@pytest.mark.parametrize("source", LOCAL) +@pytest.mark.parametrize("target", SSH) +def test_transfer_find_ssh_upload(source, target): + from anemoi.utils.remote.ssh import RsyncUpload + + assert _find_transfer_class(source, target) == RsyncUpload + + +@pytest.mark.parametrize("source", S3 + SSH) +@pytest.mark.parametrize("target", S3 + SSH) +def test_transfer_find_none(source, target): + with pytest.raises(TransferMethodNotImplementedError): + assert _find_transfer_class(source, target) + + +@pytest.mark.skipif(IN_CI, reason="Test requires access to S3") +def test_transfer_zarr_s3_to_local(tmpdir): + source = "s3://ml-datasets/aifs-ea-an-oper-0001-mars-20p0-2000-2000-12h-v0-TESTING2.zarr/" + tmp = tmpdir.strpath + "/test" + + transfer(source, tmp) + with pytest.raises(ValueError, match="already exists"): + transfer(source, tmp) + + transfer(source, tmp, resume=True) + transfer(source, tmp, overwrite=True) + + +@pytest.mark.skipif(IN_CI, reason="Test requires access to S3") +def test_transfer_zarr_local_to_s3(tmpdir): + fixture = "s3://ml-datasets/aifs-ea-an-oper-0001-mars-20p0-2000-2000-12h-v0-TESTING2.zarr/" + source = tmpdir.strpath + "/test" + target = ROOT_S3_WRITE + "/test.zarr" + + transfer(fixture, source) + transfer(source, target) + + with pytest.raises(ValueError, match="already exists"): + transfer(source, target) + + transfer(source, target, resume=True) + transfer(source, target, overwrite=True) + + +def _delete_file_or_directory(path): + if os.path.isdir(path): + shutil.rmtree(path, ignore_errors=True) + else: + if os.path.exists(path): + os.remove(path) + + +def compare(local1, local2): + if os.path.isdir(local1): + for root, dirs, files in os.walk(local1): + for file in files: + file1 = os.path.join(root, file) + file2 = file1.replace(local1, local2) + assert os.path.exists(file2) + with open(file1, "rb") as f1, open(file2, "rb") as f2: + assert f1.read() == f2.read() + else: + with open(local1, "rb") as f1, open(local2, "rb") as f2: + assert f1.read() == f2.read() + + +@pytest.mark.skipif(IN_CI, reason="Test requires access to S3") +@pytest.mark.parametrize("path", ["directory/", "file"]) +def test_transfer_local_to_s3_to_local(path): + local = LOCAL_TEST_DATA + "/" + path + remote = ROOT_S3_WRITE + "/" + path + local2 = LOCAL_TEST_DATA + "-copy-" + path + + transfer(local, remote, overwrite=True) + transfer(local, remote, resume=True) + with pytest.raises(ValueError, match="already exists"): + transfer(local, remote) + + _delete_file_or_directory(local2) + transfer(remote, local2) + with pytest.raises(ValueError, match="already exists"): + transfer(remote, local2) + transfer(local, remote, overwrite=True) + transfer(local, remote, resume=True) + + compare(local, local2) + + _delete_file_or_directory(local2) + + +@pytest.mark.skipif(IN_CI, reason="Test requires ssh access to localhost") +@pytest.mark.parametrize("path", ["directory", "directory/", "file"]) +@pytest.mark.parametrize("temporary_target", [True, False]) +def test_transfer_local_to_ssh(path, temporary_target): + local = LOCAL_TEST_DATA + "/" + path + remote_path = LOCAL_TEST_DATA + "-as-ssh-" + path + assert os.path.isabs(remote_path), remote_path + + remote = "ssh://localhost:" + remote_path + + transfer(local, remote, temporary_target=temporary_target) + transfer(local, remote, temporary_target=temporary_target) + + compare(local, remote_path) + + _delete_file_or_directory(remote_path) + + +if __name__ == "__main__": + for name, obj in list(globals().items()): + if name.startswith("test_") and callable(obj): + print(f"Running {name}...") + obj()