Skip to content

Commit

Permalink
Merge branch 'main' into assign-units
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Dec 4, 2024
2 parents 21d59a7 + e8f01b2 commit 5ccf35c
Show file tree
Hide file tree
Showing 7 changed files with 1,112 additions and 90 deletions.
95 changes: 60 additions & 35 deletions brainunit/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import operator
from contextlib import contextmanager
from copy import deepcopy
from functools import wraps
from functools import wraps, partial
from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict

import jax
Expand Down Expand Up @@ -2411,6 +2411,10 @@ def size(self) -> int:
def T(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa).T, unit=self.unit)

@property
def mT(self) -> 'Quantity':
return Quantity(jnp.asarray(self.mantissa).mT, unit=self.unit)

@property
def isreal(self) -> jax.Array:
return jnp.isreal(self.mantissa)
Expand Down Expand Up @@ -4147,24 +4151,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not have_same_dim(result, expected_result):
unit = get_dim_for_display(expected_result)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{result}'"

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)
raise DimensionMismatchError(error_message, get_dim(result))

jax.tree.map(
partial(_check_dim, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4203,6 +4203,19 @@ def new_f(*args, **kwds):
return do_check_units


def _check_dim(f, val, dim):
dim = DIMENSIONLESS if dim is None else dim
if not have_same_dim(val, dim):
unit = get_dim_for_display(dim)
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"dimension {unit} but was "
f"'{val}'"
)
raise DimensionMismatchError(error_message, get_dim(val))


@set_module_as('brainunit')
def check_units(**au):
"""
Expand Down Expand Up @@ -4389,23 +4402,20 @@ def new_f(*args, **kwds):
expected_result = au["result"](*[get_dim(a) for a in args])
else:
expected_result = au["result"]
if au["result"] == bool:
if not isinstance(result, bool):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to be "
"a boolean value, but was of type "
f"{type(result)}"
)
raise TypeError(error_message)
elif not has_same_unit(result, expected_result):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(expected_result)} but was "
f"'{result}'"

if (
jax.tree.structure(expected_result, is_leaf=_is_quantity)
!=
jax.tree.structure(result, is_leaf=_is_quantity)
):
raise TypeError(
f"Expected a return value of type {expected_result} but got {result}"
)
raise UnitMismatchError(error_message, get_unit(result))

jax.tree.map(
partial(_check_unit, f), result, expected_result,
is_leaf=_is_quantity
)
return result

new_f._orig_func = f
Expand Down Expand Up @@ -4520,4 +4530,19 @@ def new_f(*args, **kwds):

return new_f

return do_handle_units
return do_handle_units

def _check_unit(f, val, unit):
unit = UNITLESS if unit is None else unit
if not has_same_unit(val, unit):
error_message = (
"The return value of function "
f"'{f.__name__}' was expected to have "
f"unit {get_unit(val)} but was "
f"'{val}'"
)
raise UnitMismatchError(error_message, get_unit(val))


def _is_quantity(x):
return isinstance(x, Quantity)
87 changes: 83 additions & 4 deletions brainunit/_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from numpy.testing import assert_equal

import brainunit as u
import brainunit as bu
from brainunit._base import (
DIMENSIONLESS,
UNITLESS,
Expand Down Expand Up @@ -1185,7 +1184,7 @@ def test_fail_for_dimension_mismatch(self):

def test_check_dims(self):
"""
Test the check_units decorator
Test the check_dims decorator
"""

@u.check_dims(v=volt.dim)
Expand Down Expand Up @@ -1245,8 +1244,49 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# with pytest.raises(TypeError):
# c_function(False, 1)

# Multiple results
@u.check_dims(result=(second.dim, volt.dim))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.DimensionMismatchError):
d_function(False)

# Multiple results
@u.check_dims(result={'u': second.dim, 'v': (volt.dim, metre.dim)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

d_function2(0)

with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.DimensionMismatchError):
d_function2(2)

def test_check_units(self):
"""
Expand Down Expand Up @@ -1313,8 +1353,47 @@ def c_function(a, b):
c_function(1, 1)
with pytest.raises(TypeError):
c_function(1 * mV, 1)

# Multiple results
@check_units(result=(second, volt))
def d_function(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result:
return 5 * second, 3 * volt
else:
return 3 * volt, 5 * second

# Should work (returns second)
d_function(True)
# Should fail (returns volt)
with pytest.raises(u.UnitMismatchError):
d_function(False)

# Multiple results
@check_units(result={'u': second, 'v': (volt, metre)})
def d_function2(true_result):
"""
Return a value in seconds if return_second is True, otherwise return
a value in volt.
"""
if true_result == 0:
return {'u': 5 * second, 'v': (3 * volt, 2 * metre)}
elif true_result == 1:
return 3 * volt, 5 * second
else:
return {'u': 5 * second, 'v': (3 * volt, 2 * volt)}

# Should work (returns second)
d_function2(0)
# Should fail (returns volt)
with pytest.raises(TypeError):
c_function(False, 1)
d_function2(1)

with pytest.raises(u.UnitMismatchError):
d_function2(2)

def test_handle_units(self):
"""
Expand Down
54 changes: 36 additions & 18 deletions brainunit/lax/_lax_linalg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from __future__ import annotations

from typing import Union, Callable, Any
from typing import Union, Callable, Any, Tuple, List

import jax
from jax import lax
from jax import lax, Array

from brainunit.lax._lax_change_unit import unit_change
from .._base import Quantity, maybe_decimal, fail_for_unit_mismatch
Expand Down Expand Up @@ -65,7 +65,7 @@ def eig(
x: Union[Quantity, jax.typing.ArrayLike],
compute_left_eigenvectors: bool = True,
compute_right_eigenvectors: bool = True
) -> tuple[Quantity, jax.Array, jax.Array] | list[jax.Array] | tuple[Quantity, jax.Array] | Quantity:
) -> tuple[Array | Quantity, Array, Array] | list[Array] | tuple[Array | Quantity, Array] | tuple[Array | Quantity]:
"""Eigendecomposition of a general matrix.
Nonsymmetric eigendecomposition is at present only implemented on CPU.
Expand All @@ -87,31 +87,37 @@ def eig(
"""
if compute_left_eigenvectors and compute_right_eigenvectors:
if isinstance(x, Quantity):
w, vl, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=True, compute_right_eigenvectors=True)
w, vl, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vl, vr
else:
return lax.linalg.eig(x, compute_left_eigenvectors=True, compute_right_eigenvectors=True)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
elif compute_left_eigenvectors:
if isinstance(x, Quantity):
w, vl = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=True, compute_right_eigenvectors=False)
w, vl = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vl
else:
return lax.linalg.eig(x, compute_left_eigenvectors, compute_left_eigenvectors=True,
compute_right_eigenvectors=False)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)

elif compute_right_eigenvectors:
if isinstance(x, Quantity):
w, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=False, compute_right_eigenvectors=True)
w, vr = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return maybe_decimal(Quantity(w, unit=x.unit)), vr
else:
return lax.linalg.eig(x, compute_right_eigenvectors, compute_left_eigenvectors=False,
compute_right_eigenvectors=True)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
else:
if isinstance(x, Quantity):
w = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=False, compute_right_eigenvectors=False)
return maybe_decimal(Quantity(w, unit=x.unit))
w = lax.linalg.eig(x.mantissa, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)
return (maybe_decimal(Quantity(w, unit=x.unit)),)
else:
return lax.linalg.eig(x, compute_left_eigenvectors=False, compute_right_eigenvectors=False)
return lax.linalg.eig(x, compute_left_eigenvectors=compute_left_eigenvectors,
compute_right_eigenvectors=compute_right_eigenvectors)


@set_module_as('brainunit.lax')
Expand Down Expand Up @@ -344,18 +350,30 @@ def schur(
@set_module_as('brainunit.lax')
def svd(
x: Union[Quantity, jax.typing.ArrayLike],
) -> tuple[jax.Array, Quantity | jax.Array, jax.Array]:
*,
full_matrices: bool = True,
compute_uv: bool = True,
subset_by_index: tuple[int, int] | None = None,
algorithm: jax.lax.linalg.SvdAlgorithm | None = None,
) -> Union[Quantity, jax.typing.ArrayLike] | tuple[jax.Array, Quantity | jax.Array, jax.Array]:
"""Singular value decomposition.
Returns the singular values if compute_uv is False, otherwise returns a triple
containing the left singular vectors, the singular values and the adjoint of
the right singular vectors.
"""
if isinstance(x, Quantity):
u, s, vh = lax.linalg.svd(x.mantissa)
return u, maybe_decimal(Quantity(s, unit=x.unit)), vh
if compute_uv:
u, s, vh = lax.linalg.svd(x.mantissa, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)
return u, maybe_decimal(Quantity(s, unit=x.unit)), vh
else:
s = lax.linalg.svd(x.mantissa, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)
return maybe_decimal(Quantity(s, unit=x.unit))
else:
return lax.linalg.svd(x)
return lax.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv,
subset_by_index=subset_by_index, algorithm=algorithm)


@set_module_as('brainunit.lax')
Expand Down
Loading

0 comments on commit 5ccf35c

Please sign in to comment.