Skip to content

Commit

Permalink
Merge branch 'main' into handle-units
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 4, 2024
2 parents 21d59a7 + e8f01b2 commit 14d5798
Show file tree
Hide file tree
Showing 7 changed files with 1,069 additions and 56 deletions.
18 changes: 17 additions & 1 deletion brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import operator
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict

import jax
Expand Down Expand Up @@ -4444,6 +4444,22 @@ def new_f(*args, **kwds):
return do_check_units


def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(val)} but was "
f"'{val}'"
)
raise UnitMismatchError(error_message, get_unit(val))


def _is_quantity(x):
return isinstance(x, Quantity)


@set_module_as('brainunit')
def handle_units(**au):
"""
Expand Down
87 changes: 83 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from numpy.testing import assert_equal

import brainunit as u
import brainunit as bu
from brainunit._base import (
DIMENSIONLESS,
UNITLESS,
Expand Down Expand Up @@ -1185,7 +1184,7 @@ def test_fail_for_dimension_mismatch(self):

def test_check_dims(self):
"""
Test the check_units decorator
Test the check_dims decorator
"""

@u.check_dims(v=volt.dim)
Expand Down Expand Up @@ -1245,8 +1244,49 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# with pytest.raises(TypeError):
# c_function(False, 1)

# Multiple results
@u.check_dims(result=(second.dim, volt.dim))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.DimensionMismatchError):
d_function(False)

# Multiple results
@u.check_dims(result={'u': second.dim, 'v': (volt.dim, metre.dim)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

d_function2(0)

with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.DimensionMismatchError):
d_function2(2)

def test_check_units(self):
"""
Expand Down Expand Up @@ -1313,8 +1353,47 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@check_units(result=(second, volt))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.UnitMismatchError):
d_function(False)

# Multiple results
@check_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

# Should work (returns second)
d_function2(0)
# Should fail (returns volt)
with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_handle_units(self):
"""
Expand Down
54 changes: 36 additions & 18 deletions brainunit/lax/_lax_linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from typing import Union, Callable, Any
from typing import Union, Callable, Any, Tuple, List

import jax
from jax import lax
from jax import lax, Array

from brainunit.lax._lax_change_unit import unit_change
from .._base import Quantity, maybe_decimal, fail_for_unit_mismatch
Expand Down Expand Up @@ -65,7 +65,7 @@ def eig(
x: Union[Quantity, jax.typing.ArrayLike],
compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True
) -> tuple[Quantity, jax.Array, jax.Array] | list[jax.Array] | tuple[Quantity, jax.Array] | Quantity:
) -> tuple[Array | Quantity, Array, Array] | list[Array] | tuple[Array | Quantity, Array] | tuple[Array | Quantity]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
Expand All @@ -87,31 +87,37 @@ def eig(
"""
if compute_left_eigenvectors and compute_right_eigenvectors:
if isinstance(x, Quantity):
w, vl, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=True, compute_right_eigenvectors=True)
w, vl, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vl, vr
else:
return lax.linalg.eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
elif compute_left_eigenvectors:
if isinstance(x, Quantity):
w, vl = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=True, compute_right_eigenvectors=False)
w, vl = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vl
else:
return lax.linalg.eig(x, compute_left_eigenvectors, compute_left_eigenvectors=True,
compute_right_eigenvectors=False)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)

elif compute_right_eigenvectors:
if isinstance(x, Quantity):
w, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=False, compute_right_eigenvectors=True)
w, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vr
else:
return lax.linalg.eig(x, compute_right_eigenvectors, compute_left_eigenvectors=False,
compute_right_eigenvectors=True)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
else:
if isinstance(x, Quantity):
w = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=False, compute_right_eigenvectors=False)
return maybe_decimal(Quantity(w, unit=x.unit))
w = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return (maybe_decimal(Quantity(w, unit=x.unit)),)
else:
return lax.linalg.eig(x, compute_left_eigenvectors=False, compute_right_eigenvectors=False)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)


@set_module_as('brainunit.lax')
Expand Down Expand Up @@ -344,18 +350,30 @@ def schur(
@set_module_as('brainunit.lax')
def svd(
x: Union[Quantity, jax.typing.ArrayLike],
) -> tuple[jax.Array, Quantity | jax.Array, jax.Array]:
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
algorithm: jax.lax.linalg.SvdAlgorithm | None = None,
) -> Union[Quantity, jax.typing.ArrayLike] | tuple[jax.Array, Quantity | jax.Array, jax.Array]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
"""
if isinstance(x, Quantity):
u, s, vh = lax.linalg.svd(x.mantissa)
return u, maybe_decimal(Quantity(s, unit=x.unit)), vh
if compute_uv:
u, s, vh = lax.linalg.svd(x.mantissa, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)
return u, maybe_decimal(Quantity(s, unit=x.unit)), vh
else:
s = lax.linalg.svd(x.mantissa, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)
return maybe_decimal(Quantity(s, unit=x.unit))
else:
return lax.linalg.svd(x)
return lax.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)


@set_module_as('brainunit.lax')
Expand Down
Loading

0 comments on commit 14d5798

Please sign in to comment.