Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

gh-405: array API support for glass.core.algorithm #423

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ jobs:
env:
FORCE_COLOR: 1

- name: Run tests and generate coverage report
- name: Run tests wih every array backend and generate coverage report
run: nox -s coverage-${{ matrix.python-version }} --verbose
env:
FORCE_COLOR: 1
GLASS_ARRAY_BACKEND: all

- name: Coveralls requires XML report
run: coverage xml
Expand Down
53 changes: 53 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,42 @@ following way -
python -m pytest --cov --doctest-plus
```

### Array API tests

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library available in the
environment by setting `GLASS_ARRAY_BACKEND` to `all`. The testing framework
only installs NumPy automatically; hence, remaining array libraries should
either be installed manually or developers should use `Nox`.

```bash
# run tests using numpy
python -m pytest
GLASS_ARRAY_BACKEND=numpy python -m pytest
# run tests using array_api_strict (should be installed manually)
GLASS_ARRAY_BACKEND=array_api_strict python -m pytest
# run tests using jax (should be installed manually)
GLASS_ARRAY_BACKEND=jax python -m pytest
# run tests using every supported array library available in the environment
GLASS_ARRAY_BACKEND=all python -m pytest
```

Moreover, one can mark a test to be compatible with the array API standard by
decorating it with `@array_api_compatible`. This will `parameterize` the test to
run on every array library specified through `GLASS_ARRAY_BACKEND` -

```py
import types
from tests.conftest import array_api_compatible


@array_api_compatible
def test_something(xp: types.ModuleType):
# use `xp.` to access the array library functionality
...
```

## Documenting

_GLASS_'s documentation is mainly written in the form of
Expand Down Expand Up @@ -173,6 +209,23 @@ syntax -
nox -s tests-3.11
```

One can specify a particular array backend for testing by setting the
`GLASS_ARRAY_BACKEND` environment variable. The default array backend is NumPy.
_GLASS_ can be tested with every supported array library by setting
`GLASS_ARRAY_BACKEND` to `all`.

```bash
# run tests using numpy
nox -s tests-3.11
GLASS_ARRAY_BACKEND=numpy nox -s tests-3.11
# run tests using array_api_strict
GLASS_ARRAY_BACKEND=array_api_strict nox -s tests-3.11
# run tests using jax
GLASS_ARRAY_BACKEND=jax nox -s tests-3.11
# run tests using every supported array library
GLASS_ARRAY_BACKEND=all nox -s tests-3.11
```

The following command can be used to deploy the docs on `localhost` -

