Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Support multiple array formats #13

Draft
wants to merge 6 commits into
base: develop
Choose a base branch
from
Draft
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support multiple array formats
  • Loading branch information
sandorkertesz committed May 8, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit d088112e484636a911f593b4e8b8bb227cc6e6eb
55 changes: 55 additions & 0 deletions src/earthkit/meteo/utils/array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# (C) Copyright 2021 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

import threading


class ArrayNamespace:
def __init__(self):
self._api = False
self.lock = threading.Lock()

@property
def api(self):
if self._api is False:
with self.lock:
if self._api is False:
try:
import array_api_compat

self._api = array_api_compat
except Exception:
self._api = None
return self._api

def namespace(self, arrays):
if not arrays:
import numpy as np

return np

if self.api is not None:
return self.api.array_namespace(*arrays)

import numpy as np

if isinstance(arrays[0], np.ndarray):
return np
else:
raise ValueError(
"Can't find namespace for array. Please install array_api_compat package"
)


def array_namespace(*args):
arrays = [a for a in args if hasattr(a, "shape")]
return _NAMESPACE.namespace(arrays)


_NAMESPACE = ArrayNamespace()
82 changes: 82 additions & 0 deletions src/earthkit/meteo/utils/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# (C) Copyright 2021 ECMWF.
#
# This software is licensed under the terms of the Apache Licence Version 2.0
# which can be obtained at http://www.apache.org/licenses/LICENSE-2.0.
# In applying this licence, ECMWF does not waive the privileges and immunities
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.
#

from importlib import import_module

import numpy as np


def modules_installed(*modules):
for module in modules:
try:
import_module(module)
except ImportError:
return False
return True


NO_PYTORCH = not modules_installed("torch")
NO_CUPY = not modules_installed("cupy")
if not NO_CUPY:
try:
import cupy as cp

a = cp.ones(2)
except Exception:
NO_CUPY = True


def is_scalar(data):
return isinstance(data, (int, float)) or data is not data


class ArrayBackend:
ns = None

def asarray(self, *data, **kwargs):
res = [self.ns.asarray(d, **kwargs) for d in data]
r = res if len(res) > 1 else res[0]
return r

def allclose(self, *args, **kwargs):
if is_scalar(args[0]):
v = [self.asarray(a, dtype=self.dtype) for a in args]
else:
v = args
return self.ns.allclose(*v, **kwargs)


class NumpyBackend(ArrayBackend):
def __init__(self):
self.ns = np
self.dtype = np.float64


class PytorchBackend(ArrayBackend):
def __init__(self):
import torch

self.ns = torch
self.dtype = torch.float64


class CupyBackend(ArrayBackend):
def __init__(self):
import cupy

self.ns = cupy
self.dtype = cupy.float64


ARRAY_BACKENDS = [NumpyBackend()]
if not NO_PYTORCH:
ARRAY_BACKENDS.append(PytorchBackend())

if not NO_CUPY:
ARRAY_BACKENDS.append(CupyBackend())
Loading