Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add pinv and fix solve for numpy>=2.1 #63

Merged
merged 6 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/dependabot.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
version: 2
updates:
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.10"
python-version: "3.11"
- name: Install build dependencies
run: python -m pip install build
- name: Build package
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
6 changes: 3 additions & 3 deletions docs/source/changelog.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
# Change Log

## v0.x.x (Unreleased)
## v0.8.0 (2024 Sep 19)
### New features
* Add `numpy.linalg.pinv` wrapper {pull}`63`

### Maintenance and fixes

### Documentation
* Update to handle modified behaviour of `numpy.linalg.solve` {pull}`63`

## v0.7.0 (2024 Jan 17)
### New features
Expand Down
497 changes: 316 additions & 181 deletions docs/source/tutorials/linalg_tutorial.ipynb

Large diffs are not rendered by default.

389 changes: 316 additions & 73 deletions docs/source/tutorials/np_linalg_tutorial_port.ipynb

Large diffs are not rendered by default.

450 changes: 256 additions & 194 deletions docs/source/tutorials/stats_tutorial.ipynb

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build-backend = "flit_core.buildapi"
name = "xarray-einstats"
description = "Stats, linear algebra and einops for xarray"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "ArviZ team", email = "arviz.devs@gmail.com"}
Expand All @@ -27,8 +27,8 @@ classifiers = [
]
dynamic = ["version"]
dependencies = [
"numpy>=1.22",
"scipy>=1.8",
"numpy>=1.23",
"scipy>=1.9",
"xarray>=2022.09.0",
]

Expand Down
2 changes: 1 addition & 1 deletion src/xarray_einstats/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"EinopsAccessor",
]

__version__ = "0.8.0.dev0"
__version__ = "0.8.0"