```bash
Expand Down
42 changes: 25 additions & 17 deletions glass/core/algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,22 @@

from __future__ import annotations

import numpy as np
import numpy.typing as npt
import typing

if typing.TYPE_CHECKING:
import cupy as cp
import jax.typing as jxt
import numpy as np
import numpy.typing as npt


def nnls(
a: npt.NDArray[np.float64],
b: npt.NDArray[np.float64],
a: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
b: npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike,
*,
tol: float = 0.0,
maxiter: int | None = None,
) -> npt.NDArray[np.float64]:
) -> npt.NDArray[np.float64] | cp.ndarray | jxt.ArrayLike:
"""
Compute a non-negative least squares solution.

Expand Down Expand Up @@ -51,8 +56,11 @@ def nnls(
Chemometrics, 11, 393-401.

"""
a = np.asanyarray(a)
b = np.asanyarray(b)
if a.__array_namespace__() != b.__array_namespace__():
msg = "input arrays should belong to the same array library"
raise ValueError(msg)

xp = a.__array_namespace__()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was talking to @matt-graham and he wondered if we should be using array-api-compat here. That should solve the cupy issue. I believe only NumPy and JAX are stable at this point.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CuPy support will become official in their next major release (v14) where they will start supporting NumPy v2. I hope they push the release in the next 3 months, so that we don't have to experiment with array-api-compat.


if a.ndim != 2:
msg = "input `a` is not a matrix"
Expand All @@ -69,25 +77,25 @@ def nnls(
if maxiter is None:
maxiter = 3 * n

index = np.arange(n)
p = np.full(n, fill_value=False)
x = np.zeros(n)
index = xp.arange(n)
p = xp.full(n, fill_value=False)
x = xp.zeros(n)
for _ in range(maxiter):
if np.all(p):
if xp.all(p):
break
w = np.dot(b - a @ x, a)
m = index[~p][np.argmax(w[~p])]
w = xp.linalg.vecdot(b - a @ x, a, axis=0)
m = index[~p][xp.argmax(w[~p])]
if w[m] <= tol:
break
p[m] = True
while True:
ap = a[:, p]
xp = x[p]
sp = np.linalg.solve(ap.T @ ap, b @ ap)
x_new = x[p]
sp = xp.linalg.solve(ap.T @ ap, b @ ap)
t = sp <= 0
if not np.any(t):
if not xp.any(t):
break
alpha = -np.min(xp[t] / (xp[t] - sp[t]))
alpha = -xp.min(xp[t] / (x_new[t] - sp[t]))
x[p] += alpha * (sp - xp)
p[x <= 0] = False
x[p] = sp
Expand Down
10 changes: 10 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
from pathlib import Path

import nox
Expand Down Expand Up @@ -29,6 +30,15 @@ def lint(session: nox.Session) -> None:
def tests(session: nox.Session) -> None:
"""Run the unit tests."""
session.install("-c", ".github/test-constraints.txt", "-e", ".[test]")

array_backend = os.environ.get("GLASS_ARRAY_BACKEND")
if array_backend == "array_api_strict":
session.install("array_api_strict>=2")
elif array_backend == "jax":
session.install("jax>=0.4.32")
elif array_backend == "all":
session.install("array_api_strict>=2", "jax>=0.4.32")

session.run(
"pytest",
*session.posargs,
Expand Down
89 changes: 89 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,100 @@
import contextlib
import importlib.metadata
import os
import types

import numpy as np
import numpy.typing as npt
import packaging.version
import pytest

from cosmology import Cosmology

from glass import RadialWindow

# environment variable to specify array backends for testing
# can be:
# a particular array library (numpy, jax, array_api_strict, ...)
# all (try finding every supported array library available in the environment)
GLASS_ARRAY_BACKEND: str | bool = os.environ.get("GLASS_ARRAY_BACKEND", False)


def _check_version(lib: str, array_api_compliant_version: str) -> None:
"""
Check if installed library's version is compliant with the array API standard.

Parameters
----------
lib
name of the library.
array_api_compliant_version
version of the library compliant with the array API standard.

Raises
------
ImportError
If the installed version is not compliant with the array API standard.
"""
lib_version = packaging.version.Version(importlib.metadata.version(lib))
if lib_version < packaging.version.Version(array_api_compliant_version):
msg = f"{lib} must be >= {array_api_compliant_version}; found {lib_version}"
raise ImportError(msg)


def _import_and_add_numpy(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add numpy to the backends dictionary."""
_check_version("numpy", "2.1.0")
xp_available_backends.update({"numpy": np})


def _import_and_add_array_api_strict(
xp_available_backends: dict[str, types.ModuleType],
) -> None:
"""Add array_api_strict to the backends dictionary."""
import array_api_strict

_check_version("array_api_strict", "2.0.0")
xp_available_backends.update({"array_api_strict": array_api_strict})
array_api_strict.set_array_api_strict_flags(api_version="2023.12")


def _import_and_add_jax(xp_available_backends: dict[str, types.ModuleType]) -> None:
"""Add jax to the backends dictionary."""
import jax

_check_version("jax", "0.4.32")
xp_available_backends.update({"jax.numpy": jax.numpy})
# enable 64 bit numbers
jax.config.update("jax_enable_x64", val=True)


# a dictionary with all array backends to test
xp_available_backends: dict[str, types.ModuleType] = {}

# if no backend passed, use numpy by default
if not GLASS_ARRAY_BACKEND or GLASS_ARRAY_BACKEND == "numpy":
_import_and_add_numpy(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "array_api_strict":
_import_and_add_array_api_strict(xp_available_backends)
elif GLASS_ARRAY_BACKEND == "jax":
_import_and_add_jax(xp_available_backends)
# if all, try importing every backend
elif GLASS_ARRAY_BACKEND == "all":
with contextlib.suppress(ImportError):
_import_and_add_numpy(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_array_api_strict(xp_available_backends)

with contextlib.suppress(ImportError):
_import_and_add_jax(xp_available_backends)
else:
msg = f"unsupported array backend: {GLASS_ARRAY_BACKEND}"
raise ValueError(msg)

# use this as a decorator for tests involving array API compatible functions
array_api_compatible = pytest.mark.parametrize("xp", xp_available_backends.values())


@pytest.fixture(scope="session")
def cosmo() -> Cosmology:
Expand Down
9 changes: 6 additions & 3 deletions tests/core/test_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import importlib.util
import types

import numpy as np
import pytest
from tests.conftest import array_api_compatible

from glass.core.algorithm import nnls as nnls_glass

# check if scipy is available for testing
HAVE_SCIPY = importlib.util.find_spec("scipy") is not None


@array_api_compatible
@pytest.mark.skipif(not HAVE_SCIPY, reason="test requires SciPy")
def test_nnls(rng: np.random.Generator) -> None:
def test_nnls(rng: np.random.Generator, xp: types.ModuleType) -> None:
from scipy.optimize import nnls as nnls_scipy

# cross-check output with scipy's nnls

a = rng.standard_normal((100, 20))
b = rng.standard_normal((100,))
a = xp.asarray(rng.standard_normal((100, 20)))
b = xp.asarray(rng.standard_normal((100,)))

x_glass = nnls_glass(a, b)
x_scipy, _ = nnls_scipy(a, b)
Expand Down
Loading