def sort(da, dim, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/accessors.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Accessors for xarray_einstats features."""

import xarray as xr

from .linalg import (
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/einops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
example usage.

"""

import warnings
from collections.abc import Hashable

Expand Down
90 changes: 84 additions & 6 deletions src/xarray_einstats/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
``matmul`` and ``get_default_dims``.

"""

import warnings

import numpy as np
Expand All @@ -34,6 +35,7 @@
"diagonal",
"solve",
"inv",
"pinv",
]


Expand Down Expand Up @@ -709,19 +711,75 @@ def solve(da, db, dims=None, **kwargs):
"""Wrap :func:`numpy.linalg.solve`.

Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.

Parameters
----------
da : DataArray
db : DataArray
dims : sequence of hashable, optional
It can have either length 2 or 3. If length 2, both dimensions should have the
same length and be present in `da`, and only one of them should also be present in `db`.
If length 3, the first two elements behave the same; the third element is a dimension
of arbitrary length which can only present in `db`.

From NumPy's docstring, a has ``(..., M, M)`` shape and b has ``(M,) or (..., M, K)``.
Here, b can be ``(..., M)`` this case is not limited to 1d, so dims with length two
indicates the two dimensions of length M, with length 3 it is something like (M, M, K),
which can be done thanks to named dimensions.
**kwargs : mapping
Passed to :func:`xarray.apply_ufunc`

Examples
--------
Dimension naming conventions are designed to ease inverse operation with :func:`xarray.dot`.

The following example illustrates what this means and how to check that solve
worked correctly

.. jupyter-execute::

import xarray as xr
import numpy as np
from xarray_einstats.linalg import solve
from xarray_einstats.tutorial import generate_matrices_dataarray

matrices = generate_matrices_dataarray()
matrices

.. jupyter-execute::

b = matrices.std("dim2") # dims (batch, experiment, dim)
y2 = solve(matrices, b, dims=("dim", "dim2")) # dims (batch, experiment, dim2)
np.allclose(b, xr.dot(matrices, y2, dims="dim2"))

"""
if dims is None:
dims = _attempt_default_dims("solve", da.dims, db.dims)
if len(dims) == 3:
b_dim = dims[0] if dims[0] in db.dims else dims[1]
in_dims = [dims[:2], [b_dim, dims[-1]]]
out_dims = [[b_dim, dims[-1]]]
# solve(a, b) in numpy has signature a: (..., M, M) and b: (..., M, K)
# we look which dim is in b -> represents the M
k_dim = dims[-1] # the last element in dims represents the K
remove_k = False
if k_dim in da:
raise ValueError(
f"Found {k_dim} in `da`. If provided, the 3rd element of 'dims' "
"can only be in `db`."
)
else:
in_dims = [dims, dims[:1]]
out_dims = [dims[:1]]
return xr.apply_ufunc(
# a: (..., M, M) and b: (..., M) is not supported, so we add a dummy K
k_dim = "__k_aux_dim__"
remove_k = True
db = db.expand_dims(k_dim)
b_dim = dims[0] if dims[0] in db.dims else dims[1]
y_dim = dims[1] if dims[0] in db.dims else dims[0]
in_dims = [dims[:2], [b_dim, k_dim]]
out_dims = [[y_dim, k_dim]]
da_out = xr.apply_ufunc(
np.linalg.solve, da, db, input_core_dims=in_dims, output_core_dims=out_dims, **kwargs
)
if remove_k:
return da_out.squeeze(k_dim, drop=True)
return da_out


def inv(da, dims=None, **kwargs):
Expand All @@ -734,3 +792,23 @@ def inv(da, dims=None, **kwargs):
return xr.apply_ufunc(
np.linalg.inv, da, input_core_dims=[dims], output_core_dims=[dims], **kwargs
)


def pinv(da, dims=None, **kwargs):
"""Wrap :func:`numpy.linalg.pinv`.

Usage examples of all arguments is available at the :ref:`linalg_tutorial` page.
If both "rtol" and "rcond" are provided, "rtol" will be ignored.
"""
if dims is None:
dims = _attempt_default_dims("pinv", da.dims)
rcond = kwargs.pop("rtol", None)
rcond = kwargs.pop("rcond", rcond)
return xr.apply_ufunc(
np.linalg.pinv,
da,
rcond,
input_core_dims=[dims, []],
output_core_dims=[dims[::-1]],
**kwargs,
)
1 change: 1 addition & 0 deletions src/xarray_einstats/numba.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Module with numba enhanced functions."""

import numba
import numpy as np
import xarray as xr
Expand Down
1 change: 1 addition & 0 deletions src/xarray_einstats/tutorial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tutorial module with data for docs and quick testing."""

import numpy as np
import xarray as xr

Expand Down
17 changes: 15 additions & 2 deletions tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
matrix_rank,
matrix_transpose,
norm,
pinv,
qr,
slogdet,
solve,
Expand Down Expand Up @@ -140,6 +141,12 @@ def test_inv(self, matrices):
assert out.shape == matrices.shape
assert out.dims == matrices.dims

def test_pinv(self, matrices):
out = pinv(matrices, dims=("experiment", "dim"))
out_dims_exp = ("batch", "dim2", "dim", "experiment")
assert out.dims == out_dims_exp
assert out.shape == tuple(out.sizes[dim] for dim in out_dims_exp)

def test_transpose(self, hermitian):
assert_equal(hermitian, matrix_transpose(hermitian, dims=("dim", "dim2")))

Expand Down Expand Up @@ -272,10 +279,16 @@ def test_slogdet_det(self, matrices):
det_da = det(matrices, dims=("dim", "dim2"))
assert_allclose(sign * np.exp(logdet), det_da)

def test_solve(self, matrices):
def test_solve_two_dims(self, matrices):
b = matrices.std("dim2")
y = solve(matrices, b, dims=("dim", "dim2"))
assert_allclose(b, xr.dot(matrices, y.rename(dim="dim2"), dims="dim2"), atol=1e-14)
assert_allclose(b, xr.dot(matrices, y, dim="dim2"), atol=1e-14)

def test_solve_three_dims(self, matrices):
b = matrices.std("dim2")
a = matrices.isel(batch=0)
y = solve(a, b, dims=("dim", "dim2", "batch"))
assert_allclose(b, xr.dot(a, y, dim="dim2").transpose(*b.dims), atol=1e-14)

def test_diagonal(self, matrices):
idx = xr.DataArray(np.arange(len(matrices["dim"])), dims="pointwise_sel")
Expand Down
Loading