From c17b9fad690d0613469fe9da3678d72a8487bea3 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 10 Jun 2024 17:04:46 +0800 Subject: [PATCH 01/23] Update _compat_numpy.py --- brainunit/math/_compat_numpy.py | 463 ++++++++++++++++++++++++++++---- 1 file changed, 410 insertions(+), 53 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 0dbc908..09dda8f 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - +import functools from collections.abc import Sequence from functools import wraps -from typing import (Callable, Union, Optional) +from typing import (Callable, Union, Optional, Any) import brainstate as bst import jax @@ -162,85 +162,335 @@ def f(*args, unit: Unit = None, **kwargs): ones = wrap_array_creation_function(jnp.ones) zeros = wrap_array_creation_function(jnp.zeros) +# docs for full, eye, identity, tri, empty, ones, zeros + +full.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. +else return an array of `shape` filled with `fill_value`. + + Args: + shape: sequence of integers, describing the shape of the output array. + fill_value: the value to fill the new array with. + dtype: the type of the output array, or `None`. If not `None`, `fill_value` + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +eye.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. +else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +identity.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. +else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +tri.__doc__ = """ +Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. +else return a triangular matrix of `shape`. + + Args: + n: the number of rows in the output array. + m: the number of columns with default being `n`. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# empty +empty.__doc__ = """ +Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. +else return an array of `shape` with uninitialized values. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be of type `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# ones +ones.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. +else return an array of `shape` filled with 1. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + +# zeros +zeros.__doc__ = """ +Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. +else return an array of `shape` filled with 0. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + out: Quantity if `unit` is provided, else an array. +""" + @set_module_as('brainunit.math') -def full_like(a, fill_value, dtype=None, shape=None): - if isinstance(a, Quantity) and isinstance(fill_value, Quantity): - fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.') - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity): - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) +def full_like(a: Union[Quantity, jax.Array, np.ndarray], + fill_value: Union[jax.Array, np.ndarray], + unit: Unit = None, + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `a` filled with `fill_value`. + + Args: + a: array_like, Quantity, shape, or dtype + fill_value: scalar or array_like + unit: Unit, optional + dtype: data-type, optional + shape: sequence of ints, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like') + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def diag(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.diag(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.diag(a, k=k) +def diag(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Extract a diagonal or construct a diagonal array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.diag(a.value, k=k) * unit + else: + return jnp.diag(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for diag') + return jnp.diag(a, k=k) @set_module_as('brainunit.math') -def tril(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.tril(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.tril(a, k=k) +def tril(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Lower triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.tril(a.value, k=k) * unit + else: + return jnp.tril(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for tril') + return jnp.tril(a, k=k) @set_module_as('brainunit.math') -def triu(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.triu(a.value, k=k), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.triu(a, k=k) +def triu(a: Union[Quantity, jax.Array, np.ndarray], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Upper triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.triu(a.value, k=k) * unit + else: + return jnp.triu(a, k=k) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for triu') + return jnp.triu(a, k=k) @set_module_as('brainunit.math') -def empty_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.empty_like(a, dtype=dtype, shape=shape) +def empty_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `a` with uninitialized values. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for empty_like') + return jnp.empty_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def ones_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.ones_like(a, dtype=dtype, shape=shape) +def ones_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. + else return an array of `a` filled with 1. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for ones_like') + return jnp.ones_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') -def zeros_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.zeros_like(a, dtype=dtype, shape=shape) +def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. + else return an array of `a` filled with 0. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + out: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit else: - raise ValueError(f'Unsupported type: {type(a)} for zeros_like') + return jnp.zeros_like(a, dtype=dtype, shape=shape) @set_module_as('brainunit.math') def asarray( - a, + a: Union[Quantity, jax.Array, np.ndarray, Sequence[Quantity]], dtype: Optional[bst.typing.DTypeLike] = None, order: Optional[str] = None, unit: Optional[Unit] = None, -): +) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if isinstance(a, Quantity): @@ -265,6 +515,19 @@ def asarray( @set_module_as('brainunit.math') def arange(*args, **kwargs): + ''' + Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity, optional + stop: number, Quantity, optional + step: number, optional + dtype: dtype, optional + unit: Unit, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' # arange has a bit of a complicated argument structure unfortunately # we leave the actual checking of the number of arguments to numpy, though @@ -343,7 +606,26 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): +def linspace(start: Union[Quantity, jax.Array, np.ndarray], + stop: Union[Quantity, jax.Array, np.ndarray], + num: int = 50, + endpoint: bool = True, + retstep: bool = False, + dtype: bst.typing.DTypeLike = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + retstep: bool, optional + dtype: dtype, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' fail_for_dimension_mismatch( start, stop, @@ -360,7 +642,26 @@ def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): @set_module_as('brainunit.math') -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): +def logspace(start: Union[Quantity, jax.Array, np.ndarray], + stop: Union[Quantity, jax.Array, np.ndarray], + num: int = 50, + endpoint: bool = True, + base: float = 10.0, + dtype: bst.typing.DTypeLike = None): + ''' + Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + base: float, optional + dtype: dtype, optional + + Returns: + out: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' fail_for_dimension_mismatch( start, stop, @@ -377,7 +678,22 @@ def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): @set_module_as('brainunit.math') -def fill_diagonal(a, val, wrap=False, inplace=True): +def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], + val: Union[Quantity, jax.Array, np.ndarray], + wrap: bool = False, + inplace: bool = True) -> Union[Quantity, jax.Array]: + ''' + Fill the main diagonal of the given array of `a` with `val`. + + Args: + a: array_like, Quantity + val: scalar, Quantity + wrap: bool, optional + inplace: bool, optional + + Returns: + out: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + ''' if isinstance(a, Quantity) and isinstance(val, Quantity): fail_for_dimension_mismatch(a, val) return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) @@ -390,7 +706,20 @@ def fill_diagonal(a, val, wrap=False, inplace=True): @set_module_as('brainunit.math') -def array_split(ary, indices_or_sections, axis=0): +def array_split(ary: Union[Quantity, jax.Array, np.ndarray], + indices_or_sections: Union[int, jax.Array, np.ndarray], + axis: int = 0) -> Union[Quantity, jax.Array]: + ''' + Split an array into multiple sub-arrays. + + Args: + ary: array_like, Quantity + indices_or_sections: int, array_like + axis: int, optional + + Returns: + out: Quantity if `ary` is a Quantity, else an array. + ''' if isinstance(ary, Quantity): return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), unit=ary.unit) elif isinstance(ary, (jax.Array, np.ndarray)): @@ -400,7 +729,22 @@ def array_split(ary, indices_or_sections, axis=0): @set_module_as('brainunit.math') -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): +def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], + copy: bool = True, + sparse: bool = False, + indexing: str = 'xy'): + ''' + Return coordinate matrices from coordinate vectors. + + Args: + xi: array_like, Quantity + copy: bool, optional + sparse: bool, optional + indexing: str, optional + + Returns: + out: Quantity if `xi` are Quantities that have the same unit, else an array. + ''' from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): fail_for_dimension_mismatch(*xi) @@ -412,7 +756,20 @@ def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): @set_module_as('brainunit.math') -def vander(x, N=None, increasing=False): +def vander(x: Union[Quantity, jax.Array, np.ndarray], + N: bool=None, + increasing: bool=False) -> Union[Quantity, jax.Array]: + ''' + Generate a Vandermonde matrix. + + Args: + x: array_like, Quantity + N: int, optional + increasing: bool, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' if isinstance(x, Quantity): return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) elif isinstance(x, (jax.Array, np.ndarray)): From 72ccd90a20537051b2ba8f292951650d41c56b21 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Mon, 10 Jun 2024 23:39:46 +0800 Subject: [PATCH 02/23] Update _compat_numpy.py --- brainunit/math/_compat_numpy.py | 1756 ++++++++++++++++++++++++++----- 1 file changed, 1500 insertions(+), 256 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 09dda8f..0dae459 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -15,7 +15,7 @@ import functools from collections.abc import Sequence from functools import wraps -from typing import (Callable, Union, Optional, Any) +from typing import (Callable, Union, Optional, Any, List) import brainstate as bst import jax @@ -23,9 +23,11 @@ import numpy as np import opt_einsum from brainstate._utils import set_module_as +from jax import Array from jax._src.numpy.lax_numpy import _einsum from ._utils import _compatible_with_quantity +from .. import Quantity from .._base import (DIMENSIONLESS, Quantity, Unit, @@ -165,8 +167,8 @@ def f(*args, unit: Unit = None, **kwargs): # docs for full, eye, identity, tri, empty, ones, zeros full.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. -else return an array of `shape` filled with `fill_value`. + Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `shape` filled with `fill_value`. Args: shape: sequence of integers, describing the shape of the output array. @@ -183,8 +185,8 @@ def f(*args, unit: Unit = None, **kwargs): """ eye.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. -else return an identity matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. Args: n: the number of rows (and columns) in the output array. @@ -203,8 +205,8 @@ def f(*args, unit: Unit = None, **kwargs): """ identity.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. -else return an identity matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. Args: n: the number of rows (and columns) in the output array. @@ -220,8 +222,8 @@ def f(*args, unit: Unit = None, **kwargs): """ tri.__doc__ = """ -Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. -else return a triangular matrix of `shape`. + Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. + else return a triangular matrix of `shape`. Args: n: the number of rows in the output array. @@ -242,8 +244,8 @@ def f(*args, unit: Unit = None, **kwargs): # empty empty.__doc__ = """ -Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. -else return an array of `shape` with uninitialized values. + Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `shape` with uninitialized values. Args: shape: sequence of integers, describing the shape of the output array. @@ -260,8 +262,8 @@ def f(*args, unit: Unit = None, **kwargs): # ones ones.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. -else return an array of `shape` filled with 1. + Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. + else return an array of `shape` filled with 1. Args: shape: sequence of integers, describing the shape of the output array. @@ -278,8 +280,8 @@ def f(*args, unit: Unit = None, **kwargs): # zeros zeros.__doc__ = """ -Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. -else return an array of `shape` filled with 0. + Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. + else return an array of `shape` filled with 0. Args: shape: sequence of integers, describing the shape of the output array. @@ -296,8 +298,8 @@ def f(*args, unit: Unit = None, **kwargs): @set_module_as('brainunit.math') -def full_like(a: Union[Quantity, jax.Array, np.ndarray], - fill_value: Union[jax.Array, np.ndarray], +def full_like(a: Union[Quantity, bst.typing.ArrayLike], + fill_value: Union[bst.typing.ArrayLike], unit: Unit = None, dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None) -> Union[Quantity, jax.Array]: @@ -326,7 +328,7 @@ def full_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def diag(a: Union[Quantity, jax.Array, np.ndarray], +def diag(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -351,7 +353,7 @@ def diag(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def tril(a: Union[Quantity, jax.Array, np.ndarray], +def tril(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -376,7 +378,7 @@ def tril(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def triu(a: Union[Quantity, jax.Array, np.ndarray], +def triu(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, unit: Unit = None) -> Union[Quantity, jax.Array]: ''' @@ -401,7 +403,7 @@ def triu(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def empty_like(a: Union[Quantity, jax.Array, np.ndarray], +def empty_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -429,7 +431,7 @@ def empty_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def ones_like(a: Union[Quantity, jax.Array, np.ndarray], +def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -457,7 +459,7 @@ def ones_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], +def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, unit: Unit = None) -> Union[Quantity, jax.Array]: @@ -486,7 +488,7 @@ def zeros_like(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') def asarray( - a: Union[Quantity, jax.Array, np.ndarray, Sequence[Quantity]], + a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], dtype: Optional[bst.typing.DTypeLike] = None, order: Optional[str] = None, unit: Optional[Unit] = None, @@ -606,12 +608,12 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start: Union[Quantity, jax.Array, np.ndarray], - stop: Union[Quantity, jax.Array, np.ndarray], +def linspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], num: int = 50, - endpoint: bool = True, - retstep: bool = False, - dtype: bst.typing.DTypeLike = None) -> Union[Quantity, jax.Array]: + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. @@ -642,12 +644,12 @@ def linspace(start: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def logspace(start: Union[Quantity, jax.Array, np.ndarray], - stop: Union[Quantity, jax.Array, np.ndarray], - num: int = 50, - endpoint: bool = True, - base: float = 10.0, - dtype: bst.typing.DTypeLike = None): +def logspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: Optional[int] = 50, + endpoint: Optional[bool] = True, + base: Optional[float] = 10.0, + dtype: Optional[bst.typing.DTypeLike] = None): ''' Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. @@ -678,10 +680,10 @@ def logspace(start: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], - val: Union[Quantity, jax.Array, np.ndarray], - wrap: bool = False, - inplace: bool = True) -> Union[Quantity, jax.Array]: +def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], + val: Union[Quantity, bst.typing.ArrayLike], + wrap: Optional[bool] = False, + inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -706,9 +708,9 @@ def fill_diagonal(a: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def array_split(ary: Union[Quantity, jax.Array, np.ndarray], - indices_or_sections: Union[int, jax.Array, np.ndarray], - axis: int = 0) -> Union[Quantity, jax.Array]: +def array_split(ary: Union[Quantity, bst.typing.ArrayLike], + indices_or_sections: Union[int, bst.typing.ArrayLike], + axis: Optional[int] = 0) -> list[Quantity] | list[Array]: ''' Split an array into multiple sub-arrays. @@ -721,18 +723,18 @@ def array_split(ary: Union[Quantity, jax.Array, np.ndarray], out: Quantity if `ary` is a Quantity, else an array. ''' if isinstance(ary, Quantity): - return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), unit=ary.unit) - elif isinstance(ary, (jax.Array, np.ndarray)): + return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + elif isinstance(ary, (bst.typing.ArrayLike)): return jnp.array_split(ary, indices_or_sections, axis) else: raise ValueError(f'Unsupported type: {type(ary)} for array_split') @set_module_as('brainunit.math') -def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], - copy: bool = True, - sparse: bool = False, - indexing: str = 'xy'): +def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy'): ''' Return coordinate matrices from coordinate vectors. @@ -756,9 +758,9 @@ def meshgrid(*xi: Union[Quantity, jax.Array, np.ndarray], @set_module_as('brainunit.math') -def vander(x: Union[Quantity, jax.Array, np.ndarray], - N: bool=None, - increasing: bool=False) -> Union[Quantity, jax.Array]: +def vander(x: Union[Quantity, bst.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Generate a Vandermonde matrix. @@ -782,7 +784,16 @@ def vander(x: Union[Quantity, jax.Array, np.ndarray], # ----------------------- @set_module_as('brainunit.math') -def ndim(a): +def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: + ''' + Return the number of dimensions of an array. + + Args: + a: array_like, Quantity + + Returns: + out: int + ''' if isinstance(a, Quantity): return a.ndim else: @@ -790,7 +801,16 @@ def ndim(a): @set_module_as('brainunit.math') -def isreal(a): +def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return True if the input array is real. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isreal else: @@ -798,7 +818,16 @@ def isreal(a): @set_module_as('brainunit.math') -def isscalar(a): +def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: + ''' + Return True if the input is a scalar. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isscalar else: @@ -806,7 +835,16 @@ def isscalar(a): @set_module_as('brainunit.math') -def isfinite(a): +def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is finite or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isfinite else: @@ -814,7 +852,16 @@ def isfinite(a): @set_module_as('brainunit.math') -def isinf(a): +def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is infinite or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isinf else: @@ -822,7 +869,16 @@ def isinf(a): @set_module_as('brainunit.math') -def isnan(a): +def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is NaN or not. + + Args: + a: array_like, Quantity + + Returns: + out: boolean array + ''' if isinstance(a, Quantity): return a.isnan else: @@ -830,7 +886,7 @@ def isnan(a): @set_module_as('brainunit.math') -def shape(a): +def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: """ Return the shape of an array. @@ -870,7 +926,7 @@ def shape(a): @set_module_as('brainunit.math') -def size(a, axis=None): +def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: """ Return the number of elements along a given axis. @@ -963,276 +1019,1042 @@ def f(x, *args, **kwargs): diff = wrap_math_funcs_keep_unit_unary(jnp.diff) modf = wrap_math_funcs_keep_unit_unary(jnp.modf) +# docs for the functions above +real.__doc__ = ''' + Return the real part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -# math funcs keep unit (binary) -# ----------------------------- +imag.__doc__ = ''' + Return the imaginary part of the complex argument. -def wrap_math_funcs_keep_unit_binary(func): - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + Args: + x: array_like, Quantity - f.__module__ = 'brainunit.math' - return f + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +conj.__doc__ = ''' + Return the complex conjugate of the argument. -fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) -mod = wrap_math_funcs_keep_unit_binary(jnp.mod) -copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) -heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) -maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) -minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) -fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) -fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) -lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) -gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -# math funcs keep unit (n-ary) -# ---------------------------- -@set_module_as('brainunit.math') -def interp(x, xp, fp, left=None, right=None, period=None): - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value - else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result +conjugate.__doc__ = ''' + Return the complex conjugate of the argument. + Args: + x: array_like, Quantity -@set_module_as('brainunit.math') -def clip(a, a_min, a_max): - unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +negative.__doc__ = ''' + Return the negative of the argument. -# math funcs match unit (binary) -# ------------------------------ + Args: + x: array_like, Quantity -def wrap_math_funcs_match_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' - f.__module__ = 'brainunit.math' - return f +positive.__doc__ = ''' + Return the positive of the argument. + Args: + x: array_like, Quantity -add = wrap_math_funcs_match_unit_binary(jnp.add) -subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) -nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +abs.__doc__ = ''' + Return the absolute value of the argument. -# math funcs change unit (unary) -# ------------------------------ + Args: + x: array_like, Quantity -def wrap_math_funcs_change_unit_unary(func, change_unit_func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' - f.__module__ = 'brainunit.math' - return f +round_.__doc__ = ''' + Round an array to the nearest integer. + Args: + x: array_like, Quantity -reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +around.__doc__ = ''' + Round an array to the nearest integer. -@set_module_as('brainunit.math') -def prod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -@set_module_as('brainunit.math') -def nanprod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) +round.__doc__ = ''' + Round an array to the nearest integer. + Args: + x: array_like, Quantity -product = prod + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' +rint.__doc__ = ''' + Round an array to the nearest integer. -@set_module_as('brainunit.math') -def cumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -@set_module_as('brainunit.math') -def nancumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) +floor.__doc__ = ''' + Return the floor of the argument. + Args: + x: array_like, Quantity -cumproduct = cumprod + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) -nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) -frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) -sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) -cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) -square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) +ceil.__doc__ = ''' + Return the ceiling of the argument. + Args: + x: array_like, Quantity -# math funcs change unit (binary) -# ------------------------------- + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -def wrap_math_funcs_change_unit_binary(func, change_unit_func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) - ) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) - elif isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') +trunc.__doc__ = ''' + Return the truncated value of the argument. - f.__module__ = 'brainunit.math' - return f + Args: + x: array_like, Quantity + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' -multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) -divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) +fix.__doc__ = ''' + Return the nearest integer towards zero. + Args: + x: array_like, Quantity -@set_module_as('brainunit.math') -def power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y, *args, **kwargs), unit=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +sum.__doc__ = ''' + Return the sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nancumsum.__doc__ = ''' + Return the cumulative sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nansum.__doc__ = ''' + Return the sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +cumsum.__doc__ = ''' + Return the cumulative sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +ediff1d.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +absolute.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +fabs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +median.__doc__ = ''' + Return the median of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmin.__doc__ = ''' + Return the minimum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmax.__doc__ = ''' + Return the maximum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +ptp.__doc__ = ''' + Return the range of the array elements (maximum - minimum). + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +average.__doc__ = ''' + Return the weighted average of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +mean.__doc__ = ''' + Return the mean of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +std.__doc__ = ''' + Return the standard deviation of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmedian.__doc__ = ''' + Return the median of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanmean.__doc__ = ''' + Return the mean of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +nanstd.__doc__ = ''' + Return the standard deviation of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +diff.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + +modf.__doc__ = ''' + Return the fractional and integer parts of the array elements. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity tuple if `x` is a Quantity, else an array tuple. +''' + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) +mod = wrap_math_funcs_keep_unit_binary(jnp.mod) +copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) +heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) +maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) +minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) +fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) +fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) +lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) +gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) + +# docs for the functions above +fmod.__doc__ = ''' + Return the element-wise remainder of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +mod.__doc__ = ''' + Return the element-wise modulus of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +copysign.__doc__ = ''' + Return a copy of the first array elements with the sign of the second array. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +heaviside.__doc__ = ''' + Compute the Heaviside step function. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +maximum.__doc__ = ''' + Element-wise maximum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +minimum.__doc__ = ''' + Element-wise minimum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmax.__doc__ = ''' + Element-wise maximum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmin.__doc__ = ''' + Element-wise minimum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +lcm.__doc__ = ''' + Return the least common multiple of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +gcd.__doc__ = ''' + Return the greatest common divisor of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('brainunit.math') +def interp(x: Union[Quantity, bst.typing.ArrayLike], + xp: Union[Quantity, bst.typing.ArrayLike], + fp: Union[Quantity, bst.typing.ArrayLike], + left: Union[Quantity, bst.typing.ArrayLike] = None, + right: Union[Quantity, bst.typing.ArrayLike] = None, + period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: + ''' + One-dimensional linear interpolation. + + Args: + x: array_like, Quantity + xp: array_like, Quantity + fp: array_like, Quantity + left: array_like, Quantity, optional + right: array_like, Quantity, optional + period: array_like, Quantity, optional + + Returns: + out: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +@set_module_as('brainunit.math') +def clip(a: Union[Quantity, bst.typing.ArrayLike], + a_min: Union[Quantity, bst.typing.ArrayLike], + a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Clip (limit) the values in an array. + + Args: + a: array_like, Quantity + a_min: array_like, Quantity + a_max: array_like, Quantity + + Returns: + out: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +add = wrap_math_funcs_match_unit_binary(jnp.add) +subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) +nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) + +# docs for the functions above +add.__doc__ = ''' + Add arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +subtract.__doc__ = ''' + Subtract arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +nextafter.__doc__ = ''' + Return the next floating-point value after `x1` towards `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(func, change_unit_func): + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) +reciprocal.__doc__ = ''' + Return the reciprocal of the argument. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if `x` is a Quantity, else an array. +''' + + +@set_module_as('brainunit.math') +def prod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None, + keepdims: Optional[bool] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None, + promote_integers: bool = True) -> Union[Quantity, jax.Array]: + ''' + Return the product of array elements over a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + promote_integers: bool, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + + +@set_module_as('brainunit.math') +def nanprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None, + keepdims: Optional[...] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None): + ''' + Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + + +product = prod + + +@set_module_as('brainunit.math') +def cumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('brainunit.math') +def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + out: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + +var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) +nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) +frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) +sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) +cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) +square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) + +# docs for the functions above +var.__doc__ = ''' + Compute the variance along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +nanvar.__doc__ = ''' + Compute the variance along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +frexp.__doc__ = ''' + Decompose a floating-point number into its mantissa and exponent. + + Args: + x: array_like, Quantity + + Returns: + out: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. +''' + +sqrt.__doc__ = ''' + Compute the square root of each element. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square root of the unit of `x`, else an array. +''' + +cbrt.__doc__ = ''' + Compute the cube root of each element. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the cube root of the unit of `x`, else an array. +''' + +square.__doc__ = ''' + Compute the square of each element. + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(func, change_unit_func): + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + +multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) +divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) +divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) +convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) + +# docs for the functions above +multiply.__doc__ = ''' + Multiply arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +divide.__doc__ = ''' + Divide arguments element-wise. + + Args: + x: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +cross.__doc__ = ''' + Return the cross product of two (arrays of) vectors. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +ldexp.__doc__ = ''' + Return x1 * 2**x2, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. +''' + +true_divide.__doc__ = ''' + Returns a true division of the inputs, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +divmod.__doc__ = ''' + Return element-wise quotient and remainder simultaneously. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +convolve.__doc__ = ''' + Returns the discrete, linear convolution of two one-dimensional sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' @set_module_as('brainunit.math') -def floor_divide(x, y, *args, **kwargs): +def power(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.floor_divide(x, y, *args, **kwargs) + return jnp.power(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y, *args, **kwargs), unit=x.unit / y)) + return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value, *args, **kwargs), unit=x / y.unit)) + return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') @set_module_as('brainunit.math') -def float_power(x, y, *args, **kwargs): +def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return the largest integer smaller or equal to the division of the inputs. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y.value, *args, **kwargs), unit=x.unit ** y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.float_power(x, y, *args, **kwargs) + return jnp.floor_divide(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y, *args, **kwargs), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x, y.value, *args, **kwargs), unit=x ** y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('brainunit.math') +def float_power(x: Union[Quantity, bst.typing.ArrayLike], + y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + Args: + x: array_like, Quantity + y: array_like -divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.float_power(x, y) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') @set_module_as('brainunit.math') -def remainder(x, y, *args, **kwargs): +def remainder(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]): if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value, *args, **kwargs), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.remainder(x, y, *args, **kwargs) + return jnp.remainder(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y, *args, **kwargs), unit=x.unit % y)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value, *args, **kwargs), unit=x % y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') -convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) - - # math funcs only accept unitless (unary) # --------------------------------------- @@ -1282,6 +2104,297 @@ def f(x, *args, **kwargs): quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) +# docs for the functions above +exp.__doc__ = ''' + Calculate the exponential of all elements in the input array. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +exp2.__doc__ = ''' + Calculate 2 raised to the power of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +expm1.__doc__ = ''' + Calculate the exponential of the input elements minus 1. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log.__doc__ = ''' + Natural logarithm, element-wise. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log10.__doc__ = ''' + Base-10 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log1p.__doc__ = ''' + Natural logarithm of 1 + the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +log2.__doc__ = ''' + Base-2 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arccos.__doc__ = ''' + Compute the arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arccosh.__doc__ = ''' + Compute the hyperbolic arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arcsin.__doc__ = ''' + Compute the arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arcsinh.__doc__ = ''' + Compute the hyperbolic arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arctan.__doc__ = ''' + Compute the arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +arctanh.__doc__ = ''' + Compute the hyperbolic arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +cos.__doc__ = ''' + Compute the cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +cosh.__doc__ = ''' + Compute the hyperbolic cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sin.__doc__ = ''' + Compute the sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sinc.__doc__ = ''' + Compute the sinc function of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sinh.__doc__ = ''' + Compute the hyperbolic sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +tan.__doc__ = ''' + Compute the tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +tanh.__doc__ = ''' + Compute the hyperbolic tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +deg2rad.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +rad2deg.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +degrees.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +radians.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +angle.__doc__ = ''' + Return the angle of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +percentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +nanpercentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +quantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +nanquantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + # math funcs only accept unitless (binary) # ---------------------------------------- @@ -1316,6 +2429,51 @@ def f(x, y, *args, **kwargs): logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) +# docs for the functions above +hypot.__doc__ = ''' + Given the “legs” of a right triangle, return its hypotenuse. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +arctan2.__doc__ = ''' + Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + +logaddexp.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + +logaddexp2.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs in base-2. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: an array +''' + # math funcs remove unit (unary) # ------------------------------ @@ -1335,6 +2493,47 @@ def f(x, *args, **kwargs): histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) +# docs for the functions above +signbit.__doc__ = ''' + Returns element-wise True where signbit is set (less than zero). + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +sign.__doc__ = ''' + Returns the sign of each element in the input array. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + +histogram.__doc__ = ''' + Compute the histogram of a set of data. + + Args: + x: array_like, Quantity + + Returns: + out: Tuple of arrays (hist, bin_edges) +''' + +bincount.__doc__ = ''' + Count number of occurrences of each value in array of non-negative integers. + + Args: + x: array_like, Quantity + + Returns: + out: an array +''' + # math funcs remove unit (binary) # ------------------------------- @@ -1358,6 +2557,51 @@ def f(x, y, *args, **kwargs): cov = wrap_math_funcs_remove_unit_binary(jnp.cov) digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) +# docs for the functions above +corrcoef.__doc__ = ''' + Return Pearson product-moment correlation coefficients. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +correlate.__doc__ = ''' + Cross-correlation of two sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: an array +''' + +cov.__doc__ = ''' + Covariance matrix. + + Args: + x: array_like, Quantity + y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) + + Returns: + out: an array +''' + +digitize.__doc__ = ''' + Return the indices of the bins to which each value in input array belongs. + + Args: + x: array_like, Quantity + bins: array_like, Quantity + + Returns: + out: an array +''' + # array manipulation # ------------------ @@ -1751,8 +2995,8 @@ def einsum( @set_module_as('brainunit.math') def gradient( - f: Union[jax.Array, np.ndarray, Quantity], - *varargs: Union[jax.Array, np.ndarray, Quantity], + f: Union[bst.typing.ArrayLike, Quantity], + *varargs: Union[bst.typing.ArrayLike, Quantity], axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: @@ -1780,8 +3024,8 @@ def gradient( @set_module_as('brainunit.math') def intersect1d( - ar1: Union[jax.Array, np.ndarray], - ar2: Union[jax.Array, np.ndarray], + ar1: Union[bst.typing.ArrayLike], + ar2: Union[bst.typing.ArrayLike], assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: From 4ee28ede8830a5990fad4904f9d3c4e9432de8d9 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 11:38:38 +0800 Subject: [PATCH 03/23] Update --- brainunit/math/_compat_numpy.py | 1054 ++++++++++++++++++++++++++++++- brainunit/math/_utils.py | 5 +- 2 files changed, 1037 insertions(+), 22 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 0dae459..184d3dc 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -101,10 +101,10 @@ 'diagflat', 'diagonal', 'choose', 'ravel', # Elementwise bit operations (unary) - 'bitwise_not', 'invert', 'left_shift', 'right_shift', + 'bitwise_not', 'invert', # Elementwise bit operations (binary) - 'bitwise_and', 'bitwise_or', 'bitwise_xor', + 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', # logic funcs (unary) 'all', 'any', 'logical_not', @@ -2657,6 +2657,490 @@ def f(x, y, *args, **kwargs): extract = _compatible_with_quantity(jnp.extract, return_quantity=False) count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) +# docs for the functions above +reshape.__doc__ = ''' + Return a reshaped copy of an array or a Quantity. + + Args: + a: input array or Quantity to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + brainunit does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. +''' + +moveaxis.__doc__ = ''' + Moves axes of an array to new positions. Other axes remain in their original order. + + Args: + a: array_like, Quantity + source: int or sequence of ints + destination: int or sequence of ints + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +transpose.__doc__ = ''' + Returns a view of the array with axes transposed. + + Args: + a: array_like, Quantity + axes: tuple or list of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +swapaxes.__doc__ = ''' + Interchanges two axes of an array. + + Args: + a: array_like, Quantity + axis1: int + axis2: int + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +concatenate.__doc__ = ''' + Join a sequence of arrays along an existing axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +stack.__doc__ = ''' + Join a sequence of arrays along a new axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +vstack.__doc__ = ''' + Stack arrays in sequence vertically (row wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +hstack.__doc__ = ''' + Stack arrays in sequence horizontally (column wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +dstack.__doc__ = ''' + Stack arrays in sequence depth wise (along third axis). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +column_stack.__doc__ = ''' + Stack 1-D arrays as columns into a 2-D array. + + Args: + arrays: sequence of 1-D or 2-D array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +split.__doc__ = ''' + Split an array into multiple sub-arrays. + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + axis: int, optional + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +dsplit.__doc__ = ''' + Split array along third axis (depth). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +hsplit.__doc__ = ''' + Split an array into multiple sub-arrays horizontally (column-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +vsplit.__doc__ = ''' + Split an array into multiple sub-arrays vertically (row-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +tile.__doc__ = ''' + Construct an array by repeating A the number of times given by reps. + + Args: + A: array_like, Quantity + reps: array_like + + Returns: + out: a Quantity if A is a Quantity, otherwise a jax.numpy.Array +''' + +repeat.__doc__ = ''' + Repeat elements of an array. + + Args: + a: array_like, Quantity + repeats: array_like + axis: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +unique.__doc__ = ''' + Find the unique elements of an array. + + Args: + a: array_like, Quantity + return_index: bool, optional + return_inverse: bool, optional + return_counts: bool, optional + axis: int or None, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +append.__doc__ = ''' + Append values to the end of an array. + + Args: + arr: array_like, Quantity + values: array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if arr and values are Quantity, otherwise a jax.numpy.Array +''' + +flip.__doc__ = ''' + Reverse the order of elements in an array along the given axis. + + Args: + m: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +fliplr.__doc__ = ''' + Flip array in the left/right direction. + + Args: + m: array_like, Quantity + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +flipud.__doc__ = ''' + Flip array in the up/down direction. + + Args: + m: array_like, Quantity + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +roll.__doc__ = ''' + Roll array elements along a given axis. + + Args: + a: array_like, Quantity + shift: int or tuple of ints + axis: int or tuple of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_1d.__doc__ = ''' + View inputs as arrays with at least one dimension. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_2d.__doc__ = ''' + View inputs as arrays with at least two dimensions. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_3d.__doc__ = ''' + View inputs as arrays with at least three dimensions. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +expand_dims.__doc__ = ''' + Expand the shape of an array. + + Args: + a: array_like, Quantity + axis: int or tuple of ints + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +squeeze.__doc__ = ''' + Remove single-dimensional entries from the shape of an array. + + Args: + a: array_like, Quantity + axis: None or int or tuple of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +sort.__doc__ = ''' + Return a sorted copy of an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' +max.__doc__ = ''' + Return the maximum of an array or maximum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +min.__doc__ = ''' + Return the minimum of an array or minimum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +choose.__doc__ = ''' + Use an index array to construct a new array from a set of choices. + + Args: + a: array_like, Quantity + choices: array_like, Quantity + + Returns: + out: a Quantity if a and choices are Quantity, otherwise a jax.numpy.Array +''' + +block.__doc__ = ''' + Assemble an nd-array from nested lists of blocks. + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +compress.__doc__ = ''' + Return selected slices of an array along given axis. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +diagflat.__doc__ = ''' + Create a two-dimensional array with the flattened input as a diagonal. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +argsort.__doc__ = ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort'}, optional + order: str or list of str, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argmax.__doc__ = ''' + Returns indices of the max value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argmin.__doc__ = ''' + Returns indices of the min value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argwhere.__doc__ = ''' + Find indices of non-zero elements. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +nonzero.__doc__ = ''' + Return the indices of the elements that are non-zero. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +flatnonzero.__doc__ = ''' + Return indices that are non-zero in the flattened version of a. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +searchsorted.__doc__ = ''' + Find indices where elements should be inserted to maintain order. + + Args: + a: array_like, Quantity + v: array_like, Quantity + side: {'left', 'right'}, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +extract.__doc__ = ''' + Return the elements of an array that satisfy some condition. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +count_nonzero.__doc__ = ''' + Counts the number of non-zero values in the array a. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + def wrap_function_to_method(func): @wraps(func) @@ -2673,6 +3157,30 @@ def f(x, *args, **kwargs): diagonal = wrap_function_to_method(jnp.diagonal) ravel = wrap_function_to_method(jnp.ravel) +diagonal.__doc__ = ''' + Return specified diagonals. + + Args: + a: array_like, Quantity + offset: int, optional + axis1: int, optional + axis2: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +ravel.__doc__ = ''' + Return a contiguous flattened array. + + Args: + a: array_like, Quantity + order: {'C', 'F', 'A', 'K'}, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + # Elementwise bit operations (unary) # ---------------------------------- @@ -2692,8 +3200,27 @@ def f(x, *args, **kwargs): bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) invert = wrap_elementwise_bit_operation_unary(jnp.invert) -left_shift = wrap_elementwise_bit_operation_unary(jnp.left_shift) -right_shift = wrap_elementwise_bit_operation_unary(jnp.right_shift) + +# docs for functions above +bitwise_not.__doc__ = ''' + Compute the bit-wise NOT of an array, element-wise. + + Args: + x: array_like + + Returns: + out: an array +''' + +invert.__doc__ = ''' + Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Args: + x: array_like + + Returns: + out: an array +''' # Elementwise bit operations (binary) @@ -2708,13 +3235,71 @@ def f(x, y, *args, **kwargs): else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - f.__module__ = 'brainunit.math' - return f + f.__module__ = 'brainunit.math' + return f + + +bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) +bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) +bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) +left_shift = wrap_elementwise_bit_operation_binary(jnp.left_shift) +right_shift = wrap_elementwise_bit_operation_binary(jnp.right_shift) + +# docs for functions above +bitwise_and.__doc__ = ''' + Compute the bit-wise AND of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +bitwise_or.__doc__ = ''' + Compute the bit-wise OR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +bitwise_xor.__doc__ = ''' + Compute the bit-wise XOR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +left_shift.__doc__ = ''' + Shift the bits of an integer to the left. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' +right_shift.__doc__ = ''' + Shift the bits of an integer to the right. -bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) -bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) -bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' # logic funcs (unary) @@ -2739,6 +3324,46 @@ def f(x, *args, **kwargs): sometrue = any logical_not = wrap_logic_func_unary(jnp.logical_not) +# docs for functions above +all.__doc__ = ''' + Test whether all array elements along a given axis evaluate to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + out: bool or array +''' + +any.__doc__ = ''' + Test whether any array element along a given axis evaluates to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + out: bool or array +''' + +logical_not.__doc__ = ''' + Compute the truth value of NOT x element-wise. + + Args: + x: array_like + out: array, optional + + Returns: + out: bool or array +''' + # logic funcs (binary) # -------------------- @@ -2771,11 +3396,155 @@ def f(x, y, *args, **kwargs): logical_or = wrap_logic_func_binary(jnp.logical_or) logical_xor = wrap_logic_func_binary(jnp.logical_xor) +# docs for functions above +equal.__doc__ = ''' + Return (x == y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +not_equal.__doc__ = ''' + Return (x != y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +greater.__doc__ = ''' + Return (x > y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +greater_equal.__doc__ = ''' + Return (x >= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +less.__doc__ = ''' + Return (x < y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +less_equal.__doc__ = ''' + Return (x <= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +array_equal.__doc__ = ''' + Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: bool or array +''' + +isclose.__doc__ = ''' + Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + out: bool or array +''' + +allclose.__doc__ = ''' + Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + out: bool +''' + +logical_and.__doc__ = ''' + Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + +logical_or.__doc__ = ''' + Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + +logical_xor.__doc__ = ''' + Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + # indexing funcs # -------------- @set_module_as('brainunit.math') -def where(condition, *args, **kwds): # pylint: disable=C0111 +def where(condition: Union[bool, bst.typing.ArrayLike], + *args: Union[Quantity, bst.typing.ArrayLike], + **kwds) -> Union[Quantity, jax.Array]: condition = jnp.asarray(condition) if len(args) == 0: # nothing to do @@ -2809,10 +3578,32 @@ def where(condition, *args, **kwds): # pylint: disable=C0111 tril_indices = jnp.tril_indices +tril_indices.__doc__ = ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + out: tuple[array] +''' @set_module_as('brainunit.math') -def tril_indices_from(arr, k=0): +def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + out: tuple[array] + ''' if isinstance(arr, Quantity): return jnp.tril_indices_from(arr.value, k=k) else: @@ -2820,10 +3611,32 @@ def tril_indices_from(arr, k=0): triu_indices = jnp.triu_indices +triu_indices.__doc__ = ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + out: tuple[array] +''' @set_module_as('brainunit.math') -def triu_indices_from(arr, k=0): +def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + out: tuple[array] + ''' if isinstance(arr, Quantity): return jnp.triu_indices_from(arr.value, k=k) else: @@ -2831,7 +3644,10 @@ def triu_indices_from(arr, k=0): @set_module_as('brainunit.math') -def take(a, indices, axis=None, mode=None): +def take(a: Union[Quantity, bst.typing.ArrayLike], + indices: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None) -> Union[Quantity, jax.Array]: if isinstance(a, Quantity): return a.take(indices, axis=axis, mode=mode) else: @@ -2839,7 +3655,9 @@ def take(a, indices, axis=None, mode=None): @set_module_as('brainunit.math') -def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quantity, jax.Array, np.ndarray], default=0): +def select(condlist: list[Union[bst.typing.ArrayLike]], + choicelist: Union[Quantity, bst.typing.ArrayLike], + default: int = 0) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(choice, Quantity) for choice in choicelist): @@ -2859,7 +3677,7 @@ def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quant def wrap_window_funcs(func): def f(*args, **kwargs): - return Quantity(func(*args, **kwargs)) + return func(*args, **kwargs) f.__module__ = 'brainunit.math' return f @@ -2871,6 +3689,13 @@ def f(*args, **kwargs): hanning = wrap_window_funcs(jnp.hanning) kaiser = wrap_window_funcs(jnp.kaiser) +# docs for functions above +bartlett.__doc__ = jnp.bartlett.__doc__ +blackman.__doc__ = jnp.blackman.__doc__ +hamming.__doc__ = jnp.hamming.__doc__ +hanning.__doc__ = jnp.hanning.__doc__ +kaiser.__doc__ = jnp.kaiser.__doc__ + # constants # --------- e = jnp.e @@ -2887,13 +3712,91 @@ def f(*args, **kwargs): matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) trace = wrap_math_funcs_keep_unit_unary(jnp.trace) +# docs for functions above +dot.__doc__ = ''' + Dot product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +vdot.__doc__ = ''' + Return the dot product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +inner.__doc__ = ''' + Inner product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +outer.__doc__ = ''' + Compute the outer product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +kron.__doc__ = ''' + Compute the Kronecker product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +matmul.__doc__ = ''' + Matrix product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +trace.__doc__ = ''' + Return the sum of the diagonal elements of a matrix or quantity. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + out: Quantity if the input is a Quantity, else an array. +''' + # data types # ---------- dtype = jnp.dtype @set_module_as('brainunit.math') -def finfo(a): +def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: if isinstance(a, Quantity): return jnp.finfo(a.value) else: @@ -2901,7 +3804,7 @@ def finfo(a): @set_module_as('brainunit.math') -def iinfo(a): +def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: if isinstance(a, Quantity): return jnp.iinfo(a.value) else: @@ -2911,7 +3814,7 @@ def iinfo(a): # more # ---- @set_module_as('brainunit.math') -def broadcast_arrays(*args): +def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(arg, Quantity) for arg in args): @@ -2929,14 +3832,37 @@ def broadcast_arrays(*args): @set_module_as('brainunit.math') def einsum( - subscripts, /, - *operands, + subscripts: str, + /, + *operands: Union[Quantity, jax.Array], out: None = None, optimize: Union[str, bool] = "optimal", precision: jax.lax.PrecisionLike = None, preferred_element_type: Union[jax.typing.DTypeLike, None] = None, _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, ) -> Union[jax.Array, Quantity]: + ''' + Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays or quantities corresponding to the subscripts. + optimize: determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + ''' operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") @@ -3000,6 +3926,18 @@ def gradient( axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + ''' + Computes the gradient of a scalar field. + + Args: + f: input array. + *varargs: list of scalar fields to compute the gradient. + axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. + edge_order: order of the edge used for the finite difference computation. The default is 1. + + Returns: + array containing the gradient of the scalar field. + ''' if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") @@ -3029,6 +3967,18 @@ def intersect1d( assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: + ''' + Find the intersection of two arrays. + + Args: + ar1: input array. + ar2: input array. + assume_unique: if True, the input arrays are both assumed to be unique. + return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. + + Returns: + array containing the intersection of the two arrays. + ''' fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') unit = None if isinstance(ar1, Quantity): @@ -3054,3 +4004,67 @@ def intersect1d( rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) + +# docs for functions above +nan_to_num.__doc__ = ''' + Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. + + Args: + x: input array. + nan: value to replace NaNs with. + posinf: value to replace positive infinity with. + neginf: value to replace negative infinity with. + + Returns: + array with NaNs replaced by zero and infinities replaced by large finite numbers. +''' + +nanargmax.__doc__ = ''' + Return the index of the maximum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the maximum value in the array. +''' + +nanargmin.__doc__ = ''' + Return the index of the minimum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the minimum value in the array. +''' + +rot90.__doc__ = ''' + Rotate an array by 90 degrees in the plane specified by axes. + + Args: + m: array like, Quantity. + k: number of times the array is rotated by 90 degrees. + axes: plane of rotation. Default is the last two axes. + + Returns: + rotated array. +''' + +tensordot.__doc__ = ''' + Compute tensor dot product along specified axes for arrays. + + Args: + a: array like, Quantity. + b: array like, Quantity. + axes: axes along which to compute the tensor dot product. + + Returns: + tensor dot product of the two arrays. +''' \ No newline at end of file diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index f30ec85..ae66103 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -15,8 +15,9 @@ import functools -from typing import Callable +from typing import Callable, Union +import jax from jax.tree_util import tree_map from .._base import Quantity @@ -38,7 +39,7 @@ def _compatible_with_quantity( func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs): + def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: unit = None if isinstance(args[0], Quantity): unit = args[0].unit From fc2b978b9af138bbf6cbd3eb03b7a2851219977c Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 11:57:06 +0800 Subject: [PATCH 04/23] Update _compat_numpy.py --- brainunit/math/_compat_numpy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 184d3dc..f12b33b 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -12,10 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -import functools from collections.abc import Sequence from functools import wraps -from typing import (Callable, Union, Optional, Any, List) +from typing import (Callable, Union, Optional, Any) import brainstate as bst import jax @@ -27,7 +26,6 @@ from jax._src.numpy.lax_numpy import _einsum from ._utils import _compatible_with_quantity -from .. import Quantity from .._base import (DIMENSIONLESS, Quantity, Unit, From 017eb6ff52ee2f876b0363c17ba787972e8d9c88 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 12:07:24 +0800 Subject: [PATCH 05/23] Fix --- brainunit/math/_compat_numpy.py | 19 ++++++------- brainunit/math/_compat_numpy_test.py | 40 ++++++++++++++-------------- 2 files changed, 30 insertions(+), 29 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index f12b33b..4c0b917 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -708,7 +708,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def array_split(ary: Union[Quantity, bst.typing.ArrayLike], indices_or_sections: Union[int, bst.typing.ArrayLike], - axis: Optional[int] = 0) -> list[Quantity] | list[Array]: + axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: ''' Split an array into multiple sub-arrays. @@ -1678,7 +1678,7 @@ def f(x, *args, **kwargs): def prod(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[bst.typing.DTypeLike] = None, - out: Optional[...] = None, + out: None = None, keepdims: Optional[bool] = False, initial: Union[Quantity, bst.typing.ArrayLike] = None, where: Union[Quantity, bst.typing.ArrayLike] = None, @@ -1711,8 +1711,8 @@ def prod(x: Union[Quantity, bst.typing.ArrayLike], def nanprod(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[bst.typing.DTypeLike] = None, - out: Optional[...] = None, - keepdims: Optional[...] = False, + out: None = None, + keepdims: None = False, initial: Union[Quantity, bst.typing.ArrayLike] = None, where: Union[Quantity, bst.typing.ArrayLike] = None): ''' @@ -1743,7 +1743,7 @@ def nanprod(x: Union[Quantity, bst.typing.ArrayLike], def cumprod(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[bst.typing.DTypeLike] = None, - out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: ''' Return the cumulative product of elements along a given axis. @@ -1766,7 +1766,7 @@ def cumprod(x: Union[Quantity, bst.typing.ArrayLike], def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, dtype: Optional[bst.typing.DTypeLike] = None, - out: Optional[...] = None) -> Union[Quantity, bst.typing.ArrayLike]: + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: ''' Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. @@ -2029,9 +2029,10 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike], Returns: out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' - assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(y, Quantity): + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' if isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y.unit)) + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y)) elif isinstance(x, (jax.Array, np.ndarray)): return jnp.float_power(x, y) else: @@ -3228,7 +3229,7 @@ def wrap_elementwise_bit_operation_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) or isinstance(y, Quantity): raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike): return func(x, y, *args, **kwargs) else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 24bfa7e..9cfec3e 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -87,7 +87,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4 * bu.second) + result_q = bu.math.full_like(q, 4, unit=bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): @@ -97,7 +97,7 @@ def test_diag(self): self.assertTrue(jnp.all(result == jnp.diag(array))) q = [1, 2, 3] * bu.second - result_q = bu.math.diag(q) + result_q = bu.math.diag(q, unit=bu.second) assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), bu.second) def test_tril(self): @@ -107,7 +107,7 @@ def test_tril(self): self.assertTrue(jnp.all(result == jnp.tril(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.tril(q) + result_q = bu.math.tril(q, unit=bu.second) assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), bu.second) def test_triu(self): @@ -117,7 +117,7 @@ def test_triu(self): self.assertTrue(jnp.all(result == jnp.triu(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.triu(q) + result_q = bu.math.triu(q, unit=bu.second) assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), bu.second) def test_empty_like(self): @@ -1810,22 +1810,6 @@ def test_invert(self): q = [0b1100] * bu.second result_q = bu.math.invert(q) - def test_left_shift(self): - result = bu.math.left_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.left_shift(q, 2) - - def test_right_shift(self): - result = bu.math.right_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.right_shift(q, 2) - class TestElementwiseBitOperationsBinary(unittest.TestCase): @@ -1856,6 +1840,22 @@ def test_bitwise_xor(self): q2 = [0b1010] * bu.second result_q = bu.math.bitwise_xor(q1, q2) + def test_left_shift(self): + result = bu.math.left_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.left_shift(q, 2) + + def test_right_shift(self): + result = bu.math.right_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.right_shift(q, 2) + class TestLogicFuncsUnary(unittest.TestCase): def test_all(self): From 22ef3b68e841edf924b16d6bfc0bc47ae3b814b7 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 12:15:46 +0800 Subject: [PATCH 06/23] Update brainunit.math.rst --- docs/apis/brainunit.math.rst | 456 +++++++++++++++++++++++++++++++++++ 1 file changed, 456 insertions(+) diff --git a/docs/apis/brainunit.math.rst b/docs/apis/brainunit.math.rst index a6ab19c..2b303fc 100644 --- a/docs/apis/brainunit.math.rst +++ b/docs/apis/brainunit.math.rst @@ -4,6 +4,462 @@ .. currentmodule:: brainunit.math .. automodule:: brainunit.math +Array Creation +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + full + full_like + eye + identity + diag + tri + tril + triu + empty + empty_like + ones + ones_like + zeros + zeros_like + array + asarray + arange + linspace + logspace + fill_diagonal + array_split + meshgrid + vander + +Getting Attribute Funcs +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ndim + isreal + isscalar + isfinite + isinf + isnan + shape + size + +Math Funcs Keep Unit (Unary) +----------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + real + imag + conj + conjugate + negative + positive + abs + round + around + round_ + rint + floor + ceil + trunc + fix + sum + nancumsum + nansum + cumsum + ediff1d + absolute + fabs + median + nanmin + nanmax + ptp + average + mean + std + nanmedian + nanmean + nanstd + diff + modf + +Math Funcs Keep Unit (Binary) +------------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + fmod + mod + copysign + heaviside + maximum + minimum + fmax + fmin + lcm + gcd + +Math Funcs Keep Unit (N-ary) +----------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + interp + clip + +Math Funcs Match Unit (Binary) +------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + add + subtract + nextafter + +Math Funcs Change Unit (Unary) +------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reciprocal + prod + product + nancumprod + nanprod + cumprod + cumproduct + var + nanvar + cbrt + square + frexp + sqrt + +Math Funcs Change Unit (Binary) +-------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + multiply + divide + power + cross + ldexp + true_divide + floor_divide + float_power + divmod + remainder + convolve + +Math Funcs Only Accept Unitless (Unary) +--------------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + exp + exp2 + expm1 + log + log10 + log1p + log2 + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + cos + cosh + sin + sinc + sinh + tan + tanh + deg2rad + rad2deg + degrees + radians + angle + percentile + nanpercentile + quantile + nanquantile + +Math Funcs Only Accept Unitless (Binary) +---------------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + hypot + arctan2 + logaddexp + logaddexp2 + +Math Funcs Remove Unit (Unary) +------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + signbit + sign + histogram + bincount + +Math Funcs Remove Unit (Binary) +-------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + corrcoef + correlate + cov + digitize + +Array Manipulation +------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reshape + moveaxis + transpose + swapaxes + row_stack + concatenate + stack + vstack + hstack + dstack + column_stack + split + dsplit + hsplit + vsplit + tile + repeat + unique + append + flip + fliplr + flipud + roll + atleast_1d + atleast_2d + atleast_3d + expand_dims + squeeze + sort + argsort + argmax + argmin + argwhere + nonzero + flatnonzero + searchsorted + extract + count_nonzero + max + min + amax + amin + block + compress + diagflat + diagonal + choose + ravel + +Elementwise Bit Operations (Unary) +---------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bitwise_not + invert + +Elementwise Bit Operations (Binary) +----------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bitwise_and + bitwise_or + bitwise_xor + left_shift + right_shift + +Logic Funcs (Unary) +-------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + all + any + logical_not + +Logic Funcs (Binary) +--------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + equal + not_equal + greater + greater_equal + less + less_equal + array_equal + isclose + allclose + logical_and + logical_or + logical_xor + alltrue + sometrue + +Indexing Funcs +--------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + nonzero + where + tril_indices + tril_indices_from + triu_indices + triu_indices_from + take + select + +Window Funcs +------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bartlett + blackman + hamming + hanning + kaiser + +Constants +---------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + e + pi + inf + +Linear Algebra +--------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + dot + vdot + inner + outer + kron + matmul + trace + +Data Types +----------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + dtype + finfo + iinfo + +More +----- + .. autosummary:: :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + broadcast_arrays + broadcast_shapes + einsum + gradient + intersect1d + nan_to_num + nanargmax + nanargmin + rot90 + tensordot From cbbcfc97ee063cfad083cb93b7d04eb109b7cd9e Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 13:13:26 +0800 Subject: [PATCH 07/23] Update _compat_numpy.py --- brainunit/math/_compat_numpy.py | 446 ++++++++++++++++---------------- 1 file changed, 223 insertions(+), 223 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 4c0b917..360c088 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -154,13 +154,13 @@ def f(*args, unit: Unit = None, **kwargs): # array creation # -------------- -full = wrap_array_creation_function(jnp.full) -eye = wrap_array_creation_function(jnp.eye) -identity = wrap_array_creation_function(jnp.identity) -tri = wrap_array_creation_function(jnp.tri) -empty = wrap_array_creation_function(jnp.empty) -ones = wrap_array_creation_function(jnp.ones) -zeros = wrap_array_creation_function(jnp.zeros) +full: Callable = wrap_array_creation_function(jnp.full) +eye: Callable = wrap_array_creation_function(jnp.eye) +identity: Callable = wrap_array_creation_function(jnp.identity) +tri: Callable = wrap_array_creation_function(jnp.tri) +empty: Callable = wrap_array_creation_function(jnp.empty) +ones: Callable = wrap_array_creation_function(jnp.ones) +zeros: Callable = wrap_array_creation_function(jnp.zeros) # docs for full, eye, identity, tri, empty, ones, zeros @@ -179,7 +179,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ eye.__doc__ = """ @@ -199,7 +199,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ identity.__doc__ = """ @@ -216,7 +216,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ tri.__doc__ = """ @@ -237,7 +237,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ # empty @@ -255,7 +255,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ # ones @@ -273,7 +273,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ # zeros @@ -291,7 +291,7 @@ def f(*args, unit: Unit = None, **kwargs): unit: the unit of the output array, or `None`. Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. """ @@ -313,7 +313,7 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike], shape: sequence of ints, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -338,7 +338,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -363,7 +363,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -388,7 +388,7 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -416,7 +416,7 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -444,7 +444,7 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -472,7 +472,7 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], unit: Unit, optional Returns: - out: Quantity if `unit` is provided, else an array. + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if unit is not None: assert isinstance(unit, Unit) @@ -526,7 +526,7 @@ def arange(*args, **kwargs): unit: Unit, optional Returns: - out: Quantity if start and stop are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. ''' # arange has a bit of a complicated argument structure unfortunately # we leave the actual checking of the number of arguments to numpy, though @@ -624,7 +624,7 @@ def linspace(start: Union[Quantity, bst.typing.ArrayLike], dtype: dtype, optional Returns: - out: Quantity if start and stop are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. ''' fail_for_dimension_mismatch( start, @@ -660,7 +660,7 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], dtype: dtype, optional Returns: - out: Quantity if start and stop are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. ''' fail_for_dimension_mismatch( start, @@ -692,7 +692,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], inplace: bool, optional Returns: - out: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. ''' if isinstance(a, Quantity) and isinstance(val, Quantity): fail_for_dimension_mismatch(a, val) @@ -718,7 +718,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike], axis: int, optional Returns: - out: Quantity if `ary` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. ''' if isinstance(ary, Quantity): return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] @@ -743,7 +743,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], indexing: str, optional Returns: - out: Quantity if `xi` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. ''' from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): @@ -768,7 +768,7 @@ def vander(x: Union[Quantity, bst.typing.ArrayLike], increasing: bool, optional Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) @@ -790,7 +790,7 @@ def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: a: array_like, Quantity Returns: - out: int + Union[jax.Array, Quantity]: int ''' if isinstance(a, Quantity): return a.ndim @@ -807,7 +807,7 @@ def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: a: array_like, Quantity Returns: - out: boolean array + Union[jax.Array, Quantity]: boolean array ''' if isinstance(a, Quantity): return a.isreal @@ -824,7 +824,7 @@ def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: a: array_like, Quantity Returns: - out: boolean array + Union[jax.Array, Quantity]: boolean array ''' if isinstance(a, Quantity): return a.isscalar @@ -841,7 +841,7 @@ def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: a: array_like, Quantity Returns: - out: boolean array + Union[jax.Array, Quantity]: boolean array ''' if isinstance(a, Quantity): return a.isfinite @@ -858,7 +858,7 @@ def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: a: array_like, Quantity Returns: - out: boolean array + Union[jax.Array, Quantity]: boolean array ''' if isinstance(a, Quantity): return a.isinf @@ -875,7 +875,7 @@ def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: a: array_like, Quantity Returns: - out: boolean array + Union[jax.Array, Quantity]: boolean array ''' if isinstance(a, Quantity): return a.isnan @@ -1025,7 +1025,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' imag.__doc__ = ''' @@ -1035,7 +1035,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' conj.__doc__ = ''' @@ -1045,7 +1045,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' conjugate.__doc__ = ''' @@ -1055,7 +1055,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' negative.__doc__ = ''' @@ -1065,7 +1065,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' positive.__doc__ = ''' @@ -1075,7 +1075,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' abs.__doc__ = ''' @@ -1085,7 +1085,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' round_.__doc__ = ''' @@ -1095,7 +1095,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' around.__doc__ = ''' @@ -1105,7 +1105,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' round.__doc__ = ''' @@ -1115,7 +1115,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' rint.__doc__ = ''' @@ -1125,7 +1125,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' floor.__doc__ = ''' @@ -1135,7 +1135,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' ceil.__doc__ = ''' @@ -1145,7 +1145,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' trunc.__doc__ = ''' @@ -1155,7 +1155,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' fix.__doc__ = ''' @@ -1165,7 +1165,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' sum.__doc__ = ''' @@ -1175,7 +1175,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nancumsum.__doc__ = ''' @@ -1185,7 +1185,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nansum.__doc__ = ''' @@ -1195,7 +1195,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' cumsum.__doc__ = ''' @@ -1205,7 +1205,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' ediff1d.__doc__ = ''' @@ -1215,7 +1215,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' absolute.__doc__ = ''' @@ -1225,7 +1225,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' fabs.__doc__ = ''' @@ -1235,7 +1235,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' median.__doc__ = ''' @@ -1245,7 +1245,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nanmin.__doc__ = ''' @@ -1255,7 +1255,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nanmax.__doc__ = ''' @@ -1265,7 +1265,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' ptp.__doc__ = ''' @@ -1275,7 +1275,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' average.__doc__ = ''' @@ -1285,7 +1285,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' mean.__doc__ = ''' @@ -1295,7 +1295,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' std.__doc__ = ''' @@ -1305,7 +1305,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nanmedian.__doc__ = ''' @@ -1315,7 +1315,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nanmean.__doc__ = ''' @@ -1325,7 +1325,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' nanstd.__doc__ = ''' @@ -1335,7 +1335,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' diff.__doc__ = ''' @@ -1345,7 +1345,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' modf.__doc__ = ''' @@ -1355,7 +1355,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity tuple if `x` is a Quantity, else an array tuple. + Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. ''' @@ -1395,7 +1395,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' mod.__doc__ = ''' @@ -1406,7 +1406,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' copysign.__doc__ = ''' @@ -1417,7 +1417,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' heaviside.__doc__ = ''' @@ -1428,7 +1428,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' maximum.__doc__ = ''' @@ -1439,7 +1439,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' minimum.__doc__ = ''' @@ -1450,7 +1450,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' fmax.__doc__ = ''' @@ -1461,7 +1461,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' fmin.__doc__ = ''' @@ -1472,7 +1472,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' lcm.__doc__ = ''' @@ -1483,7 +1483,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' gcd.__doc__ = ''' @@ -1494,7 +1494,7 @@ def f(x1, x2, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' @@ -1519,7 +1519,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], period: array_like, Quantity, optional Returns: - out: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. ''' unit = None if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): @@ -1556,7 +1556,7 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike], a_max: array_like, Quantity Returns: - out: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. ''' unit = None if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): @@ -1620,7 +1620,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. ''' subtract.__doc__ = ''' @@ -1631,7 +1631,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if `x` and `y` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. ''' nextafter.__doc__ = ''' @@ -1642,7 +1642,7 @@ def f(x, y, *args, **kwargs): x2: array_like, Quantity Returns: - out: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. ''' @@ -1670,7 +1670,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' @@ -1697,7 +1697,7 @@ def prod(x: Union[Quantity, bst.typing.ArrayLike], promote_integers: bool, optional Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, @@ -1728,7 +1728,7 @@ def nanprod(x: Union[Quantity, bst.typing.ArrayLike], where: array_like, Quantity, optional Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) @@ -1754,7 +1754,7 @@ def cumprod(x: Union[Quantity, bst.typing.ArrayLike], out: array, optional Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): return x.cumprod(axis=axis, dtype=dtype, out=out) @@ -1777,7 +1777,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], out: array, optional Returns: - out: Quantity if `x` is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): return x.nancumprod(axis=axis, dtype=dtype, out=out) @@ -1802,7 +1802,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Quantity if the final unit is the square of the unit of `x`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. ''' nanvar.__doc__ = ''' @@ -1812,7 +1812,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Quantity if the final unit is the square of the unit of `x`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. ''' frexp.__doc__ = ''' @@ -1822,7 +1822,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. + Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. ''' sqrt.__doc__ = ''' @@ -1832,7 +1832,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Quantity if the final unit is the square root of the unit of `x`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. ''' cbrt.__doc__ = ''' @@ -1842,7 +1842,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Quantity if the final unit is the cube root of the unit of `x`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. ''' square.__doc__ = ''' @@ -1852,7 +1852,7 @@ def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], x: array_like, Quantity Returns: - out: Quantity if the final unit is the square of the unit of `x`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. ''' @@ -1897,7 +1897,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' divide.__doc__ = ''' @@ -1907,7 +1907,7 @@ def f(x, y, *args, **kwargs): x: array_like, Quantity Returns: - out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' cross.__doc__ = ''' @@ -1918,7 +1918,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' ldexp.__doc__ = ''' @@ -1929,7 +1929,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. ''' true_divide.__doc__ = ''' @@ -1940,7 +1940,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' divmod.__doc__ = ''' @@ -1951,7 +1951,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' convolve.__doc__ = ''' @@ -1962,7 +1962,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' @@ -1977,7 +1977,7 @@ def power(x: Union[Quantity, bst.typing.ArrayLike], y: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) @@ -2002,7 +2002,7 @@ def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], y: array_like, Quantity Returns: - out: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) @@ -2027,7 +2027,7 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike], y: array_like Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(y, Quantity): assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' @@ -2111,7 +2111,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' exp2.__doc__ = ''' @@ -2121,7 +2121,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' expm1.__doc__ = ''' @@ -2131,7 +2131,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' log.__doc__ = ''' @@ -2141,7 +2141,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' log10.__doc__ = ''' @@ -2151,7 +2151,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' log1p.__doc__ = ''' @@ -2161,7 +2161,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' log2.__doc__ = ''' @@ -2171,7 +2171,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arccos.__doc__ = ''' @@ -2181,7 +2181,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arccosh.__doc__ = ''' @@ -2191,7 +2191,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arcsin.__doc__ = ''' @@ -2201,7 +2201,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arcsinh.__doc__ = ''' @@ -2211,7 +2211,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arctan.__doc__ = ''' @@ -2221,7 +2221,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arctanh.__doc__ = ''' @@ -2231,7 +2231,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' cos.__doc__ = ''' @@ -2241,7 +2241,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' cosh.__doc__ = ''' @@ -2251,7 +2251,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' sin.__doc__ = ''' @@ -2261,7 +2261,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' sinc.__doc__ = ''' @@ -2271,7 +2271,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' sinh.__doc__ = ''' @@ -2281,7 +2281,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' tan.__doc__ = ''' @@ -2291,7 +2291,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' tanh.__doc__ = ''' @@ -2301,7 +2301,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' deg2rad.__doc__ = ''' @@ -2311,7 +2311,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' rad2deg.__doc__ = ''' @@ -2321,7 +2321,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' degrees.__doc__ = ''' @@ -2331,7 +2331,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' radians.__doc__ = ''' @@ -2341,7 +2341,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' angle.__doc__ = ''' @@ -2351,7 +2351,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' percentile.__doc__ = ''' @@ -2361,7 +2361,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' nanpercentile.__doc__ = ''' @@ -2371,7 +2371,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' quantile.__doc__ = ''' @@ -2381,7 +2381,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' nanquantile.__doc__ = ''' @@ -2391,7 +2391,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' @@ -2437,7 +2437,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' arctan2.__doc__ = ''' @@ -2448,7 +2448,7 @@ def f(x, y, *args, **kwargs): x2: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' logaddexp.__doc__ = ''' @@ -2459,7 +2459,7 @@ def f(x, y, *args, **kwargs): x2: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' logaddexp2.__doc__ = ''' @@ -2470,7 +2470,7 @@ def f(x, y, *args, **kwargs): x2: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' @@ -2500,7 +2500,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' sign.__doc__ = ''' @@ -2510,7 +2510,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' histogram.__doc__ = ''' @@ -2520,7 +2520,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: Tuple of arrays (hist, bin_edges) + tuple[jax.Array]: Tuple of arrays (hist, bin_edges) ''' bincount.__doc__ = ''' @@ -2530,7 +2530,7 @@ def f(x, *args, **kwargs): x: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' @@ -2565,7 +2565,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' correlate.__doc__ = ''' @@ -2576,7 +2576,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' cov.__doc__ = ''' @@ -2587,7 +2587,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) Returns: - out: an array + jax.Array: an array ''' digitize.__doc__ = ''' @@ -2598,7 +2598,7 @@ def f(x, y, *args, **kwargs): bins: array_like, Quantity Returns: - out: an array + jax.Array: an array ''' # array manipulation @@ -2682,7 +2682,7 @@ def f(x, y, *args, **kwargs): destination: int or sequence of ints Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' transpose.__doc__ = ''' @@ -2693,7 +2693,7 @@ def f(x, y, *args, **kwargs): axes: tuple or list of ints, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' swapaxes.__doc__ = ''' @@ -2705,7 +2705,7 @@ def f(x, y, *args, **kwargs): axis2: int Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' concatenate.__doc__ = ''' @@ -2716,7 +2716,7 @@ def f(x, y, *args, **kwargs): axis: int, optional Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' stack.__doc__ = ''' @@ -2727,7 +2727,7 @@ def f(x, y, *args, **kwargs): axis: int Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' vstack.__doc__ = ''' @@ -2737,7 +2737,7 @@ def f(x, y, *args, **kwargs): arrays: sequence of array_like, Quantity Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array ''' hstack.__doc__ = ''' @@ -2747,7 +2747,7 @@ def f(x, y, *args, **kwargs): arrays: sequence of array_like, Quantity Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' dstack.__doc__ = ''' @@ -2757,7 +2757,7 @@ def f(x, y, *args, **kwargs): arrays: sequence of array_like, Quantity Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' column_stack.__doc__ = ''' @@ -2767,7 +2767,7 @@ def f(x, y, *args, **kwargs): arrays: sequence of 1-D or 2-D array_like, Quantity Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' split.__doc__ = ''' @@ -2779,7 +2779,7 @@ def f(x, y, *args, **kwargs): axis: int, optional Returns: - out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array ''' dsplit.__doc__ = ''' @@ -2790,7 +2790,7 @@ def f(x, y, *args, **kwargs): indices_or_sections: int or 1-D array Returns: - out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array ''' hsplit.__doc__ = ''' @@ -2801,7 +2801,7 @@ def f(x, y, *args, **kwargs): indices_or_sections: int or 1-D array Returns: - out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array ''' vsplit.__doc__ = ''' @@ -2812,7 +2812,7 @@ def f(x, y, *args, **kwargs): indices_or_sections: int or 1-D array Returns: - out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array ''' tile.__doc__ = ''' @@ -2823,7 +2823,7 @@ def f(x, y, *args, **kwargs): reps: array_like Returns: - out: a Quantity if A is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array ''' repeat.__doc__ = ''' @@ -2835,7 +2835,7 @@ def f(x, y, *args, **kwargs): axis: int, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' unique.__doc__ = ''' @@ -2849,7 +2849,7 @@ def f(x, y, *args, **kwargs): axis: int or None, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' append.__doc__ = ''' @@ -2861,7 +2861,7 @@ def f(x, y, *args, **kwargs): axis: int, optional Returns: - out: a Quantity if arr and values are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array ''' flip.__doc__ = ''' @@ -2872,7 +2872,7 @@ def f(x, y, *args, **kwargs): axis: int or tuple of ints, optional Returns: - out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array ''' fliplr.__doc__ = ''' @@ -2882,7 +2882,7 @@ def f(x, y, *args, **kwargs): m: array_like, Quantity Returns: - out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array ''' flipud.__doc__ = ''' @@ -2892,7 +2892,7 @@ def f(x, y, *args, **kwargs): m: array_like, Quantity Returns: - out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array ''' roll.__doc__ = ''' @@ -2904,7 +2904,7 @@ def f(x, y, *args, **kwargs): axis: int or tuple of ints, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' atleast_1d.__doc__ = ''' @@ -2914,7 +2914,7 @@ def f(x, y, *args, **kwargs): *args: array_like, Quantity Returns: - out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array ''' atleast_2d.__doc__ = ''' @@ -2924,7 +2924,7 @@ def f(x, y, *args, **kwargs): *args: array_like, Quantity Returns: - out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array ''' atleast_3d.__doc__ = ''' @@ -2934,7 +2934,7 @@ def f(x, y, *args, **kwargs): *args: array_like, Quantity Returns: - out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array ''' expand_dims.__doc__ = ''' @@ -2945,7 +2945,7 @@ def f(x, y, *args, **kwargs): axis: int or tuple of ints Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' squeeze.__doc__ = ''' @@ -2956,7 +2956,7 @@ def f(x, y, *args, **kwargs): axis: None or int or tuple of ints, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' sort.__doc__ = ''' @@ -2969,7 +2969,7 @@ def f(x, y, *args, **kwargs): order: str or list of str, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' max.__doc__ = ''' Return the maximum of an array or maximum along an axis. @@ -2980,7 +2980,7 @@ def f(x, y, *args, **kwargs): keepdims: bool, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' min.__doc__ = ''' @@ -2992,7 +2992,7 @@ def f(x, y, *args, **kwargs): keepdims: bool, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' choose.__doc__ = ''' @@ -3003,7 +3003,7 @@ def f(x, y, *args, **kwargs): choices: array_like, Quantity Returns: - out: a Quantity if a and choices are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array ''' block.__doc__ = ''' @@ -3013,7 +3013,7 @@ def f(x, y, *args, **kwargs): arrays: sequence of array_like, Quantity Returns: - out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array ''' compress.__doc__ = ''' @@ -3025,7 +3025,7 @@ def f(x, y, *args, **kwargs): axis: int, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' diagflat.__doc__ = ''' @@ -3036,7 +3036,7 @@ def f(x, y, *args, **kwargs): offset: int, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array ''' argsort.__doc__ = ''' @@ -3049,7 +3049,7 @@ def f(x, y, *args, **kwargs): order: str or list of str, optional Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array jax.numpy.Array (does not return a Quantity) ''' argmax.__doc__ = ''' @@ -3061,7 +3061,7 @@ def f(x, y, *args, **kwargs): out: array, optional Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' argmin.__doc__ = ''' @@ -3073,7 +3073,7 @@ def f(x, y, *args, **kwargs): out: array, optional Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' argwhere.__doc__ = ''' @@ -3083,7 +3083,7 @@ def f(x, y, *args, **kwargs): a: array_like, Quantity Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' nonzero.__doc__ = ''' @@ -3093,7 +3093,7 @@ def f(x, y, *args, **kwargs): a: array_like, Quantity Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' flatnonzero.__doc__ = ''' @@ -3103,7 +3103,7 @@ def f(x, y, *args, **kwargs): a: array_like, Quantity Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' searchsorted.__doc__ = ''' @@ -3115,7 +3115,7 @@ def f(x, y, *args, **kwargs): side: {'left', 'right'}, optional Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' extract.__doc__ = ''' @@ -3126,7 +3126,7 @@ def f(x, y, *args, **kwargs): a: array_like, Quantity Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' count_nonzero.__doc__ = ''' @@ -3137,7 +3137,7 @@ def f(x, y, *args, **kwargs): axis: int or tuple of ints, optional Returns: - out: jax.numpy.Array (does not return a Quantity) + jax.Array: an array (does not return a Quantity) ''' @@ -3166,7 +3166,7 @@ def f(x, *args, **kwargs): axis2: int, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array ''' ravel.__doc__ = ''' @@ -3177,7 +3177,7 @@ def f(x, *args, **kwargs): order: {'C', 'F', 'A', 'K'}, optional Returns: - out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array ''' @@ -3208,7 +3208,7 @@ def f(x, *args, **kwargs): x: array_like Returns: - out: an array + jax.Array: an array ''' invert.__doc__ = ''' @@ -3218,7 +3218,7 @@ def f(x, *args, **kwargs): x: array_like Returns: - out: an array + jax.Array: an array ''' @@ -3253,7 +3253,7 @@ def f(x, y, *args, **kwargs): y: array_like Returns: - out: an array + jax.Array: an array ''' bitwise_or.__doc__ = ''' @@ -3264,7 +3264,7 @@ def f(x, y, *args, **kwargs): y: array_like Returns: - out: an array + jax.Array: an array ''' bitwise_xor.__doc__ = ''' @@ -3275,7 +3275,7 @@ def f(x, y, *args, **kwargs): y: array_like Returns: - out: an array + jax.Array: an array ''' left_shift.__doc__ = ''' @@ -3286,7 +3286,7 @@ def f(x, y, *args, **kwargs): y: array_like Returns: - out: an array + jax.Array: an array ''' right_shift.__doc__ = ''' @@ -3297,7 +3297,7 @@ def f(x, y, *args, **kwargs): y: array_like Returns: - out: an array + jax.Array: an array ''' @@ -3335,7 +3335,7 @@ def f(x, *args, **kwargs): where: array_like of bool, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' any.__doc__ = ''' @@ -3349,7 +3349,7 @@ def f(x, *args, **kwargs): where: array_like of bool, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' logical_not.__doc__ = ''' @@ -3360,7 +3360,7 @@ def f(x, *args, **kwargs): out: array, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' @@ -3404,7 +3404,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' not_equal.__doc__ = ''' @@ -3415,7 +3415,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' greater.__doc__ = ''' @@ -3426,7 +3426,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' greater_equal.__doc__ = ''' @@ -3437,7 +3437,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' less.__doc__ = ''' @@ -3448,7 +3448,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' less_equal.__doc__ = ''' @@ -3459,7 +3459,7 @@ def f(x, y, *args, **kwargs): y: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' array_equal.__doc__ = ''' @@ -3470,7 +3470,7 @@ def f(x, y, *args, **kwargs): x2: array_like, Quantity Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' isclose.__doc__ = ''' @@ -3484,7 +3484,7 @@ def f(x, y, *args, **kwargs): equal_nan: bool, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' allclose.__doc__ = ''' @@ -3498,7 +3498,7 @@ def f(x, y, *args, **kwargs): equal_nan: bool, optional Returns: - out: bool + bool: boolean result ''' logical_and.__doc__ = ''' @@ -3510,7 +3510,7 @@ def f(x, y, *args, **kwargs): out: array, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' logical_or.__doc__ = ''' @@ -3522,7 +3522,7 @@ def f(x, y, *args, **kwargs): out: array, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' logical_xor.__doc__ = ''' @@ -3534,7 +3534,7 @@ def f(x, y, *args, **kwargs): out: array, optional Returns: - out: bool or array + Union[bool, jax.Array]: bool or array ''' @@ -3586,7 +3586,7 @@ def where(condition: Union[bool, bst.typing.ArrayLike], k: int, optional Returns: - out: tuple[array] + tuple[jax.Array]: tuple[array] ''' @@ -3601,7 +3601,7 @@ def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], k: int, optional Returns: - out: tuple[array] + tuple[jax.Array]: tuple[array] ''' if isinstance(arr, Quantity): return jnp.tril_indices_from(arr.value, k=k) @@ -3619,7 +3619,7 @@ def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], k: int, optional Returns: - out: tuple[array] + tuple[jax.Array]: tuple[array] ''' @@ -3634,7 +3634,7 @@ def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], k: int, optional Returns: - out: tuple[array] + tuple[jax.Array]: tuple[array] ''' if isinstance(arr, Quantity): return jnp.triu_indices_from(arr.value, k=k) @@ -3720,7 +3720,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' vdot.__doc__ = ''' @@ -3731,7 +3731,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. ''' inner.__doc__ = ''' @@ -3742,7 +3742,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. ''' outer.__doc__ = ''' @@ -3753,7 +3753,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. ''' kron.__doc__ = ''' @@ -3764,7 +3764,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. ''' matmul.__doc__ = ''' @@ -3775,7 +3775,7 @@ def f(*args, **kwargs): b: array_like, Quantity Returns: - out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. ''' trace.__doc__ = ''' @@ -3786,7 +3786,7 @@ def f(*args, **kwargs): offset: int, optional Returns: - out: Quantity if the input is a Quantity, else an array. + Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. ''' # data types @@ -4066,4 +4066,4 @@ def intersect1d( Returns: tensor dot product of the two arrays. -''' \ No newline at end of file +''' From f6a004012f13f516a9bde66740de1b515989c6da Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 13:23:51 +0800 Subject: [PATCH 08/23] Update _unit_test.py --- brainunit/_unit_test.py | 68 ++++++++++++++++++++++------------------- 1 file changed, 36 insertions(+), 32 deletions(-) diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 9309ceb..f24c622 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -193,55 +193,59 @@ def test_display(): def test_unary_operations(): - q = Quantity(5, unit=mV) - assert_quantity(-q, -5, mV) - assert_quantity(+q, 5, mV) - assert_quantity(abs(Quantity(-5, unit=mV)), 5, mV) + q = 5 * second + assert_quantity(-q, -5, second) + assert_quantity(+q, 5, second) + assert_quantity(abs(-5 * second), 5, second) assert_quantity(~Quantity(0b101, unit=DIMENSIONLESS), -0b110, DIMENSIONLESS) def test_operations(): - q1 = Quantity(5, unit=mV) - q2 = Quantity(10, unit=mV) - assert_quantity(q1 + q2, 15, mV) - assert_quantity(q1 - q2, -5, mV) - assert_quantity(q1 * q2, 50, mV * mV) + q1 = 5 * second + q2 = 10 * second + assert_quantity(q1 + q2, 15, second) + assert_quantity(q1 - q2, -5, second) + assert_quantity(q1 * q2, 50, second * second) assert_quantity(q2 / q1, 2, DIMENSIONLESS) assert_quantity(q2 // q1, 2, DIMENSIONLESS) - assert_quantity(q2 % q1, 0, mV) + assert_quantity(q2 % q1, 0, second) assert_quantity(divmod(q2, q1)[0], 2, DIMENSIONLESS) - assert_quantity(divmod(q2, q1)[1], 0, mV) - assert_quantity(q1 ** 2, 25, mV ** 2) - assert_quantity(q1 << 1, 10, mV) - assert_quantity(q1 >> 1, 2, mV) - assert_quantity(round(q1, 0), 5, mV) + assert_quantity(divmod(q2, q1)[1], 0, second) + assert_quantity(q1 ** 2, 25, second ** 2) + assert_quantity(round(q1, 0), 5, second) + # matmul - q1 = Quantity([1, 2], unit=mV) - q2 = Quantity([3, 4], unit=mV) - assert_quantity(q1 @ q2, 11, mV ** 2) + q1 = Quantity([1, 2], unit=second) + q2 = Quantity([3, 4], unit=second) + assert_quantity(q1 @ q2, 11, second ** 2) + + # shift + q1 = Quantity(0b1100, dtype=jnp.int32, unit=DIMENSIONLESS) + assert_quantity(q1 << 1, 0b11000, second) + assert_quantity(q1 >> 1, 0b110, second) def test_numpy_methods(): - q = Quantity([[1, 2], [3, 4]], unit=mV) + q = [[1, 2], [3, 4]] * second assert q.all() assert q.any() assert q.nonzero()[0].tolist() == [0, 0, 1, 1] assert q.argmax() == 3 assert q.argmin() == 0 assert q.argsort(axis=None).tolist() == [0, 1, 2, 3] - assert_quantity(q.var(), 1.25, mV ** 2) - assert_quantity(q.round(), [[1, 2], [3, 4]], mV) - assert_quantity(q.std(), 1.11803398875, mV) - assert_quantity(q.sum(), 10, mV) - assert_quantity(q.trace(), 5, mV) - assert_quantity(q.cumsum(), [1, 3, 6, 10], mV) - assert_quantity(q.cumprod(), [1, 2, 6, 24], mV ** 4) - assert_quantity(q.diagonal(), [1, 4], mV) - assert_quantity(q.max(), 4, mV) - assert_quantity(q.mean(), 2.5, mV) - assert_quantity(q.min(), 1, mV) - assert_quantity(q.ptp(), 3, mV) - assert_quantity(q.ravel(), [1, 2, 3, 4], mV) + assert_quantity(q.var(), 1.25, second ** 2) + assert_quantity(q.round(), [[1, 2], [3, 4]], second) + assert_quantity(q.std(), 1.11803398875, second) + assert_quantity(q.sum(), 10, second) + assert_quantity(q.trace(), 5, second) + assert_quantity(q.cumsum(), [1, 3, 6, 10], second) + assert_quantity(q.cumprod(), [1, 2, 6, 24], second ** 4) + assert_quantity(q.diagonal(), [1, 4], second) + assert_quantity(q.max(), 4, second) + assert_quantity(q.mean(), 2.5, second) + assert_quantity(q.min(), 1, second) + assert_quantity(q.ptp(), 3, second) + assert_quantity(q.ravel(), [1, 2, 3, 4], second) def test_shape_manipulation(): From fea30e206bd0f930cc8e60565948e4f28c2c4e18 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:07:17 +0800 Subject: [PATCH 09/23] Restruct --- brainunit/math/__init__.py | 47 +- brainunit/math/_compat_numpy.py | 4069 ----------------- .../math/_compat_numpy_array_creation.py | 720 +++ .../math/_compat_numpy_array_manipulation.py | 821 ++++ .../_compat_numpy_funcs_accept_unitless.py | 588 +++ .../math/_compat_numpy_funcs_bit_operation.py | 182 + .../math/_compat_numpy_funcs_change_unit.py | 527 +++ .../math/_compat_numpy_funcs_indexing.py | 166 + .../math/_compat_numpy_funcs_keep_unit.py | 832 ++++ brainunit/math/_compat_numpy_funcs_logic.py | 343 ++ .../math/_compat_numpy_funcs_match_unit.py | 108 + .../math/_compat_numpy_funcs_remove_unit.py | 191 + brainunit/math/_compat_numpy_funcs_window.py | 69 + brainunit/math/_compat_numpy_get_attribute.py | 215 + .../math/_compat_numpy_linear_algebra.py | 149 + brainunit/math/_compat_numpy_misc.py | 354 ++ brainunit/math/_compat_numpy_test.py | 7 +- brainunit/math/_utils.py | 114 +- docs/apis/brainunit.math.rst | 649 ++- docs/auto_generater.py | 32 +- 20 files changed, 5681 insertions(+), 4502 deletions(-) delete mode 100644 brainunit/math/_compat_numpy.py create mode 100644 brainunit/math/_compat_numpy_array_creation.py create mode 100644 brainunit/math/_compat_numpy_array_manipulation.py create mode 100644 brainunit/math/_compat_numpy_funcs_accept_unitless.py create mode 100644 brainunit/math/_compat_numpy_funcs_bit_operation.py create mode 100644 brainunit/math/_compat_numpy_funcs_change_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_indexing.py create mode 100644 brainunit/math/_compat_numpy_funcs_keep_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_logic.py create mode 100644 brainunit/math/_compat_numpy_funcs_match_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_remove_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_window.py create mode 100644 brainunit/math/_compat_numpy_get_attribute.py create mode 100644 brainunit/math/_compat_numpy_linear_algebra.py create mode 100644 brainunit/math/_compat_numpy_misc.py diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 5b1a673..b43d188 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -13,7 +13,48 @@ # limitations under the License. # ============================================================================== -from ._compat_numpy import * -from ._compat_numpy import __all__ as _compat_numpy_all +# from ._compat_numpy import * +# from ._compat_numpy import __all__ as _compat_numpy_all +from ._compat_numpy_array_creation import * +from ._compat_numpy_array_creation import __all__ as _compat_array_creation_all +from ._compat_numpy_array_manipulation import * +from ._compat_numpy_array_manipulation import __all__ as _compat_array_manipulation_all +from ._compat_numpy_funcs_accept_unitless import * +from ._compat_numpy_funcs_accept_unitless import __all__ as _compat_funcs_accept_unitless_all +from ._compat_numpy_funcs_bit_operation import * +from ._compat_numpy_funcs_bit_operation import __all__ as _compat_funcs_bit_operation_all +from ._compat_numpy_funcs_change_unit import * +from ._compat_numpy_funcs_change_unit import __all__ as _compat_funcs_change_unit_all +from ._compat_numpy_funcs_indexing import * +from ._compat_numpy_funcs_indexing import __all__ as _compat_funcs_indexing_all +from ._compat_numpy_funcs_keep_unit import * +from ._compat_numpy_funcs_keep_unit import __all__ as _compat_funcs_keep_unit_all +from ._compat_numpy_funcs_logic import * +from ._compat_numpy_funcs_logic import __all__ as _compat_funcs_logic_all +from ._compat_numpy_funcs_match_unit import * +from ._compat_numpy_funcs_match_unit import __all__ as _compat_funcs_match_unit_all +from ._compat_numpy_funcs_remove_unit import * +from ._compat_numpy_funcs_remove_unit import __all__ as _compat_funcs_remove_unit_all +from ._compat_numpy_funcs_window import * +from ._compat_numpy_funcs_window import __all__ as _compat_funcs_window_all +from ._compat_numpy_get_attribute import * +from ._compat_numpy_get_attribute import __all__ as _compat_get_attribute_all +from ._compat_numpy_linear_algebra import * +from ._compat_numpy_linear_algebra import __all__ as _compat_linear_algebra_all +from ._compat_numpy_misc import * +from ._compat_numpy_misc import __all__ as _compat_misc_all -__all__ = _compat_numpy_all +__all__ = _compat_array_creation_all + \ + _compat_array_manipulation_all + \ + _compat_funcs_change_unit_all + \ + _compat_funcs_keep_unit_all + \ + _compat_funcs_accept_unitless_all + \ + _compat_funcs_match_unit_all + \ + _compat_funcs_remove_unit_all + \ + _compat_get_attribute_all + \ + _compat_funcs_bit_operation_all + \ + _compat_funcs_logic_all + \ + _compat_funcs_indexing_all + \ + _compat_funcs_window_all + \ + _compat_linear_algebra_all + \ + _compat_misc_all diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py deleted file mode 100644 index 360c088..0000000 --- a/brainunit/math/_compat_numpy.py +++ /dev/null @@ -1,4069 +0,0 @@ -# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -from collections.abc import Sequence -from functools import wraps -from typing import (Callable, Union, Optional, Any) - -import brainstate as bst -import jax -import jax.numpy as jnp -import numpy as np -import opt_einsum -from brainstate._utils import set_module_as -from jax import Array -from jax._src.numpy.lax_numpy import _einsum - -from ._utils import _compatible_with_quantity -from .._base import (DIMENSIONLESS, - Quantity, - Unit, - fail_for_dimension_mismatch, - is_unitless, - get_unit, ) -from .._base import _return_check_unitless - -__all__ = [ - # array creation - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', - 'array_split', 'meshgrid', 'vander', - - # getting attribute funcs - 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', - 'isnan', 'shape', 'size', - - # math funcs keep unit (unary) - 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', - 'abs', 'round', 'around', 'round_', 'rint', - 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', - 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', - 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', - 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', - - # math funcs keep unit (binary) - 'fmod', 'mod', 'copysign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', - - # math funcs keep unit (n-ary) - 'interp', 'clip', - - # math funcs match unit (binary) - 'add', 'subtract', 'nextafter', - - # math funcs change unit (unary) - 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', - 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', - - # math funcs change unit (binary) - 'multiply', 'divide', 'power', 'cross', 'ldexp', - 'true_divide', 'floor_divide', 'float_power', - 'divmod', 'remainder', 'convolve', - - # math funcs only accept unitless (unary) - 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', - 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', - 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - - # math funcs only accept unitless (binary) - 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', - - # math funcs remove unit (unary) - 'signbit', 'sign', 'histogram', 'bincount', - - # math funcs remove unit (binary) - 'corrcoef', 'correlate', 'cov', 'digitize', - - # array manipulation - 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', - 'diagflat', 'diagonal', 'choose', 'ravel', - - # Elementwise bit operations (unary) - 'bitwise_not', 'invert', - - # Elementwise bit operations (binary) - 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', - - # logic funcs (unary) - 'all', 'any', 'logical_not', - - # logic funcs (binary) - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', - 'logical_or', 'logical_xor', "alltrue", 'sometrue', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', - - # more - 'broadcast_arrays', 'broadcast_shapes', - 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', - 'rot90', 'tensordot', - -] - - -# array creation -# -------------- - -def wrap_array_creation_function(func): - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -# array creation -# -------------- - -full: Callable = wrap_array_creation_function(jnp.full) -eye: Callable = wrap_array_creation_function(jnp.eye) -identity: Callable = wrap_array_creation_function(jnp.identity) -tri: Callable = wrap_array_creation_function(jnp.tri) -empty: Callable = wrap_array_creation_function(jnp.empty) -ones: Callable = wrap_array_creation_function(jnp.ones) -zeros: Callable = wrap_array_creation_function(jnp.zeros) - -# docs for full, eye, identity, tri, empty, ones, zeros - -full.__doc__ = """ - Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. - else return an array of `shape` filled with `fill_value`. - - Args: - shape: sequence of integers, describing the shape of the output array. - fill_value: the value to fill the new array with. - dtype: the type of the output array, or `None`. If not `None`, `fill_value` - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -eye.__doc__ = """ - Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. - else return an identity matrix of `shape`. - - Args: - n: the number of rows (and columns) in the output array. - k: the index of the diagonal: 0 (the default) refers to the main diagonal, - a positive value refers to an upper diagonal, and a negative value to a - lower diagonal. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -identity.__doc__ = """ - Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. - else return an identity matrix of `shape`. - - Args: - n: the number of rows (and columns) in the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -tri.__doc__ = """ - Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. - else return a triangular matrix of `shape`. - - Args: - n: the number of rows in the output array. - m: the number of columns with default being `n`. - k: the index of the diagonal: 0 (the default) refers to the main diagonal, - a positive value refers to an upper diagonal, and a negative value to a - lower diagonal. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -# empty -empty.__doc__ = """ - Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. - else return an array of `shape` with uninitialized values. - - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be of type `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -# ones -ones.__doc__ = """ - Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. - else return an array of `shape` filled with 1. - - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - -# zeros -zeros.__doc__ = """ - Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. - else return an array of `shape` filled with 0. - - Args: - shape: sequence of integers, describing the shape of the output array. - dtype: the type of the output array, or `None`. If not `None`, elements - will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" - - -@set_module_as('brainunit.math') -def full_like(a: Union[Quantity, bst.typing.ArrayLike], - fill_value: Union[bst.typing.ArrayLike], - unit: Unit = None, - dtype: Optional[bst.typing.DTypeLike] = None, - shape: Any = None) -> Union[Quantity, jax.Array]: - ''' - Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. - else return an array of `a` filled with `fill_value`. - - Args: - a: array_like, Quantity, shape, or dtype - fill_value: scalar or array_like - unit: Unit, optional - dtype: data-type, optional - shape: sequence of ints, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit - else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit - else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) - - -@set_module_as('brainunit.math') -def diag(a: Union[Quantity, bst.typing.ArrayLike], - k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Extract a diagonal or construct a diagonal array. - - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.diag(a.value, k=k) * unit - else: - return jnp.diag(a, k=k) * unit - else: - return jnp.diag(a, k=k) - - -@set_module_as('brainunit.math') -def tril(a: Union[Quantity, bst.typing.ArrayLike], - k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Lower triangle of an array. - - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.tril(a.value, k=k) * unit - else: - return jnp.tril(a, k=k) * unit - else: - return jnp.tril(a, k=k) - - -@set_module_as('brainunit.math') -def triu(a: Union[Quantity, bst.typing.ArrayLike], - k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Upper triangle of an array. - - Args: - a: array_like, Quantity - k: int, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.triu(a.value, k=k) * unit - else: - return jnp.triu(a, k=k) * unit - else: - return jnp.triu(a, k=k) - - -@set_module_as('brainunit.math') -def empty_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, - shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. - else return an array of `a` with uninitialized values. - - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit - else: - return jnp.empty_like(a, dtype=dtype, shape=shape) * unit - else: - return jnp.empty_like(a, dtype=dtype, shape=shape) - - -@set_module_as('brainunit.math') -def ones_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, - shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. - else return an array of `a` filled with 1. - - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit - else: - return jnp.ones_like(a, dtype=dtype, shape=shape) * unit - else: - return jnp.ones_like(a, dtype=dtype, shape=shape) - - -@set_module_as('brainunit.math') -def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], - dtype: Optional[bst.typing.DTypeLike] = None, - shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: - ''' - Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. - else return an array of `a` filled with 0. - - Args: - a: array_like, Quantity, shape, or dtype - dtype: data-type, optional - shape: sequence of ints, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. - ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit - else: - return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit - else: - return jnp.zeros_like(a, dtype=dtype, shape=shape) - - -@set_module_as('brainunit.math') -def asarray( - a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], - dtype: Optional[bst.typing.DTypeLike] = None, - order: Optional[str] = None, - unit: Optional[Unit] = None, -) -> Union[Quantity, jax.Array]: - from builtins import all as origin_all - from builtins import any as origin_any - if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.asarray(a, dtype=dtype, order=order) - # list[Quantity] - elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if origin_any(x.unit != a[0].unit for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].unit - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), unit=unit) - else: - return jnp.asarray(a, dtype=dtype, order=order) - - -array = asarray - - -@set_module_as('brainunit.math') -def arange(*args, **kwargs): - ''' - Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. - - Args: - start: number, Quantity, optional - stop: number, Quantity, optional - step: number, optional - dtype: dtype, optional - unit: Unit, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. - ''' - # arange has a bit of a complicated argument structure unfortunately - # we leave the actual checking of the number of arguments to numpy, though - - # default values - start = kwargs.pop("start", 0) - step = kwargs.pop("step", 1) - stop = kwargs.pop("stop", None) - if len(args) == 1: - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - stop = args[0] - elif len(args) == 2: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - start, stop = args - elif len(args) == 3: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - if step != 1: - raise TypeError("Duplicate definition of 'step'") - start, stop, step = args - elif len(args) > 3: - raise TypeError("Need between 1 and 3 non-keyword arguments") - - if stop is None: - raise TypeError("Missing stop argument.") - if stop is not None and not is_unitless(stop): - start = Quantity(start, unit=stop.unit) - - fail_for_dimension_mismatch( - start, - stop, - error_message=( - "Start value {start} and stop value {stop} have to have the same units." - ), - start=start, - stop=stop, - ) - fail_for_dimension_mismatch( - stop, - step, - error_message=( - "Stop value {stop} and step value {step} have to have the same units." - ), - stop=stop, - step=step, - ) - unit = getattr(stop, "unit", DIMENSIONLESS) - # start is a position-only argument in numpy 2.0 - # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only - # TODO: check whether this is still the case in the final release - if start == 0: - return Quantity( - jnp.arange( - start=start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - unit=unit, - ) - else: - return Quantity( - jnp.arange( - start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - unit=unit, - ) - - -@set_module_as('brainunit.math') -def linspace(start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], - num: int = 50, - endpoint: Optional[bool] = True, - retstep: Optional[bool] = False, - dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: - ''' - Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. - - Args: - start: number, Quantity - stop: number, Quantity - num: int, optional - endpoint: bool, optional - retstep: bool, optional - dtype: dtype, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. - ''' - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) - return Quantity(result, unit=unit) - - -@set_module_as('brainunit.math') -def logspace(start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], - num: Optional[int] = 50, - endpoint: Optional[bool] = True, - base: Optional[float] = 10.0, - dtype: Optional[bst.typing.DTypeLike] = None): - ''' - Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. - - Args: - start: number, Quantity - stop: number, Quantity - num: int, optional - endpoint: bool, optional - base: float, optional - dtype: dtype, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. - ''' - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) - return Quantity(result, unit=unit) - - -@set_module_as('brainunit.math') -def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], - val: Union[Quantity, bst.typing.ArrayLike], - wrap: Optional[bool] = False, - inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: - ''' - Fill the main diagonal of the given array of `a` with `val`. - - Args: - a: array_like, Quantity - val: scalar, Quantity - wrap: bool, optional - inplace: bool, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. - ''' - if isinstance(a, Quantity) and isinstance(val, Quantity): - fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - elif is_unitless(a) or is_unitless(val): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') - - -@set_module_as('brainunit.math') -def array_split(ary: Union[Quantity, bst.typing.ArrayLike], - indices_or_sections: Union[int, bst.typing.ArrayLike], - axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: - ''' - Split an array into multiple sub-arrays. - - Args: - ary: array_like, Quantity - indices_or_sections: int, array_like - axis: int, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. - ''' - if isinstance(ary, Quantity): - return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] - elif isinstance(ary, (bst.typing.ArrayLike)): - return jnp.array_split(ary, indices_or_sections, axis) - else: - raise ValueError(f'Unsupported type: {type(ary)} for array_split') - - -@set_module_as('brainunit.math') -def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], - copy: Optional[bool] = True, - sparse: Optional[bool] = False, - indexing: Optional[str] = 'xy'): - ''' - Return coordinate matrices from coordinate vectors. - - Args: - xi: array_like, Quantity - copy: bool, optional - sparse: bool, optional - indexing: str, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. - ''' - from builtins import all as origin_all - if origin_all(isinstance(x, Quantity) for x in xi): - fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), unit=xi[0].unit) - elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): - return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) - else: - raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') - - -@set_module_as('brainunit.math') -def vander(x: Union[Quantity, bst.typing.ArrayLike], - N: Optional[bool] = None, - increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: - ''' - Generate a Vandermonde matrix. - - Args: - x: array_like, Quantity - N: int, optional - increasing: bool, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - ''' - if isinstance(x, Quantity): - return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return jnp.vander(x, N=N, increasing=increasing) - else: - raise ValueError(f'Unsupported type: {type(x)} for vander') - - -# getting attribute funcs -# ----------------------- - -@set_module_as('brainunit.math') -def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: - ''' - Return the number of dimensions of an array. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: int - ''' - if isinstance(a, Quantity): - return a.ndim - else: - return jnp.ndim(a) - - -@set_module_as('brainunit.math') -def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: - ''' - Return True if the input array is real. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: boolean array - ''' - if isinstance(a, Quantity): - return a.isreal - else: - return jnp.isreal(a) - - -@set_module_as('brainunit.math') -def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: - ''' - Return True if the input is a scalar. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: boolean array - ''' - if isinstance(a, Quantity): - return a.isscalar - else: - return jnp.isscalar(a) - - -@set_module_as('brainunit.math') -def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: - ''' - Return each element of the array is finite or not. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: boolean array - ''' - if isinstance(a, Quantity): - return a.isfinite - else: - return jnp.isfinite(a) - - -@set_module_as('brainunit.math') -def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: - ''' - Return each element of the array is infinite or not. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: boolean array - ''' - if isinstance(a, Quantity): - return a.isinf - else: - return jnp.isinf(a) - - -@set_module_as('brainunit.math') -def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: - ''' - Return each element of the array is NaN or not. - - Args: - a: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: boolean array - ''' - if isinstance(a, Quantity): - return a.isnan - else: - return jnp.isnan(a) - - -@set_module_as('brainunit.math') -def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: - """ - Return the shape of an array. - - Parameters - ---------- - a : array_like - Input array. - - Returns - ------- - shape : tuple of ints - The elements of the shape tuple give the lengths of the - corresponding array dimensions. - - See Also - -------- - len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with - ``N>=1``. - ndarray.shape : Equivalent array method. - - Examples - -------- - >>> brainunit.math.shape(brainunit.math.eye(3)) - (3, 3) - >>> brainunit.math.shape([[1, 3]]) - (1, 2) - >>> brainunit.math.shape([0]) - (1,) - >>> brainunit.math.shape(0) - () - - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - return a.shape - else: - return np.shape(a) - - -@set_module_as('brainunit.math') -def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: - """ - Return the number of elements along a given axis. - - Parameters - ---------- - a : array_like - Input data. - axis : int, optional - Axis along which the elements are counted. By default, give - the total number of elements. - - Returns - ------- - element_count : int - Number of elements along the specified axis. - - See Also - -------- - shape : dimensions of array - Array.shape : dimensions of array - Array.size : number of elements in array - - Examples - -------- - >>> a = Quantity([[1,2,3], [4,5,6]]) - >>> brainunit.math.size(a) - 6 - >>> brainunit.math.size(a, 1) - 3 - >>> brainunit.math.size(a, 0) - 2 - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - if axis is None: - return a.size - else: - return a.shape[axis] - else: - return np.size(a, axis=axis) - - -# math funcs keep unit (unary) -# ---------------------------- - -def wrap_math_funcs_keep_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -real = wrap_math_funcs_keep_unit_unary(jnp.real) -imag = wrap_math_funcs_keep_unit_unary(jnp.imag) -conj = wrap_math_funcs_keep_unit_unary(jnp.conj) -conjugate = wrap_math_funcs_keep_unit_unary(jnp.conjugate) -negative = wrap_math_funcs_keep_unit_unary(jnp.negative) -positive = wrap_math_funcs_keep_unit_unary(jnp.positive) -abs = wrap_math_funcs_keep_unit_unary(jnp.abs) -round_ = wrap_math_funcs_keep_unit_unary(jnp.round) -around = wrap_math_funcs_keep_unit_unary(jnp.around) -round = wrap_math_funcs_keep_unit_unary(jnp.round) -rint = wrap_math_funcs_keep_unit_unary(jnp.rint) -floor = wrap_math_funcs_keep_unit_unary(jnp.floor) -ceil = wrap_math_funcs_keep_unit_unary(jnp.ceil) -trunc = wrap_math_funcs_keep_unit_unary(jnp.trunc) -fix = wrap_math_funcs_keep_unit_unary(jnp.fix) -sum = wrap_math_funcs_keep_unit_unary(jnp.sum) -nancumsum = wrap_math_funcs_keep_unit_unary(jnp.nancumsum) -nansum = wrap_math_funcs_keep_unit_unary(jnp.nansum) -cumsum = wrap_math_funcs_keep_unit_unary(jnp.cumsum) -ediff1d = wrap_math_funcs_keep_unit_unary(jnp.ediff1d) -absolute = wrap_math_funcs_keep_unit_unary(jnp.absolute) -fabs = wrap_math_funcs_keep_unit_unary(jnp.fabs) -median = wrap_math_funcs_keep_unit_unary(jnp.median) -nanmin = wrap_math_funcs_keep_unit_unary(jnp.nanmin) -nanmax = wrap_math_funcs_keep_unit_unary(jnp.nanmax) -ptp = wrap_math_funcs_keep_unit_unary(jnp.ptp) -average = wrap_math_funcs_keep_unit_unary(jnp.average) -mean = wrap_math_funcs_keep_unit_unary(jnp.mean) -std = wrap_math_funcs_keep_unit_unary(jnp.std) -nanmedian = wrap_math_funcs_keep_unit_unary(jnp.nanmedian) -nanmean = wrap_math_funcs_keep_unit_unary(jnp.nanmean) -nanstd = wrap_math_funcs_keep_unit_unary(jnp.nanstd) -diff = wrap_math_funcs_keep_unit_unary(jnp.diff) -modf = wrap_math_funcs_keep_unit_unary(jnp.modf) - -# docs for the functions above -real.__doc__ = ''' - Return the real part of the complex argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -imag.__doc__ = ''' - Return the imaginary part of the complex argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -conj.__doc__ = ''' - Return the complex conjugate of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -conjugate.__doc__ = ''' - Return the complex conjugate of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -negative.__doc__ = ''' - Return the negative of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -positive.__doc__ = ''' - Return the positive of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -abs.__doc__ = ''' - Return the absolute value of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -round_.__doc__ = ''' - Round an array to the nearest integer. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -around.__doc__ = ''' - Round an array to the nearest integer. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -round.__doc__ = ''' - Round an array to the nearest integer. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -rint.__doc__ = ''' - Round an array to the nearest integer. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -floor.__doc__ = ''' - Return the floor of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -ceil.__doc__ = ''' - Return the ceiling of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -trunc.__doc__ = ''' - Return the truncated value of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -fix.__doc__ = ''' - Return the nearest integer towards zero. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -sum.__doc__ = ''' - Return the sum of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nancumsum.__doc__ = ''' - Return the cumulative sum of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nansum.__doc__ = ''' - Return the sum of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -cumsum.__doc__ = ''' - Return the cumulative sum of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -ediff1d.__doc__ = ''' - Return the differences between consecutive elements of the array. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -absolute.__doc__ = ''' - Return the absolute value of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -fabs.__doc__ = ''' - Return the absolute value of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -median.__doc__ = ''' - Return the median of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nanmin.__doc__ = ''' - Return the minimum of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nanmax.__doc__ = ''' - Return the maximum of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -ptp.__doc__ = ''' - Return the range of the array elements (maximum - minimum). - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -average.__doc__ = ''' - Return the weighted average of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -mean.__doc__ = ''' - Return the mean of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -std.__doc__ = ''' - Return the standard deviation of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nanmedian.__doc__ = ''' - Return the median of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nanmean.__doc__ = ''' - Return the mean of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -nanstd.__doc__ = ''' - Return the standard deviation of the array elements, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -diff.__doc__ = ''' - Return the differences between consecutive elements of the array. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - -modf.__doc__ = ''' - Return the fractional and integer parts of the array elements. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. -''' - - -# math funcs keep unit (binary) -# ----------------------------- - -def wrap_math_funcs_keep_unit_binary(func): - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) -mod = wrap_math_funcs_keep_unit_binary(jnp.mod) -copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) -heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) -maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) -minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) -fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) -fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) -lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) -gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) - -# docs for the functions above -fmod.__doc__ = ''' - Return the element-wise remainder of division. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -mod.__doc__ = ''' - Return the element-wise modulus of division. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -copysign.__doc__ = ''' - Return a copy of the first array elements with the sign of the second array. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -heaviside.__doc__ = ''' - Compute the Heaviside step function. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -maximum.__doc__ = ''' - Element-wise maximum of array elements. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -minimum.__doc__ = ''' - Element-wise minimum of array elements. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -fmax.__doc__ = ''' - Element-wise maximum of array elements ignoring NaNs. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -fmin.__doc__ = ''' - Element-wise minimum of array elements ignoring NaNs. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -lcm.__doc__ = ''' - Return the least common multiple of `x1` and `x2`. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - -gcd.__doc__ = ''' - Return the greatest common divisor of `x1` and `x2`. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - - -# math funcs keep unit (n-ary) -# ---------------------------- -@set_module_as('brainunit.math') -def interp(x: Union[Quantity, bst.typing.ArrayLike], - xp: Union[Quantity, bst.typing.ArrayLike], - fp: Union[Quantity, bst.typing.ArrayLike], - left: Union[Quantity, bst.typing.ArrayLike] = None, - right: Union[Quantity, bst.typing.ArrayLike] = None, - period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: - ''' - One-dimensional linear interpolation. - - Args: - x: array_like, Quantity - xp: array_like, Quantity - fp: array_like, Quantity - left: array_like, Quantity, optional - right: array_like, Quantity, optional - period: array_like, Quantity, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. - ''' - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value - else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result - - -@set_module_as('brainunit.math') -def clip(a: Union[Quantity, bst.typing.ArrayLike], - a_min: Union[Quantity, bst.typing.ArrayLike], - a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - ''' - Clip (limit) the values in an array. - - Args: - a: array_like, Quantity - a_min: array_like, Quantity - a_max: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. - ''' - unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, unit=unit) - else: - return result - - -# math funcs match unit (binary) -# ------------------------------ - -def wrap_math_funcs_match_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -add = wrap_math_funcs_match_unit_binary(jnp.add) -subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) -nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) - -# docs for the functions above -add.__doc__ = ''' - Add arguments element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. -''' - -subtract.__doc__ = ''' - Subtract arguments element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. -''' - -nextafter.__doc__ = ''' - Return the next floating-point value after `x1` towards `x2`. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. -''' - - -# math funcs change unit (unary) -# ------------------------------ - -def wrap_math_funcs_change_unit_unary(func, change_unit_func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) -reciprocal.__doc__ = ''' - Return the reciprocal of the argument. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. -''' - - -@set_module_as('brainunit.math') -def prod(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None, - keepdims: Optional[bool] = False, - initial: Union[Quantity, bst.typing.ArrayLike] = None, - where: Union[Quantity, bst.typing.ArrayLike] = None, - promote_integers: bool = True) -> Union[Quantity, jax.Array]: - ''' - Return the product of array elements over a given axis. - - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - keepdims: bool, optional - initial: array_like, Quantity, optional - where: array_like, Quantity, optional - promote_integers: bool, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - ''' - if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, - promote_integers=promote_integers) - else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, - promote_integers=promote_integers) - - -@set_module_as('brainunit.math') -def nanprod(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None, - keepdims: None = False, - initial: Union[Quantity, bst.typing.ArrayLike] = None, - where: Union[Quantity, bst.typing.ArrayLike] = None): - ''' - Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. - - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - keepdims: bool, optional - initial: array_like, Quantity, optional - where: array_like, Quantity, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - ''' - if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) - else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) - - -product = prod - - -@set_module_as('brainunit.math') -def cumprod(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: - ''' - Return the cumulative product of elements along a given axis. - - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - ''' - if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) - - -@set_module_as('brainunit.math') -def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[int] = None, - dtype: Optional[bst.typing.DTypeLike] = None, - out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: - ''' - Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. - - Args: - x: array_like, Quantity - axis: int, optional - dtype: dtype, optional - out: array, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. - ''' - if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) - - -cumproduct = cumprod - -var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) -nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) -frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) -sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) -cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) -square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) - -# docs for the functions above -var.__doc__ = ''' - Compute the variance along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' - -nanvar.__doc__ = ''' - Compute the variance along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' - -frexp.__doc__ = ''' - Decompose a floating-point number into its mantissa and exponent. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. -''' - -sqrt.__doc__ = ''' - Compute the square root of each element. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. -''' - -cbrt.__doc__ = ''' - Compute the cube root of each element. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. -''' - -square.__doc__ = ''' - Compute the square of each element. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. -''' - - -# math funcs change unit (binary) -# ------------------------------- - -def wrap_math_funcs_change_unit_binary(func, change_unit_func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) - ) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) - elif isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) -divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) -cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) -ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) -true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) -divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) -convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) - -# docs for the functions above -multiply.__doc__ = ''' - Multiply arguments element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' - -divide.__doc__ = ''' - Divide arguments element-wise. - - Args: - x: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' - -cross.__doc__ = ''' - Return the cross product of two (arrays of) vectors. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' - -ldexp.__doc__ = ''' - Return x1 * 2**x2, element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. -''' - -true_divide.__doc__ = ''' - Returns a true division of the inputs, element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' - -divmod.__doc__ = ''' - Return element-wise quotient and remainder simultaneously. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. -''' - -convolve.__doc__ = ''' - Returns the discrete, linear convolution of two one-dimensional sequences. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' - - -@set_module_as('brainunit.math') -def power(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: - ''' - First array elements raised to powers from second array, element-wise. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. - ''' - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.power(x, y) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') - - -@set_module_as('brainunit.math') -def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - ''' - Return the largest integer smaller or equal to the division of the inputs. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. - ''' - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.floor_divide(x, y) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') - - -@set_module_as('brainunit.math') -def float_power(x: Union[Quantity, bst.typing.ArrayLike], - y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: - ''' - First array elements raised to powers from second array, element-wise. - - Args: - x: array_like, Quantity - y: array_like - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. - ''' - if isinstance(y, Quantity): - assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y)) - elif isinstance(x, (jax.Array, np.ndarray)): - return jnp.float_power(x, y) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') - - -@set_module_as('brainunit.math') -def remainder(x: Union[Quantity, bst.typing.ArrayLike], - y: Union[Quantity, bst.typing.ArrayLike]): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.remainder(x, y) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') - - -# math funcs only accept unitless (unary) -# --------------------------------------- - -def wrap_math_funcs_only_accept_unitless_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - return func(jnp.array(x.value), *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -exp = wrap_math_funcs_only_accept_unitless_unary(jnp.exp) -exp2 = wrap_math_funcs_only_accept_unitless_unary(jnp.exp2) -expm1 = wrap_math_funcs_only_accept_unitless_unary(jnp.expm1) -log = wrap_math_funcs_only_accept_unitless_unary(jnp.log) -log10 = wrap_math_funcs_only_accept_unitless_unary(jnp.log10) -log1p = wrap_math_funcs_only_accept_unitless_unary(jnp.log1p) -log2 = wrap_math_funcs_only_accept_unitless_unary(jnp.log2) -arccos = wrap_math_funcs_only_accept_unitless_unary(jnp.arccos) -arccosh = wrap_math_funcs_only_accept_unitless_unary(jnp.arccosh) -arcsin = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsin) -arcsinh = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsinh) -arctan = wrap_math_funcs_only_accept_unitless_unary(jnp.arctan) -arctanh = wrap_math_funcs_only_accept_unitless_unary(jnp.arctanh) -cos = wrap_math_funcs_only_accept_unitless_unary(jnp.cos) -cosh = wrap_math_funcs_only_accept_unitless_unary(jnp.cosh) -sin = wrap_math_funcs_only_accept_unitless_unary(jnp.sin) -sinc = wrap_math_funcs_only_accept_unitless_unary(jnp.sinc) -sinh = wrap_math_funcs_only_accept_unitless_unary(jnp.sinh) -tan = wrap_math_funcs_only_accept_unitless_unary(jnp.tan) -tanh = wrap_math_funcs_only_accept_unitless_unary(jnp.tanh) -deg2rad = wrap_math_funcs_only_accept_unitless_unary(jnp.deg2rad) -rad2deg = wrap_math_funcs_only_accept_unitless_unary(jnp.rad2deg) -degrees = wrap_math_funcs_only_accept_unitless_unary(jnp.degrees) -radians = wrap_math_funcs_only_accept_unitless_unary(jnp.radians) -angle = wrap_math_funcs_only_accept_unitless_unary(jnp.angle) -percentile = wrap_math_funcs_only_accept_unitless_unary(jnp.percentile) -nanpercentile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanpercentile) -quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) -nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) - -# docs for the functions above -exp.__doc__ = ''' - Calculate the exponential of all elements in the input array. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -exp2.__doc__ = ''' - Calculate 2 raised to the power of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -expm1.__doc__ = ''' - Calculate the exponential of the input elements minus 1. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -log.__doc__ = ''' - Natural logarithm, element-wise. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -log10.__doc__ = ''' - Base-10 logarithm of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -log1p.__doc__ = ''' - Natural logarithm of 1 + the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -log2.__doc__ = ''' - Base-2 logarithm of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arccos.__doc__ = ''' - Compute the arccosine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arccosh.__doc__ = ''' - Compute the hyperbolic arccosine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arcsin.__doc__ = ''' - Compute the arcsine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arcsinh.__doc__ = ''' - Compute the hyperbolic arcsine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arctan.__doc__ = ''' - Compute the arctangent of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arctanh.__doc__ = ''' - Compute the hyperbolic arctangent of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -cos.__doc__ = ''' - Compute the cosine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -cosh.__doc__ = ''' - Compute the hyperbolic cosine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -sin.__doc__ = ''' - Compute the sine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -sinc.__doc__ = ''' - Compute the sinc function of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -sinh.__doc__ = ''' - Compute the hyperbolic sine of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -tan.__doc__ = ''' - Compute the tangent of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -tanh.__doc__ = ''' - Compute the hyperbolic tangent of the input elements. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -deg2rad.__doc__ = ''' - Convert angles from degrees to radians. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -rad2deg.__doc__ = ''' - Convert angles from radians to degrees. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -degrees.__doc__ = ''' - Convert angles from radians to degrees. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -radians.__doc__ = ''' - Convert angles from degrees to radians. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -angle.__doc__ = ''' - Return the angle of the complex argument. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -percentile.__doc__ = ''' - Compute the nth percentile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -nanpercentile.__doc__ = ''' - Compute the nth percentile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -quantile.__doc__ = ''' - Compute the qth quantile of the input array along the specified axis. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -nanquantile.__doc__ = ''' - Compute the qth quantile of the input array along the specified axis, ignoring NaNs. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - - -# math funcs only accept unitless (binary) -# ---------------------------------------- - -def wrap_math_funcs_only_accept_unitless_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - fail_for_dimension_mismatch( - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=y, - ) - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -hypot = wrap_math_funcs_only_accept_unitless_binary(jnp.hypot) -arctan2 = wrap_math_funcs_only_accept_unitless_binary(jnp.arctan2) -logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) -logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) - -# docs for the functions above -hypot.__doc__ = ''' - Given the “legs” of a right triangle, return its hypotenuse. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - jax.Array: an array -''' - -arctan2.__doc__ = ''' - Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - jax.Array: an array -''' - -logaddexp.__doc__ = ''' - Logarithm of the sum of exponentiations of the inputs. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - jax.Array: an array -''' - -logaddexp2.__doc__ = ''' - Logarithm of the sum of exponentiations of the inputs in base-2. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - jax.Array: an array -''' - - -# math funcs remove unit (unary) -# ------------------------------ -def wrap_math_funcs_remove_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return func(x.value, *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -signbit = wrap_math_funcs_remove_unit_unary(jnp.signbit) -sign = wrap_math_funcs_remove_unit_unary(jnp.sign) -histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) -bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) - -# docs for the functions above -signbit.__doc__ = ''' - Returns element-wise True where signbit is set (less than zero). - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -sign.__doc__ = ''' - Returns the sign of each element in the input array. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - -histogram.__doc__ = ''' - Compute the histogram of a set of data. - - Args: - x: array_like, Quantity - - Returns: - tuple[jax.Array]: Tuple of arrays (hist, bin_edges) -''' - -bincount.__doc__ = ''' - Count number of occurrences of each value in array of non-negative integers. - - Args: - x: array_like, Quantity - - Returns: - jax.Array: an array -''' - - -# math funcs remove unit (binary) -# ------------------------------- -def wrap_math_funcs_remove_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -corrcoef = wrap_math_funcs_remove_unit_binary(jnp.corrcoef) -correlate = wrap_math_funcs_remove_unit_binary(jnp.correlate) -cov = wrap_math_funcs_remove_unit_binary(jnp.cov) -digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) - -# docs for the functions above -corrcoef.__doc__ = ''' - Return Pearson product-moment correlation coefficients. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - jax.Array: an array -''' - -correlate.__doc__ = ''' - Cross-correlation of two sequences. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - jax.Array: an array -''' - -cov.__doc__ = ''' - Covariance matrix. - - Args: - x: array_like, Quantity - y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) - - Returns: - jax.Array: an array -''' - -digitize.__doc__ = ''' - Return the indices of the bins to which each value in input array belongs. - - Args: - x: array_like, Quantity - bins: array_like, Quantity - - Returns: - jax.Array: an array -''' - -# array manipulation -# ------------------ - -reshape = _compatible_with_quantity(jnp.reshape) -moveaxis = _compatible_with_quantity(jnp.moveaxis) -transpose = _compatible_with_quantity(jnp.transpose) -swapaxes = _compatible_with_quantity(jnp.swapaxes) -concatenate = _compatible_with_quantity(jnp.concatenate) -stack = _compatible_with_quantity(jnp.stack) -vstack = _compatible_with_quantity(jnp.vstack) -row_stack = vstack -hstack = _compatible_with_quantity(jnp.hstack) -dstack = _compatible_with_quantity(jnp.dstack) -column_stack = _compatible_with_quantity(jnp.column_stack) -split = _compatible_with_quantity(jnp.split) -dsplit = _compatible_with_quantity(jnp.dsplit) -hsplit = _compatible_with_quantity(jnp.hsplit) -vsplit = _compatible_with_quantity(jnp.vsplit) -tile = _compatible_with_quantity(jnp.tile) -repeat = _compatible_with_quantity(jnp.repeat) -unique = _compatible_with_quantity(jnp.unique) -append = _compatible_with_quantity(jnp.append) -flip = _compatible_with_quantity(jnp.flip) -fliplr = _compatible_with_quantity(jnp.fliplr) -flipud = _compatible_with_quantity(jnp.flipud) -roll = _compatible_with_quantity(jnp.roll) -atleast_1d = _compatible_with_quantity(jnp.atleast_1d) -atleast_2d = _compatible_with_quantity(jnp.atleast_2d) -atleast_3d = _compatible_with_quantity(jnp.atleast_3d) -expand_dims = _compatible_with_quantity(jnp.expand_dims) -squeeze = _compatible_with_quantity(jnp.squeeze) -sort = _compatible_with_quantity(jnp.sort) - -max = _compatible_with_quantity(jnp.max) -min = _compatible_with_quantity(jnp.min) - -amax = max -amin = min - -choose = _compatible_with_quantity(jnp.choose) -block = _compatible_with_quantity(jnp.block) -compress = _compatible_with_quantity(jnp.compress) -diagflat = _compatible_with_quantity(jnp.diagflat) - -# return jax.numpy.Array, not Quantity -argsort = _compatible_with_quantity(jnp.argsort, return_quantity=False) -argmax = _compatible_with_quantity(jnp.argmax, return_quantity=False) -argmin = _compatible_with_quantity(jnp.argmin, return_quantity=False) -argwhere = _compatible_with_quantity(jnp.argwhere, return_quantity=False) -nonzero = _compatible_with_quantity(jnp.nonzero, return_quantity=False) -flatnonzero = _compatible_with_quantity(jnp.flatnonzero, return_quantity=False) -searchsorted = _compatible_with_quantity(jnp.searchsorted, return_quantity=False) -extract = _compatible_with_quantity(jnp.extract, return_quantity=False) -count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) - -# docs for the functions above -reshape.__doc__ = ''' - Return a reshaped copy of an array or a Quantity. - - Args: - a: input array or Quantity to reshape - shape: integer or sequence of integers giving the new shape, which must match the - size of the input array. If any single dimension is given size ``-1``, it will be - replaced with a value such that the output has the correct size. - order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major - (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. - brainunit does not support ``order="A"``. - - Returns: - reshaped copy of input array with the specified shape. -''' - -moveaxis.__doc__ = ''' - Moves axes of an array to new positions. Other axes remain in their original order. - - Args: - a: array_like, Quantity - source: int or sequence of ints - destination: int or sequence of ints - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -transpose.__doc__ = ''' - Returns a view of the array with axes transposed. - - Args: - a: array_like, Quantity - axes: tuple or list of ints, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -swapaxes.__doc__ = ''' - Interchanges two axes of an array. - - Args: - a: array_like, Quantity - axis1: int - axis2: int - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -concatenate.__doc__ = ''' - Join a sequence of arrays along an existing axis. - - Args: - arrays: sequence of array_like, Quantity - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -stack.__doc__ = ''' - Join a sequence of arrays along a new axis. - - Args: - arrays: sequence of array_like, Quantity - axis: int - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -vstack.__doc__ = ''' - Stack arrays in sequence vertically (row wise). - - Args: - arrays: sequence of array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array -''' - -hstack.__doc__ = ''' - Stack arrays in sequence horizontally (column wise). - - Args: - arrays: sequence of array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -dstack.__doc__ = ''' - Stack arrays in sequence depth wise (along third axis). - - Args: - arrays: sequence of array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -column_stack.__doc__ = ''' - Stack 1-D arrays as columns into a 2-D array. - - Args: - arrays: sequence of 1-D or 2-D array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -split.__doc__ = ''' - Split an array into multiple sub-arrays. - - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' - -dsplit.__doc__ = ''' - Split array along third axis (depth). - - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array - - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' - -hsplit.__doc__ = ''' - Split an array into multiple sub-arrays horizontally (column-wise). - - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array - - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' - -vsplit.__doc__ = ''' - Split an array into multiple sub-arrays vertically (row-wise). - - Args: - a: array_like, Quantity - indices_or_sections: int or 1-D array - - Returns: - Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array -''' - -tile.__doc__ = ''' - Construct an array by repeating A the number of times given by reps. - - Args: - A: array_like, Quantity - reps: array_like - - Returns: - Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array -''' - -repeat.__doc__ = ''' - Repeat elements of an array. - - Args: - a: array_like, Quantity - repeats: array_like - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -unique.__doc__ = ''' - Find the unique elements of an array. - - Args: - a: array_like, Quantity - return_index: bool, optional - return_inverse: bool, optional - return_counts: bool, optional - axis: int or None, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -append.__doc__ = ''' - Append values to the end of an array. - - Args: - arr: array_like, Quantity - values: array_like, Quantity - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array -''' - -flip.__doc__ = ''' - Reverse the order of elements in an array along the given axis. - - Args: - m: array_like, Quantity - axis: int or tuple of ints, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' - -fliplr.__doc__ = ''' - Flip array in the left/right direction. - - Args: - m: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' - -flipud.__doc__ = ''' - Flip array in the up/down direction. - - Args: - m: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array -''' - -roll.__doc__ = ''' - Roll array elements along a given axis. - - Args: - a: array_like, Quantity - shift: int or tuple of ints - axis: int or tuple of ints, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -atleast_1d.__doc__ = ''' - View inputs as arrays with at least one dimension. - - Args: - *args: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' - -atleast_2d.__doc__ = ''' - View inputs as arrays with at least two dimensions. - - Args: - *args: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' - -atleast_3d.__doc__ = ''' - View inputs as arrays with at least three dimensions. - - Args: - *args: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array -''' - -expand_dims.__doc__ = ''' - Expand the shape of an array. - - Args: - a: array_like, Quantity - axis: int or tuple of ints - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -squeeze.__doc__ = ''' - Remove single-dimensional entries from the shape of an array. - - Args: - a: array_like, Quantity - axis: None or int or tuple of ints, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -sort.__doc__ = ''' - Return a sorted copy of an array. - - Args: - a: array_like, Quantity - axis: int or None, optional - kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional - order: str or list of str, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' -max.__doc__ = ''' - Return the maximum of an array or maximum along an axis. - - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional - keepdims: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -min.__doc__ = ''' - Return the minimum of an array or minimum along an axis. - - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional - keepdims: bool, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -choose.__doc__ = ''' - Use an index array to construct a new array from a set of choices. - - Args: - a: array_like, Quantity - choices: array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array -''' - -block.__doc__ = ''' - Assemble an nd-array from nested lists of blocks. - - Args: - arrays: sequence of array_like, Quantity - - Returns: - Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array -''' - -compress.__doc__ = ''' - Return selected slices of an array along given axis. - - Args: - condition: array_like, Quantity - a: array_like, Quantity - axis: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -diagflat.__doc__ = ''' - Create a two-dimensional array with the flattened input as a diagonal. - - Args: - a: array_like, Quantity - offset: int, optional - - Returns: - Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array -''' - -argsort.__doc__ = ''' - Returns the indices that would sort an array. - - Args: - a: array_like, Quantity - axis: int or None, optional - kind: {'quicksort', 'mergesort', 'heapsort'}, optional - order: str or list of str, optional - - Returns: - jax.Array jax.numpy.Array (does not return a Quantity) -''' - -argmax.__doc__ = ''' - Returns indices of the max value along an axis. - - Args: - a: array_like, Quantity - axis: int, optional - out: array, optional - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -argmin.__doc__ = ''' - Returns indices of the min value along an axis. - - Args: - a: array_like, Quantity - axis: int, optional - out: array, optional - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -argwhere.__doc__ = ''' - Find indices of non-zero elements. - - Args: - a: array_like, Quantity - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -nonzero.__doc__ = ''' - Return the indices of the elements that are non-zero. - - Args: - a: array_like, Quantity - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -flatnonzero.__doc__ = ''' - Return indices that are non-zero in the flattened version of a. - - Args: - a: array_like, Quantity - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -searchsorted.__doc__ = ''' - Find indices where elements should be inserted to maintain order. - - Args: - a: array_like, Quantity - v: array_like, Quantity - side: {'left', 'right'}, optional - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -extract.__doc__ = ''' - Return the elements of an array that satisfy some condition. - - Args: - condition: array_like, Quantity - a: array_like, Quantity - - Returns: - jax.Array: an array (does not return a Quantity) -''' - -count_nonzero.__doc__ = ''' - Counts the number of non-zero values in the array a. - - Args: - a: array_like, Quantity - axis: int or tuple of ints, optional - - Returns: - jax.Array: an array (does not return a Quantity) -''' - - -def wrap_function_to_method(func): - @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -diagonal = wrap_function_to_method(jnp.diagonal) -ravel = wrap_function_to_method(jnp.ravel) - -diagonal.__doc__ = ''' - Return specified diagonals. - - Args: - a: array_like, Quantity - offset: int, optional - axis1: int, optional - axis2: int, optional - - Returns: - Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array -''' - -ravel.__doc__ = ''' - Return a contiguous flattened array. - - Args: - a: array_like, Quantity - order: {'C', 'F', 'A', 'K'}, optional - - Returns: - Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array -''' - - -# Elementwise bit operations (unary) -# ---------------------------------- - -def wrap_elementwise_bit_operation_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected integers, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) -invert = wrap_elementwise_bit_operation_unary(jnp.invert) - -# docs for functions above -bitwise_not.__doc__ = ''' - Compute the bit-wise NOT of an array, element-wise. - - Args: - x: array_like - - Returns: - jax.Array: an array -''' - -invert.__doc__ = ''' - Compute bit-wise inversion, or bit-wise NOT, element-wise. - - Args: - x: array_like - - Returns: - jax.Array: an array -''' - - -# Elementwise bit operations (binary) -# ----------------------------------- - -def wrap_elementwise_bit_operation_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) or isinstance(y, Quantity): - raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) -bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) -bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) -left_shift = wrap_elementwise_bit_operation_binary(jnp.left_shift) -right_shift = wrap_elementwise_bit_operation_binary(jnp.right_shift) - -# docs for functions above -bitwise_and.__doc__ = ''' - Compute the bit-wise AND of two arrays element-wise. - - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array -''' - -bitwise_or.__doc__ = ''' - Compute the bit-wise OR of two arrays element-wise. - - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array -''' - -bitwise_xor.__doc__ = ''' - Compute the bit-wise XOR of two arrays element-wise. - - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array -''' - -left_shift.__doc__ = ''' - Shift the bits of an integer to the left. - - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array -''' - -right_shift.__doc__ = ''' - Shift the bits of an integer to the right. - - Args: - x: array_like - y: array_like - - Returns: - jax.Array: an array -''' - - -# logic funcs (unary) -# ------------------- - -def wrap_logic_func_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected booleans, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -all = wrap_logic_func_unary(jnp.all) -any = wrap_logic_func_unary(jnp.any) -alltrue = all -sometrue = any -logical_not = wrap_logic_func_unary(jnp.logical_not) - -# docs for functions above -all.__doc__ = ''' - Test whether all array elements along a given axis evaluate to True. - - Args: - a: array_like - axis: int, optional - out: array, optional - keepdims: bool, optional - where: array_like of bool, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - -any.__doc__ = ''' - Test whether any array element along a given axis evaluates to True. - - Args: - a: array_like - axis: int, optional - out: array, optional - keepdims: bool, optional - where: array_like of bool, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - -logical_not.__doc__ = ''' - Compute the truth value of NOT x element-wise. - - Args: - x: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - - -# logic funcs (binary) -# -------------------- - -def wrap_logic_func_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return func(x.value, y.value, *args, **kwargs) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -equal = wrap_logic_func_binary(jnp.equal) -not_equal = wrap_logic_func_binary(jnp.not_equal) -greater = wrap_logic_func_binary(jnp.greater) -greater_equal = wrap_logic_func_binary(jnp.greater_equal) -less = wrap_logic_func_binary(jnp.less) -less_equal = wrap_logic_func_binary(jnp.less_equal) -array_equal = wrap_logic_func_binary(jnp.array_equal) -isclose = wrap_logic_func_binary(jnp.isclose) -allclose = wrap_logic_func_binary(jnp.allclose) -logical_and = wrap_logic_func_binary(jnp.logical_and) - -logical_or = wrap_logic_func_binary(jnp.logical_or) -logical_xor = wrap_logic_func_binary(jnp.logical_xor) - -# docs for functions above -equal.__doc__ = ''' - Return (x == y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -not_equal.__doc__ = ''' - Return (x != y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -greater.__doc__ = ''' - Return (x > y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -greater_equal.__doc__ = ''' - Return (x >= y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -less.__doc__ = ''' - Return (x < y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -less_equal.__doc__ = ''' - Return (x <= y) element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like, Quantity - y: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -array_equal.__doc__ = ''' - Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. - - Args: - x1: array_like, Quantity - x2: array_like, Quantity - - Returns: - Union[bool, jax.Array]: bool or array -''' - -isclose.__doc__ = ''' - Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. - - Args: - a: array_like, Quantity - b: array_like, Quantity - rtol: float, optional - atol: float, optional - equal_nan: bool, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - -allclose.__doc__ = ''' - Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. - - Args: - a: array_like, Quantity - b: array_like, Quantity - rtol: float, optional - atol: float, optional - equal_nan: bool, optional - - Returns: - bool: boolean result -''' - -logical_and.__doc__ = ''' - Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - -logical_or.__doc__ = ''' - Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - -logical_xor.__doc__ = ''' - Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. - - Args: - x: array_like - y: array_like - out: array, optional - - Returns: - Union[bool, jax.Array]: bool or array -''' - - -# indexing funcs -# -------------- -@set_module_as('brainunit.math') -def where(condition: Union[bool, bst.typing.ArrayLike], - *args: Union[Quantity, bst.typing.ArrayLike], - **kwds) -> Union[Quantity, jax.Array]: - condition = jnp.asarray(condition) - if len(args) == 0: - # nothing to do - return jnp.where(condition, *args, **kwds) - elif len(args) == 2: - # check that x and y have the same dimensions - fail_for_dimension_mismatch( - args[0], args[1], "x and y need to have the same dimensions" - ) - new_args = [] - for arg in args: - if isinstance(arg, Quantity): - new_args.append(arg.value) - if is_unitless(args[0]): - if len(new_args) == 2: - return jnp.where(condition, *new_args, **kwds) - else: - return jnp.where(condition, *args, **kwds) - else: - # as both arguments have the same unit, just use the first one's - dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] - return Quantity.with_units( - jnp.where(condition, *dimensionless_args), args[0].unit - ) - else: - # illegal number of arguments - if len(args) == 1: - raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") - elif len(args) > 2: - raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) - - -tril_indices = jnp.tril_indices -tril_indices.__doc__ = ''' - Return the indices for the lower-triangle of an (n, m) array. - - Args: - n: int - m: int - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] -''' - - -@set_module_as('brainunit.math') -def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], - k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: - ''' - Return the indices for the lower-triangle of an (n, m) array. - - Args: - arr: array_like, Quantity - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] - ''' - if isinstance(arr, Quantity): - return jnp.tril_indices_from(arr.value, k=k) - else: - return jnp.tril_indices_from(arr, k=k) - - -triu_indices = jnp.triu_indices -triu_indices.__doc__ = ''' - Return the indices for the upper-triangle of an (n, m) array. - - Args: - n: int - m: int - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] -''' - - -@set_module_as('brainunit.math') -def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], - k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: - ''' - Return the indices for the upper-triangle of an (n, m) array. - - Args: - arr: array_like, Quantity - k: int, optional - - Returns: - tuple[jax.Array]: tuple[array] - ''' - if isinstance(arr, Quantity): - return jnp.triu_indices_from(arr.value, k=k) - else: - return jnp.triu_indices_from(arr, k=k) - - -@set_module_as('brainunit.math') -def take(a: Union[Quantity, bst.typing.ArrayLike], - indices: Union[Quantity, bst.typing.ArrayLike], - axis: Optional[int] = None, - mode: Optional[str] = None) -> Union[Quantity, jax.Array]: - if isinstance(a, Quantity): - return a.take(indices, axis=axis, mode=mode) - else: - return jnp.take(a, indices, axis=axis, mode=mode) - - -@set_module_as('brainunit.math') -def select(condlist: list[Union[bst.typing.ArrayLike]], - choicelist: Union[Quantity, bst.typing.ArrayLike], - default: int = 0) -> Union[Quantity, jax.Array]: - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.unit != choicelist[0].unit for choice in choicelist): - raise ValueError("All choices must have the same unit") - else: - return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), - unit=choicelist[0].unit) - elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): - return jnp.select(condlist, choicelist, default=default) - else: - raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") - - -# window funcs -# ------------ - -def wrap_window_funcs(func): - def f(*args, **kwargs): - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -bartlett = wrap_window_funcs(jnp.bartlett) -blackman = wrap_window_funcs(jnp.blackman) -hamming = wrap_window_funcs(jnp.hamming) -hanning = wrap_window_funcs(jnp.hanning) -kaiser = wrap_window_funcs(jnp.kaiser) - -# docs for functions above -bartlett.__doc__ = jnp.bartlett.__doc__ -blackman.__doc__ = jnp.blackman.__doc__ -hamming.__doc__ = jnp.hamming.__doc__ -hanning.__doc__ = jnp.hanning.__doc__ -kaiser.__doc__ = jnp.kaiser.__doc__ - -# constants -# --------- -e = jnp.e -pi = jnp.pi -inf = jnp.inf - -# linear algebra -# -------------- -dot = wrap_math_funcs_change_unit_binary(jnp.dot, lambda x, y: x * y) -vdot = wrap_math_funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y) -inner = wrap_math_funcs_change_unit_binary(jnp.inner, lambda x, y: x * y) -outer = wrap_math_funcs_change_unit_binary(jnp.outer, lambda x, y: x * y) -kron = wrap_math_funcs_change_unit_binary(jnp.kron, lambda x, y: x * y) -matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) -trace = wrap_math_funcs_keep_unit_unary(jnp.trace) - -# docs for functions above -dot.__doc__ = ''' - Dot product of two arrays or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. -''' - -vdot.__doc__ = ''' - Return the dot product of two vectors or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' - -inner.__doc__ = ''' - Inner product of two arrays or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' - -outer.__doc__ = ''' - Compute the outer product of two vectors or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' - -kron.__doc__ = ''' - Compute the Kronecker product of two arrays or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' - -matmul.__doc__ = ''' - Matrix product of two arrays or quantities. - - Args: - a: array_like, Quantity - b: array_like, Quantity - - Returns: - Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. -''' - -trace.__doc__ = ''' - Return the sum of the diagonal elements of a matrix or quantity. - - Args: - a: array_like, Quantity - offset: int, optional - - Returns: - Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. -''' - -# data types -# ---------- -dtype = jnp.dtype - - -@set_module_as('brainunit.math') -def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: - if isinstance(a, Quantity): - return jnp.finfo(a.value) - else: - return jnp.finfo(a) - - -@set_module_as('brainunit.math') -def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: - if isinstance(a, Quantity): - return jnp.iinfo(a.value) - else: - return jnp.iinfo(a) - - -# more -# ---- -@set_module_as('brainunit.math') -def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.unit != args[0].unit for arg in args): - raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit) - elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): - return jnp.broadcast_arrays(*args) - else: - raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") - - -broadcast_shapes = jnp.broadcast_shapes - - -@set_module_as('brainunit.math') -def einsum( - subscripts: str, - /, - *operands: Union[Quantity, jax.Array], - out: None = None, - optimize: Union[str, bool] = "optimal", - precision: jax.lax.PrecisionLike = None, - preferred_element_type: Union[jax.typing.DTypeLike, None] = None, - _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, -) -> Union[jax.Array, Quantity]: - ''' - Evaluates the Einstein summation convention on the operands. - - Args: - subscripts: string containing axes names separated by commas. - *operands: sequence of one or more arrays or quantities corresponding to the subscripts. - optimize: determine whether to optimize the order of computation. In JAX - this defaults to ``"optimize"`` which produces optimized expressions via - the opt_einsum_ package. - precision: either ``None`` (default), which means the default precision for - the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, - ``Precision.HIGH`` or ``Precision.HIGHEST``). - preferred_element_type: either ``None`` (default), which means the default - accumulation type for the input types, or a datatype, indicating to - accumulate results to and return a result with that datatype. - out: unsupported by JAX - _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. - This parameter is experimental, and may be removed without warning at any time. - - Returns: - array containing the result of the einstein summation. - ''' - operands = (subscripts, *operands) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") - spec = operands[0] if isinstance(operands[0], str) else None - optimize = 'optimal' if optimize is True else optimize - - # Allow handling of shape polymorphism - non_constant_dim_types = { - type(d) for op in operands if not isinstance(op, str) - for d in np.shape(op) if not jax.core.is_constant_dim(d) - } - if not non_constant_dim_types: - contract_path = opt_einsum.contract_path - else: - from jax._src.numpy.lax_numpy import _default_poly_einsum_handler - contract_path = _default_poly_einsum_handler - - operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) - - unit = None - for i in range(len(contractions) - 1): - if contractions[i][4] == 'False': - - fail_for_dimension_mismatch( - Quantity([], unit=unit), operands[i + 1], 'einsum' - ) - elif contractions[i][4] == 'DOT' or \ - contractions[i][4] == 'TDOT' or \ - contractions[i][4] == 'GEMM' or \ - contractions[i][4] == 'OUTER/EINSUM': - if i == 0: - if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): - unit = operands[i].unit * operands[i + 1].unit - elif isinstance(operands[i], Quantity): - unit = operands[i].unit - elif isinstance(operands[i + 1], Quantity): - unit = operands[i + 1].unit - else: - if isinstance(operands[i + 1], Quantity): - unit = unit * operands[i + 1].unit - - contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - - einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) - if spec is not None: - einsum = jax.named_call(einsum, name=spec) - operands = [op.value if isinstance(op, Quantity) else op for op in operands] - r = einsum(operands, contractions, precision, # type: ignore[operator] - preferred_element_type, _dot_general) - if unit is not None: - return Quantity(r, unit=unit) - else: - return r - - -@set_module_as('brainunit.math') -def gradient( - f: Union[bst.typing.ArrayLike, Quantity], - *varargs: Union[bst.typing.ArrayLike, Quantity], - axis: Union[int, Sequence[int], None] = None, - edge_order: Union[int, None] = None, -) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: - ''' - Computes the gradient of a scalar field. - - Args: - f: input array. - *varargs: list of scalar fields to compute the gradient. - axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. - edge_order: order of the edge used for the finite difference computation. The default is 1. - - Returns: - array containing the gradient of the scalar field. - ''' - if edge_order is not None: - raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") - - if len(varargs) == 0: - if isinstance(f, Quantity) and not is_unitless(f): - return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit) - else: - return jnp.gradient(f) - elif len(varargs) == 1: - unit = get_unit(f) / get_unit(varargs[0]) - if unit is None or unit == DIMENSIONLESS: - return jnp.gradient(f, varargs[0], axis=axis) - else: - return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] - else: - unit_list = [get_unit(f) / get_unit(v) for v in varargs] - f = f.value if isinstance(f, Quantity) else f - varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] - result_list = jnp.gradient(f, *varargs, axis=axis) - return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] - - -@set_module_as('brainunit.math') -def intersect1d( - ar1: Union[bst.typing.ArrayLike], - ar2: Union[bst.typing.ArrayLike], - assume_unique: bool = False, - return_indices: bool = False -) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: - ''' - Find the intersection of two arrays. - - Args: - ar1: input array. - ar2: input array. - assume_unique: if True, the input arrays are both assumed to be unique. - return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. - - Returns: - array containing the intersection of the two arrays. - ''' - fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') - unit = None - if isinstance(ar1, Quantity): - unit = ar1.unit - ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 - ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 - result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - if return_indices: - if unit is not None: - return (Quantity(result[0], unit=unit), result[1], result[2]) - else: - return result - else: - if unit is not None: - return Quantity(result, unit=unit) - else: - return result - - -nan_to_num = wrap_math_funcs_keep_unit_unary(jnp.nan_to_num) -nanargmax = _compatible_with_quantity(jnp.nanargmax, return_quantity=False) -nanargmin = _compatible_with_quantity(jnp.nanargmin, return_quantity=False) - -rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) -tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) - -# docs for functions above -nan_to_num.__doc__ = ''' - Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. - - Args: - x: input array. - nan: value to replace NaNs with. - posinf: value to replace positive infinity with. - neginf: value to replace negative infinity with. - - Returns: - array with NaNs replaced by zero and infinities replaced by large finite numbers. -''' - -nanargmax.__doc__ = ''' - Return the index of the maximum value in an array, ignoring NaNs. - - Args: - a: array like, Quantity. - axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. - out: output array, optional. - keepdims: if True, the result is broadcast to the input array with the same number of dimensions. - - Returns: - index of the maximum value in the array. -''' - -nanargmin.__doc__ = ''' - Return the index of the minimum value in an array, ignoring NaNs. - - Args: - a: array like, Quantity. - axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. - out: output array, optional. - keepdims: if True, the result is broadcast to the input array with the same number of dimensions. - - Returns: - index of the minimum value in the array. -''' - -rot90.__doc__ = ''' - Rotate an array by 90 degrees in the plane specified by axes. - - Args: - m: array like, Quantity. - k: number of times the array is rotated by 90 degrees. - axes: plane of rotation. Default is the last two axes. - - Returns: - rotated array. -''' - -tensordot.__doc__ = ''' - Compute tensor dot product along specified axes for arrays. - - Args: - a: array like, Quantity. - b: array like, Quantity. - axes: axes along which to compute the tensor dot product. - - Returns: - tensor dot product of the two arrays. -''' diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py new file mode 100644 index 0000000..9080502 --- /dev/null +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -0,0 +1,720 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional, Any) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as +from jax import Array + +from .._base import (DIMENSIONLESS, + Quantity, + Unit, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + # array creation + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + 'array_split', 'meshgrid', 'vander', +] + + +def wrap_array_creation_function(func: Callable) -> Callable: + @wraps(func) + def f(*args, unit: Unit = None, **kwargs): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return func(*args, **kwargs) * unit + else: + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_array_creation_function +def full(shape: Sequence[int], + fill_value: Any, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.full(shape, fill_value, dtype=dtype) + + +@wrap_array_creation_function +def eye(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.eye(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def identity(n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.identity(n, dtype=dtype) + + +@wrap_array_creation_function +def tri(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.tri(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def empty(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.empty(shape, dtype=dtype) + + +@wrap_array_creation_function +def ones(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.ones(shape, dtype=dtype) + + +@wrap_array_creation_function +def zeros(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.zeros(shape, dtype=dtype) + + +full.__doc__ = ''' + Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `shape` filled with `fill_value`. + + Args: + shape: sequence of integers, describing the shape of the output array. + fill_value: the value to fill the new array with. + dtype: the type of the output array, or `None`. If not `None`, `fill_value` + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + +eye.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +identity.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +tri.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. + else return a triangular matrix of `shape`. + + Args: + n: the number of rows in the output array. + m: the number of columns with default being `n`. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# empty +empty.__doc__ = """ + Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `shape` with uninitialized values. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be of type `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# ones +ones.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. + else return an array of `shape` filled with 1. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# zeros +zeros.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. + else return an array of `shape` filled with 0. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + + +@set_module_as('brainunit.math') +def full_like(a: Union[Quantity, bst.typing.ArrayLike], + fill_value: Union[bst.typing.ArrayLike], + unit: Unit = None, + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `a` filled with `fill_value`. + + Args: + a: array_like, Quantity, shape, or dtype + fill_value: scalar or array_like + unit: Unit, optional + dtype: data-type, optional + shape: sequence of ints, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def diag(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Extract a diagonal or construct a diagonal array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.diag(a.value, k=k) * unit + else: + return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) + + +@set_module_as('brainunit.math') +def tril(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Lower triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.tril(a.value, k=k) * unit + else: + return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) + + +@set_module_as('brainunit.math') +def triu(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Upper triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.triu(a.value, k=k) * unit + else: + return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) + + +@set_module_as('brainunit.math') +def empty_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `a` with uninitialized values. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def ones_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. + else return an array of `a` filled with 1. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. + else return an array of `a` filled with 0. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def asarray( + a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], + dtype: Optional[bst.typing.DTypeLike] = None, + order: Optional[str] = None, + unit: Optional[Unit] = None, +) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if isinstance(a, Quantity): + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), unit=a.unit) + elif isinstance(a, (jax.Array, np.ndarray)): + return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): + # check all elements have the same unit + if origin_any(x.unit != a[0].unit for x in a): + raise ValueError('Units do not match for asarray operation.') + values = [x.value for x in a] + unit = a[0].unit + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), unit=unit) + else: + return jnp.asarray(a, dtype=dtype, order=order) + + +array = asarray + + +@set_module_as('brainunit.math') +def arange(*args, **kwargs): + ''' + Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity, optional + stop: number, Quantity, optional + step: number, optional + dtype: dtype, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + # arange has a bit of a complicated argument structure unfortunately + # we leave the actual checking of the number of arguments to numpy, though + + # default values + start = kwargs.pop("start", 0) + step = kwargs.pop("step", 1) + stop = kwargs.pop("stop", None) + if len(args) == 1: + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + stop = args[0] + elif len(args) == 2: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + start, stop = args + elif len(args) == 3: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + if step != 1: + raise TypeError("Duplicate definition of 'step'") + start, stop, step = args + elif len(args) > 3: + raise TypeError("Need between 1 and 3 non-keyword arguments") + + if stop is None: + raise TypeError("Missing stop argument.") + if stop is not None and not is_unitless(stop): + start = Quantity(start, unit=stop.unit) + + fail_for_dimension_mismatch( + start, + stop, + error_message=( + "Start value {start} and stop value {stop} have to have the same units." + ), + start=start, + stop=stop, + ) + fail_for_dimension_mismatch( + stop, + step, + error_message=( + "Stop value {stop} and step value {step} have to have the same units." + ), + stop=stop, + step=step, + ) + unit = getattr(stop, "unit", DIMENSIONLESS) + # start is a position-only argument in numpy 2.0 + # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only + # TODO: check whether this is still the case in the final release + if start == 0: + return Quantity( + jnp.arange( + start=start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + unit=unit, + ) + else: + return Quantity( + jnp.arange( + start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + unit=unit, + ) + + +@set_module_as('brainunit.math') +def linspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: int = 50, + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + retstep: bool, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "unit", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) + return Quantity(result, unit=unit) + + +@set_module_as('brainunit.math') +def logspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: Optional[int] = 50, + endpoint: Optional[bool] = True, + base: Optional[float] = 10.0, + dtype: Optional[bst.typing.DTypeLike] = None): + ''' + Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + base: float, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "unit", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) + return Quantity(result, unit=unit) + + +@set_module_as('brainunit.math') +def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], + val: Union[Quantity, bst.typing.ArrayLike], + wrap: Optional[bool] = False, + inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: + ''' + Fill the main diagonal of the given array of `a` with `val`. + + Args: + a: array_like, Quantity + val: scalar, Quantity + wrap: bool, optional + inplace: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + ''' + if isinstance(a, Quantity) and isinstance(val, Quantity): + fail_for_dimension_mismatch(a, val) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) + elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + elif is_unitless(a) or is_unitless(val): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + else: + raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') + + +@set_module_as('brainunit.math') +def array_split(ary: Union[Quantity, bst.typing.ArrayLike], + indices_or_sections: Union[int, bst.typing.ArrayLike], + axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: + ''' + Split an array into multiple sub-arrays. + + Args: + ary: array_like, Quantity + indices_or_sections: int, array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. + ''' + if isinstance(ary, Quantity): + return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + elif isinstance(ary, bst.typing.ArrayLike): + return jnp.array_split(ary, indices_or_sections, axis) + else: + raise ValueError(f'Unsupported type: {type(ary)} for array_split') + + +@set_module_as('brainunit.math') +def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy'): + ''' + Return coordinate matrices from coordinate vectors. + + Args: + xi: array_like, Quantity + copy: bool, optional + sparse: bool, optional + indexing: str, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. + ''' + from builtins import all as origin_all + if origin_all(isinstance(x, Quantity) for x in xi): + fail_for_dimension_mismatch(*xi) + return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), unit=xi[0].unit) + elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): + return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) + else: + raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') + + +@set_module_as('brainunit.math') +def vander(x: Union[Quantity, bst.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: + ''' + Generate a Vandermonde matrix. + + Args: + x: array_like, Quantity + N: int, optional + increasing: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.vander(x, N=N, increasing=increasing) + else: + raise ValueError(f'Unsupported type: {type(x)} for vander') diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py new file mode 100644 index 0000000..dbfca5e --- /dev/null +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -0,0 +1,821 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Union, Optional, Tuple, List) + +import jax +import jax.numpy as jnp +from jax import Array + +from ._utils import _compatible_with_quantity +from .._base import (Quantity, + ) + +__all__ = [ + # array manipulation + 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', + 'diagflat', 'diagonal', 'choose', 'ravel', +] + + +# array manipulation +# ------------------ + + +@_compatible_with_quantity() +def reshape(a: Union[Array, Quantity], shape: Union[int, Tuple[int, ...]], order: str = 'C') -> Union[Array, Quantity]: + return jnp.reshape(a, shape, order) + + +@_compatible_with_quantity() +def moveaxis(a: Union[Array, Quantity], source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.moveaxis(a, source, destination) + + +@_compatible_with_quantity() +def transpose(a: Union[Array, Quantity], axes: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.transpose(a, axes) + + +@_compatible_with_quantity() +def swapaxes(a: Union[Array, Quantity], axis1: int, axis2: int) -> Union[Array, Quantity]: + return jnp.swapaxes(a, axis1, axis2) + + +@_compatible_with_quantity() +def concatenate(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.concatenate(arrays, axis) + + +@_compatible_with_quantity() +def stack(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: int = 0) -> Union[Array, Quantity]: + return jnp.stack(arrays, axis) + + +@_compatible_with_quantity() +def vstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.vstack(arrays) + + +row_stack = vstack + + +@_compatible_with_quantity() +def hstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.hstack(arrays) + + +@_compatible_with_quantity() +def dstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.dstack(arrays) + + +@_compatible_with_quantity() +def column_stack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.column_stack(arrays) + + +@_compatible_with_quantity() +def split(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]], axis: int = 0) -> Union[ + List[Array], List[Quantity]]: + return jnp.split(a, indices_or_sections, axis) + + +@_compatible_with_quantity() +def dsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.dsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def hsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.hsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def vsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.vsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def tile(A: Union[Array, Quantity], reps: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.tile(A, reps) + + +@_compatible_with_quantity() +def repeat(a: Union[Array, Quantity], repeats: Union[int, Tuple[int, ...]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.repeat(a, repeats, axis) + + +@_compatible_with_quantity() +def unique(a: Union[Array, Quantity], return_index: bool = False, return_inverse: bool = False, + return_counts: bool = False, axis: Optional[int] = None) -> Union[Array, Quantity]: + return jnp.unique(a, return_index, return_inverse, return_counts, axis) + + +@_compatible_with_quantity() +def append(arr: Union[Array, Quantity], values: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.append(arr, values, axis) + + +@_compatible_with_quantity() +def flip(m: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.flip(m, axis) + + +@_compatible_with_quantity() +def fliplr(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.fliplr(m) + + +@_compatible_with_quantity() +def flipud(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.flipud(m) + + +@_compatible_with_quantity() +def roll(a: Union[Array, Quantity], shift: Union[int, Tuple[int, ...]], + axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.roll(a, shift, axis) + + +@_compatible_with_quantity() +def atleast_1d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_1d(*arys) + + +@_compatible_with_quantity() +def atleast_2d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_2d(*arys) + + +@_compatible_with_quantity() +def atleast_3d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_3d(*arys) + + +@_compatible_with_quantity() +def expand_dims(a: Union[Array, Quantity], axis: int) -> Union[Array, Quantity]: + return jnp.expand_dims(a, axis) + + +@_compatible_with_quantity() +def squeeze(a: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.squeeze(a, axis) + + +@_compatible_with_quantity() +def sort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Union[Array, Quantity]: + return jnp.sort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def argsort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Array: + return jnp.argsort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def max(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.max(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def min(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.min(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def choose(a: Union[Array, Quantity], choices: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.choose(a, choices) + + +@_compatible_with_quantity() +def block(arrays: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.block(arrays) + + +@_compatible_with_quantity() +def compress(condition: Union[Array, Quantity], a: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.compress(condition, a, axis) + + +@_compatible_with_quantity() +def diagflat(v: Union[Array, Quantity], k: int = 0) -> Union[Array, Quantity]: + return jnp.diagflat(v, k) + + +# return jax.numpy.Array, not Quantity + +@_compatible_with_quantity(return_quantity=False) +def argmax(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmax(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argmin(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmin(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argwhere(a: Union[Array, Quantity]) -> Array: + return jnp.argwhere(a) + + +@_compatible_with_quantity(return_quantity=False) +def nonzero(a: Union[Array, Quantity]) -> Tuple[Array, ...]: + return jnp.nonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def flatnonzero(a: Union[Array, Quantity]) -> Array: + return jnp.flatnonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def searchsorted(a: Union[Array, Quantity], v: Union[Array, Quantity], side: str = 'left', + sorter: Optional[Array] = None) -> Array: + return jnp.searchsorted(a, v, side, sorter) + + +@_compatible_with_quantity(return_quantity=False) +def extract(condition: Union[Array, Quantity], arr: Union[Array, Quantity]) -> Array: + return jnp.extract(condition, arr) + + +@_compatible_with_quantity(return_quantity=False) +def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Array: + return jnp.count_nonzero(a, axis) + + +amax = max +amin = min + +# docs for the functions above +reshape.__doc__ = ''' + Return a reshaped copy of an array or a Quantity. + + Args: + a: input array or Quantity to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + brainunit does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. +''' + +moveaxis.__doc__ = ''' + Moves axes of an array to new positions. Other axes remain in their original order. + + Args: + a: array_like, Quantity + source: int or sequence of ints + destination: int or sequence of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +transpose.__doc__ = ''' + Returns a view of the array with axes transposed. + + Args: + a: array_like, Quantity + axes: tuple or list of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +swapaxes.__doc__ = ''' + Interchanges two axes of an array. + + Args: + a: array_like, Quantity + axis1: int + axis2: int + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +concatenate.__doc__ = ''' + Join a sequence of arrays along an existing axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +stack.__doc__ = ''' + Join a sequence of arrays along a new axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +vstack.__doc__ = ''' + Stack arrays in sequence vertically (row wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +hstack.__doc__ = ''' + Stack arrays in sequence horizontally (column wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +dstack.__doc__ = ''' + Stack arrays in sequence depth wise (along third axis). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +column_stack.__doc__ = ''' + Stack 1-D arrays as columns into a 2-D array. + + Args: + arrays: sequence of 1-D or 2-D array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +split.__doc__ = ''' + Split an array into multiple sub-arrays. + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +dsplit.__doc__ = ''' + Split array along third axis (depth). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +hsplit.__doc__ = ''' + Split an array into multiple sub-arrays horizontally (column-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +vsplit.__doc__ = ''' + Split an array into multiple sub-arrays vertically (row-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +tile.__doc__ = ''' + Construct an array by repeating A the number of times given by reps. + + Args: + A: array_like, Quantity + reps: array_like + + Returns: + Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array +''' + +repeat.__doc__ = ''' + Repeat elements of an array. + + Args: + a: array_like, Quantity + repeats: array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +unique.__doc__ = ''' + Find the unique elements of an array. + + Args: + a: array_like, Quantity + return_index: bool, optional + return_inverse: bool, optional + return_counts: bool, optional + axis: int or None, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +append.__doc__ = ''' + Append values to the end of an array. + + Args: + arr: array_like, Quantity + values: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array +''' + +flip.__doc__ = ''' + Reverse the order of elements in an array along the given axis. + + Args: + m: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +fliplr.__doc__ = ''' + Flip array in the left/right direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +flipud.__doc__ = ''' + Flip array in the up/down direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +roll.__doc__ = ''' + Roll array elements along a given axis. + + Args: + a: array_like, Quantity + shift: int or tuple of ints + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +atleast_1d.__doc__ = ''' + View inputs as arrays with at least one dimension. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_2d.__doc__ = ''' + View inputs as arrays with at least two dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_3d.__doc__ = ''' + View inputs as arrays with at least three dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +expand_dims.__doc__ = ''' + Expand the shape of an array. + + Args: + a: array_like, Quantity + axis: int or tuple of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +squeeze.__doc__ = ''' + Remove single-dimensional entries from the shape of an array. + + Args: + a: array_like, Quantity + axis: None or int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +sort.__doc__ = ''' + Return a sorted copy of an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' +max.__doc__ = ''' + Return the maximum of an array or maximum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +min.__doc__ = ''' + Return the minimum of an array or minimum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +choose.__doc__ = ''' + Use an index array to construct a new array from a set of choices. + + Args: + a: array_like, Quantity + choices: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array +''' + +block.__doc__ = ''' + Assemble an nd-array from nested lists of blocks. + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +compress.__doc__ = ''' + Return selected slices of an array along given axis. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +diagflat.__doc__ = ''' + Create a two-dimensional array with the flattened input as a diagonal. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +argsort.__doc__ = ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort'}, optional + order: str or list of str, optional + + Returns: + jax.Array jax.numpy.Array (does not return a Quantity) +''' + +argmax.__doc__ = ''' + Returns indices of the max value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argmin.__doc__ = ''' + Returns indices of the min value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argwhere.__doc__ = ''' + Find indices of non-zero elements. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +nonzero.__doc__ = ''' + Return the indices of the elements that are non-zero. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +flatnonzero.__doc__ = ''' + Return indices that are non-zero in the flattened version of a. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +searchsorted.__doc__ = ''' + Find indices where elements should be inserted to maintain order. + + Args: + a: array_like, Quantity + v: array_like, Quantity + side: {'left', 'right'}, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +extract.__doc__ = ''' + Return the elements of an array that satisfy some condition. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +count_nonzero.__doc__ = ''' + Counts the number of non-zero values in the array a. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + + +def wrap_function_to_method(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_function_to_method +def diagonal(a: Union[jax.Array, Quantity], offset: int = 0, axis1: int = 0, axis2: int = 1) -> Union[ + jax.Array, Quantity]: + return jnp.diagonal(a, offset, axis1, axis2) + + +@wrap_function_to_method +def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Quantity]: + return jnp.ravel(a, order) + + +diagonal.__doc__ = ''' + Return specified diagonals. + + Args: + a: array_like, Quantity + offset: int, optional + axis1: int, optional + axis2: int, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +ravel.__doc__ = ''' + Return a contiguous flattened array. + + Args: + a: array_like, Quantity + order: {'C', 'F', 'A', 'K'}, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py new file mode 100644 index 0000000..c87890a --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -0,0 +1,588 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs only accept unitless (unary) + 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', + 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', + 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + + # math funcs only accept unitless (binary) + 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', +] + + +# math funcs only accept unitless (unary) +# --------------------------------------- + +def wrap_math_funcs_only_accept_unitless_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + return func(jnp.array(x.value), *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_unary +def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def expm1(x: Union[Array, Quantity]) -> Array: + return jnp.expm1(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log(x: Union[Array, Quantity]) -> Array: + return jnp.log(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log10(x: Union[Array, Quantity]) -> Array: + return jnp.log10(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log1p(x: Union[Array, Quantity]) -> Array: + return jnp.log1p(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log2(x: Union[Array, Quantity]) -> Array: + return jnp.log2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccos(x: Union[Array, Quantity]) -> Array: + return jnp.arccos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccosh(x: Union[Array, Quantity]) -> Array: + return jnp.arccosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsin(x: Union[Array, Quantity]) -> Array: + return jnp.arcsin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsinh(x: Union[Array, Quantity]) -> Array: + return jnp.arcsinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctan(x: Union[Array, Quantity]) -> Array: + return jnp.arctan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctanh(x: Union[Array, Quantity]) -> Array: + return jnp.arctanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cos(x: Union[Array, Quantity]) -> Array: + return jnp.cos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cosh(x: Union[Array, Quantity]) -> Array: + return jnp.cosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sin(x: Union[Array, Quantity]) -> Array: + return jnp.sin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinc(x: Union[Array, Quantity]) -> Array: + return jnp.sinc(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinh(x: Union[Array, Quantity]) -> Array: + return jnp.sinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tan(x: Union[Array, Quantity]) -> Array: + return jnp.tan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tanh(x: Union[Array, Quantity]) -> Array: + return jnp.tanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def deg2rad(x: Union[Array, Quantity]) -> Array: + return jnp.deg2rad(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def rad2deg(x: Union[Array, Quantity]) -> Array: + return jnp.rad2deg(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def degrees(x: Union[Array, Quantity]) -> Array: + return jnp.degrees(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def radians(x: Union[Array, Quantity]) -> Array: + return jnp.radians(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def angle(x: Union[Array, Quantity]) -> Array: + return jnp.angle(x) + + +# docs for the functions above +exp.__doc__ = ''' + Calculate the exponential of all elements in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +exp2.__doc__ = ''' + Calculate 2 raised to the power of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +expm1.__doc__ = ''' + Calculate the exponential of the input elements minus 1. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log.__doc__ = ''' + Natural logarithm, element-wise. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log10.__doc__ = ''' + Base-10 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log1p.__doc__ = ''' + Natural logarithm of 1 + the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log2.__doc__ = ''' + Base-2 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccos.__doc__ = ''' + Compute the arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccosh.__doc__ = ''' + Compute the hyperbolic arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsin.__doc__ = ''' + Compute the arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsinh.__doc__ = ''' + Compute the hyperbolic arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan.__doc__ = ''' + Compute the arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctanh.__doc__ = ''' + Compute the hyperbolic arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cos.__doc__ = ''' + Compute the cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cosh.__doc__ = ''' + Compute the hyperbolic cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sin.__doc__ = ''' + Compute the sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinc.__doc__ = ''' + Compute the sinc function of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinh.__doc__ = ''' + Compute the hyperbolic sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tan.__doc__ = ''' + Compute the tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tanh.__doc__ = ''' + Compute the hyperbolic tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +deg2rad.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +rad2deg.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +degrees.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +radians.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +angle.__doc__ = ''' + Return the angle of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs only accept unitless (binary) +# ---------------------------------------- + +def wrap_math_funcs_only_accept_unitless_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + fail_for_dimension_mismatch( + y, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=y, + ) + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_binary +def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.hypot(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.arctan2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.percentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanpercentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.quantile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanquantile(a, q, *args, **kwargs) + + +# docs for the functions above +hypot.__doc__ = ''' + Given the “legs” of a right triangle, return its hypotenuse. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan2.__doc__ = ''' + Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp2.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs in base-2. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +percentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanpercentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +quantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanquantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py new file mode 100644 index 0000000..c48fa18 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -0,0 +1,182 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + ) + +__all__ = [ + + # Elementwise bit operations (unary) + 'bitwise_not', 'invert', + + # Elementwise bit operations (binary) + 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', +] + + +# Elementwise bit operations (unary) +# ---------------------------------- + +def wrap_elementwise_bit_operation_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected integers, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_unary +def bitwise_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_not(x) + + +@wrap_elementwise_bit_operation_unary +def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.invert(x) + + +# docs for functions above +bitwise_not.__doc__ = ''' + Compute the bit-wise NOT of an array, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + +invert.__doc__ = ''' + Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + + +# Elementwise bit operations (binary) +# ----------------------------------- + +def wrap_elementwise_bit_operation_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) or isinstance(y, Quantity): + raise ValueError(f'Expected integers, got {x} and {y}') + elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_binary +def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_and(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_or(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_xor(x, y) + + +@wrap_elementwise_bit_operation_binary +def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.left_shift(x, y) + + +@wrap_elementwise_bit_operation_binary +def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.right_shift(x, y) + + +# docs for functions above +bitwise_and.__doc__ = ''' + Compute the bit-wise AND of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_or.__doc__ = ''' + Compute the bit-wise OR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_xor.__doc__ = ''' + Compute the bit-wise XOR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +left_shift.__doc__ = ''' + Shift the bits of an integer to the left. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +right_shift.__doc__ = ''' + Shift the bits of an integer to the right. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py new file mode 100644 index 0000000..ced279e --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -0,0 +1,527 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from ._compat_numpy_get_attribute import isscalar +from .._base import (DIMENSIONLESS, + Quantity, + ) +from .._base import _return_check_unitless + +__all__ = [ + + # math funcs change unit (unary) + 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', + 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', + + # math funcs change unit (binary) + 'multiply', 'divide', 'power', 'cross', 'ldexp', + 'true_divide', 'floor_divide', 'float_power', + 'divmod', 'remainder', 'convolve', +] + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(change_unit_func: Callable) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + elif isinstance(x, (jnp.ndarray, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** -1) +def reciprocal(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.reciprocal(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def var(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.var(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def nanvar(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.nanvar(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x * 2 ** -1) +def frexp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.frexp(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 0.5) +def sqrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sqrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** (1 / 3)) +def cbrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cbrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.square(x) + + +# docs for the functions above + +reciprocal.__doc__ = ''' + Return the reciprocal of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +var.__doc__ = ''' + Compute the variance along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +nanvar.__doc__ = ''' + Compute the variance along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +frexp.__doc__ = ''' + Decompose a floating-point number into its mantissa and exponent. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. +''' + +sqrt.__doc__ = ''' + Compute the square root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. +''' + +cbrt.__doc__ = ''' + Compute the cube root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. +''' + +square.__doc__ = ''' + Compute the square of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + + +@set_module_as('brainunit.math') +def prod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: Optional[bool] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None, + promote_integers: bool = True) -> Union[Quantity, jax.Array]: + ''' + Return the product of array elements over a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + promote_integers: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + else: + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + + +@set_module_as('brainunit.math') +def nanprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: bool = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None): + ''' + Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + + +product = prod + + +@set_module_as('brainunit.math') +def cumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('brainunit.math') +def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(change_unit_func): + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + return decorator + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def multiply(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.multiply(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def cross(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.cross(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * 2 ** y) +def ldexp(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.ldexp(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def true_divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.true_divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divmod(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divmod(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.convolve(x, y) + + +# docs for the functions above +multiply.__doc__ = ''' + Multiply arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +divide.__doc__ = ''' + Divide arguments element-wise. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +cross.__doc__ = ''' + Return the cross product of two (arrays of) vectors. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +ldexp.__doc__ = ''' + Return x1 * 2**x2, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. +''' + +true_divide.__doc__ = ''' + Returns a true division of the inputs, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +divmod.__doc__ = ''' + Return element-wise quotient and remainder simultaneously. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +convolve.__doc__ = ''' + Returns the discrete, linear convolution of two one-dimensional sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + + +@set_module_as('brainunit.math') +def power(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.power(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + + +@set_module_as('brainunit.math') +def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return the largest integer smaller or equal to the division of the inputs. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.floor_divide(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('brainunit.math') +def float_power(x: Union[Quantity, bst.typing.ArrayLike], + y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(y, Quantity): + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y)) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.float_power(x, y) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + + +@set_module_as('brainunit.math') +def remainder(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return element-wise remainder of division. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.remainder(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py new file mode 100644 index 0000000..bf21d75 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -0,0 +1,166 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + + # indexing funcs + 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', +] + + +# indexing funcs +# -------------- +@set_module_as('brainunit.math') +def where(condition: Union[bool, bst.typing.ArrayLike], + *args: Union[Quantity, bst.typing.ArrayLike], + **kwds) -> Union[Quantity, jax.Array]: + condition = jnp.asarray(condition) + if len(args) == 0: + # nothing to do + return jnp.where(condition, *args, **kwds) + elif len(args) == 2: + # check that x and y have the same dimensions + fail_for_dimension_mismatch( + args[0], args[1], "x and y need to have the same dimensions" + ) + new_args = [] + for arg in args: + if isinstance(arg, Quantity): + new_args.append(arg.value) + if is_unitless(args[0]): + if len(new_args) == 2: + return jnp.where(condition, *new_args, **kwds) + else: + return jnp.where(condition, *args, **kwds) + else: + # as both arguments have the same unit, just use the first one's + dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] + return Quantity.with_units( + jnp.where(condition, *dimensionless_args), args[0].unit + ) + else: + # illegal number of arguments + if len(args) == 1: + raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") + elif len(args) > 2: + raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) + + +tril_indices = jnp.tril_indices +tril_indices.__doc__ = ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.tril_indices_from(arr.value, k=k) + else: + return jnp.tril_indices_from(arr, k=k) + + +triu_indices = jnp.triu_indices +triu_indices.__doc__ = ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.triu_indices_from(arr.value, k=k) + else: + return jnp.triu_indices_from(arr, k=k) + + +@set_module_as('brainunit.math') +def take(a: Union[Quantity, bst.typing.ArrayLike], + indices: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None) -> Union[Quantity, jax.Array]: + if isinstance(a, Quantity): + return a.take(indices, axis=axis, mode=mode) + else: + return jnp.take(a, indices, axis=axis, mode=mode) + + +@set_module_as('brainunit.math') +def select(condlist: list[Union[bst.typing.ArrayLike]], + choicelist: Union[Quantity, bst.typing.ArrayLike], + default: int = 0) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(choice, Quantity) for choice in choicelist): + if origin_any(choice.unit != choicelist[0].unit for choice in choicelist): + raise ValueError("All choices must have the same unit") + else: + return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), + unit=choicelist[0].unit) + elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): + return jnp.select(condlist, choicelist, default=default) + else: + raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py new file mode 100644 index 0000000..b11f4c4 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -0,0 +1,832 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # math funcs keep unit (unary) + 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', + 'abs', 'round', 'around', 'round_', 'rint', + 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', + 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', + 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', + 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', + + # math funcs keep unit (binary) + 'fmod', 'mod', 'copysign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', + + # math funcs keep unit (n-ary) + 'interp', 'clip', +] + + +# math funcs keep unit (unary) +# ---------------------------- + +def wrap_math_funcs_keep_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_unary +def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.real(x) + + +@wrap_math_funcs_keep_unit_unary +def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.imag(x) + + +@wrap_math_funcs_keep_unit_unary +def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conj(x) + + +@wrap_math_funcs_keep_unit_unary +def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conjugate(x) + + +@wrap_math_funcs_keep_unit_unary +def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.negative(x) + + +@wrap_math_funcs_keep_unit_unary +def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.positive(x) + + +@wrap_math_funcs_keep_unit_unary +def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.abs(x) + + +@wrap_math_funcs_keep_unit_unary +def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.around(x) + + +@wrap_math_funcs_keep_unit_unary +def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.rint(x) + + +@wrap_math_funcs_keep_unit_unary +def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.floor(x) + + +@wrap_math_funcs_keep_unit_unary +def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ceil(x) + + +@wrap_math_funcs_keep_unit_unary +def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.trunc(x) + + +@wrap_math_funcs_keep_unit_unary +def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fix(x) + + +@wrap_math_funcs_keep_unit_unary +def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sum(x) + + +@wrap_math_funcs_keep_unit_unary +def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nancumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nansum(x) + + +@wrap_math_funcs_keep_unit_unary +def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ediff1d(x) + + +@wrap_math_funcs_keep_unit_unary +def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.absolute(x) + + +@wrap_math_funcs_keep_unit_unary +def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fabs(x) + + +@wrap_math_funcs_keep_unit_unary +def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.median(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmin(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmax(x) + + +@wrap_math_funcs_keep_unit_unary +def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ptp(x) + + +@wrap_math_funcs_keep_unit_unary +def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.average(x) + + +@wrap_math_funcs_keep_unit_unary +def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.mean(x) + + +@wrap_math_funcs_keep_unit_unary +def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.std(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmedian(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmean(x) + + +@wrap_math_funcs_keep_unit_unary +def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanstd(x) + + +@wrap_math_funcs_keep_unit_unary +def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.diff(x) + + +@wrap_math_funcs_keep_unit_unary +def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.modf(x) + + +# docs for the functions above +real.__doc__ = ''' + Return the real part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +imag.__doc__ = ''' + Return the imaginary part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conj.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conjugate.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +negative.__doc__ = ''' + Return the negative of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +positive.__doc__ = ''' + Return the positive of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +abs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round_.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +around.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +rint.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +floor.__doc__ = ''' + Return the floor of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ceil.__doc__ = ''' + Return the ceiling of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +trunc.__doc__ = ''' + Return the truncated value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fix.__doc__ = ''' + Return the nearest integer towards zero. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +sum.__doc__ = ''' + Return the sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nancumsum.__doc__ = ''' + Return the cumulative sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nansum.__doc__ = ''' + Return the sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +cumsum.__doc__ = ''' + Return the cumulative sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ediff1d.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +absolute.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fabs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +median.__doc__ = ''' + Return the median of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmin.__doc__ = ''' + Return the minimum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmax.__doc__ = ''' + Return the maximum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ptp.__doc__ = ''' + Return the range of the array elements (maximum - minimum). + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +average.__doc__ = ''' + Return the weighted average of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +mean.__doc__ = ''' + Return the mean of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +std.__doc__ = ''' + Return the standard deviation of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmedian.__doc__ = ''' + Return the median of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmean.__doc__ = ''' + Return the mean of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanstd.__doc__ = ''' + Return the standard deviation of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +diff.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +modf.__doc__ = ''' + Return the fractional and integer parts of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. +''' + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + @wraps(func) + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_binary +def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.mod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.copysign(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.heaviside(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.maximum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.minimum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmax(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmin(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.lcm(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.gcd(x1, x2) + + +# docs for the functions above +fmod.__doc__ = ''' + Return the element-wise remainder of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +mod.__doc__ = ''' + Return the element-wise modulus of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +copysign.__doc__ = ''' + Return a copy of the first array elements with the sign of the second array. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +heaviside.__doc__ = ''' + Compute the Heaviside step function. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +maximum.__doc__ = ''' + Element-wise maximum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +minimum.__doc__ = ''' + Element-wise minimum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmax.__doc__ = ''' + Element-wise maximum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmin.__doc__ = ''' + Element-wise minimum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +lcm.__doc__ = ''' + Return the least common multiple of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +gcd.__doc__ = ''' + Return the greatest common divisor of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('brainunit.math') +def interp(x: Union[Quantity, bst.typing.ArrayLike], + xp: Union[Quantity, bst.typing.ArrayLike], + fp: Union[Quantity, bst.typing.ArrayLike], + left: Union[Quantity, bst.typing.ArrayLike] = None, + right: Union[Quantity, bst.typing.ArrayLike] = None, + period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: + ''' + One-dimensional linear interpolation. + + Args: + x: array_like, Quantity + xp: array_like, Quantity + fp: array_like, Quantity + left: array_like, Quantity, optional + right: array_like, Quantity, optional + period: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +@set_module_as('brainunit.math') +def clip(a: Union[Quantity, bst.typing.ArrayLike], + a_min: Union[Quantity, bst.typing.ArrayLike], + a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Clip (limit) the values in an array. + + Args: + a: array_like, Quantity + a_min: array_like, Quantity + a_max: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, unit=unit) + else: + return result diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py new file mode 100644 index 0000000..e7d69e7 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -0,0 +1,343 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # logic funcs (unary) + 'all', 'any', 'logical_not', + + # logic funcs (binary) + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', + 'logical_or', 'logical_xor', "alltrue", 'sometrue', +] + + +# logic funcs (unary) +# ------------------- + +def wrap_logic_func_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected booleans, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_unary +def all(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.all(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def any(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.any(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.logical_not(x) + + +alltrue = all +sometrue = any + +# docs for functions above +all.__doc__ = ''' + Test whether all array elements along a given axis evaluate to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +any.__doc__ = ''' + Test whether any array element along a given axis evaluates to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_not.__doc__ = ''' + Compute the truth value of NOT x element-wise. + + Args: + x: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + + +# logic funcs (binary) +# -------------------- + +def wrap_logic_func_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_binary +def equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.equal(x, y) + + +@wrap_logic_func_binary +def not_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.not_equal(x, y) + + +@wrap_logic_func_binary +def greater(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater(x, y) + + +@wrap_logic_func_binary +def greater_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater_equal(x, y) + + +@wrap_logic_func_binary +def less(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less(x, y) + + +@wrap_logic_func_binary +def less_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less_equal(x, y) + + +@wrap_logic_func_binary +def array_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.array_equal(x, y) + + +@wrap_logic_func_binary +def isclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def allclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def logical_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_and(x, y) + + +@wrap_logic_func_binary +def logical_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_or(x, y) + + +@wrap_logic_func_binary +def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_xor(x, y) + + +# docs for functions above +equal.__doc__ = ''' + Return (x == y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +not_equal.__doc__ = ''' + Return (x != y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater.__doc__ = ''' + Return (x > y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater_equal.__doc__ = ''' + Return (x >= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less.__doc__ = ''' + Return (x < y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less_equal.__doc__ = ''' + Return (x <= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +array_equal.__doc__ = ''' + Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +isclose.__doc__ = ''' + Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +allclose.__doc__ = ''' + Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + bool: boolean result +''' + +logical_and.__doc__ = ''' + Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_or.__doc__ = ''' + Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_xor.__doc__ = ''' + Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py new file mode 100644 index 0000000..b863d87 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -0,0 +1,108 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs match unit (binary) + 'add', 'subtract', 'nextafter', +] + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_match_unit_binary +def add(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.add(x, y) + + +@wrap_math_funcs_match_unit_binary +def subtract(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.subtract(x, y) + + +@wrap_math_funcs_match_unit_binary +def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.nextafter(x, y) + + +# docs for the functions above +add.__doc__ = ''' + Add arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +subtract.__doc__ = ''' + Subtract arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +nextafter.__doc__ = ''' + Return the next floating-point value after `x1` towards `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' diff --git a/brainunit/math/_compat_numpy_funcs_remove_unit.py b/brainunit/math/_compat_numpy_funcs_remove_unit.py new file mode 100644 index 0000000..afea533 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_remove_unit.py @@ -0,0 +1,191 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + ) + +__all__ = [ + + # math funcs remove unit (unary) + 'signbit', 'sign', 'histogram', 'bincount', + + # math funcs remove unit (binary) + 'corrcoef', 'correlate', 'cov', 'digitize', +] + + +# math funcs remove unit (unary) +# ------------------------------ +def wrap_math_funcs_remove_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return func(x.value, *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_unary +def signbit(x: Union[Array, Quantity]) -> Array: + return jnp.signbit(x) + + +@wrap_math_funcs_remove_unit_unary +def sign(x: Union[Array, Quantity]) -> Array: + return jnp.sign(x) + + +@wrap_math_funcs_remove_unit_unary +def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: + return jnp.histogram(x) + + +@wrap_math_funcs_remove_unit_unary +def bincount(x: Union[Array, Quantity]) -> Array: + return jnp.bincount(x) + + +# docs for the functions above +signbit.__doc__ = ''' + Returns element-wise True where signbit is set (less than zero). + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sign.__doc__ = ''' + Returns the sign of each element in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +histogram.__doc__ = ''' + Compute the histogram of a set of data. + + Args: + x: array_like, Quantity + + Returns: + tuple[jax.Array]: Tuple of arrays (hist, bin_edges) +''' + +bincount.__doc__ = ''' + Count number of occurrences of each value in array of non-negative integers. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs remove unit (binary) +# ------------------------------- +def wrap_math_funcs_remove_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_binary +def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.corrcoef(x, y) + + +@wrap_math_funcs_remove_unit_binary +def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.correlate(x, y) + + +@wrap_math_funcs_remove_unit_binary +def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: + return jnp.cov(x, y) + + +@wrap_math_funcs_remove_unit_binary +def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: + return jnp.digitize(x, bins) + + +# docs for the functions above +corrcoef.__doc__ = ''' + Return Pearson product-moment correlation coefficients. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +correlate.__doc__ = ''' + Cross-correlation of two sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cov.__doc__ = ''' + Covariance matrix. + + Args: + x: array_like, Quantity + y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) + + Returns: + jax.Array: an array +''' + +digitize.__doc__ = ''' + Return the indices of the bins to which each value in input array belongs. + + Args: + x: array_like, Quantity + bins: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py new file mode 100644 index 0000000..776450f --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -0,0 +1,69 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps + +import jax.numpy as jnp +from jax import Array + +__all__ = [ + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', +] + + +# window funcs +# ------------ + +def wrap_window_funcs(func): + @wraps(func) + def f(*args, **kwargs): + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_window_funcs +def bartlett(M: int) -> Array: + return jnp.bartlett(M) + + +@wrap_window_funcs +def blackman(M: int) -> Array: + return jnp.blackman(M) + + +@wrap_window_funcs +def hamming(M: int) -> Array: + return jnp.hamming(M) + + +@wrap_window_funcs +def hanning(M: int) -> Array: + return jnp.hanning(M) + + +@wrap_window_funcs +def kaiser(M: int, beta: float) -> Array: + return jnp.kaiser(M, beta) + + +# docs for functions above +bartlett.__doc__ = jnp.bartlett.__doc__ +blackman.__doc__ = jnp.blackman.__doc__ +hamming.__doc__ = jnp.hamming.__doc__ +hanning.__doc__ = jnp.hanning.__doc__ +kaiser.__doc__ = jnp.kaiser.__doc__ diff --git a/brainunit/math/_compat_numpy_get_attribute.py b/brainunit/math/_compat_numpy_get_attribute.py new file mode 100644 index 0000000..03bec0d --- /dev/null +++ b/brainunit/math/_compat_numpy_get_attribute.py @@ -0,0 +1,215 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # getting attribute funcs + 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', + 'isnan', 'shape', 'size', +] + + +@set_module_as('brainunit.math') +def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: + ''' + Return the number of dimensions of an array. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: int + ''' + if isinstance(a, Quantity): + return a.ndim + else: + return jnp.ndim(a) + + +@set_module_as('brainunit.math') +def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return True if the input array is real. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isreal + else: + return jnp.isreal(a) + + +@set_module_as('brainunit.math') +def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: + ''' + Return True if the input is a scalar. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isscalar + else: + return jnp.isscalar(a) + + +@set_module_as('brainunit.math') +def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is finite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isfinite + else: + return jnp.isfinite(a) + + +@set_module_as('brainunit.math') +def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is infinite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isinf + else: + return jnp.isinf(a) + + +@set_module_as('brainunit.math') +def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is NaN or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isnan + else: + return jnp.isnan(a) + + +@set_module_as('brainunit.math') +def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: + """ + Return the shape of an array. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also + -------- + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples + -------- + >>> brainunit.math.shape(brainunit.math.eye(3)) + (3, 3) + >>> brainunit.math.shape([[1, 3]]) + (1, 2) + >>> brainunit.math.shape([0]) + (1,) + >>> brainunit.math.shape(0) + () + + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) + + +@set_module_as('brainunit.math') +def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: + """ + Return the number of elements along a given axis. + + Parameters + ---------- + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns + ------- + element_count : int + Number of elements along the specified axis. + + See Also + -------- + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples + -------- + >>> a = Quantity([[1,2,3], [4,5,6]]) + >>> brainunit.math.size(a) + 6 + >>> brainunit.math.size(a, 1) + 3 + >>> brainunit.math.size(a, 0) + 2 + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] + else: + return np.size(a, axis=axis) diff --git a/brainunit/math/_compat_numpy_linear_algebra.py b/brainunit/math/_compat_numpy_linear_algebra.py new file mode 100644 index 0000000..88f27e9 --- /dev/null +++ b/brainunit/math/_compat_numpy_linear_algebra.py @@ -0,0 +1,149 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import jax.numpy as jnp +from jax import Array + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from .._base import (Quantity, + ) + +__all__ = [ + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + +] + + + + +# linear algebra +# -------------- + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def dot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.dot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.vdot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.inner(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.outer(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.kron(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.matmul(a, b) + + +@wrap_math_funcs_keep_unit_unary +def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.trace(a) + + +# docs for functions above +dot.__doc__ = ''' + Dot product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +vdot.__doc__ = ''' + Return the dot product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +inner.__doc__ = ''' + Inner product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +outer.__doc__ = ''' + Compute the outer product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +kron.__doc__ = ''' + Compute the Kronecker product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +matmul.__doc__ = ''' + Matrix product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +trace.__doc__ = ''' + Return the sum of the diagonal elements of a matrix or quantity. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. +''' diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py new file mode 100644 index 0000000..cebb5aa --- /dev/null +++ b/brainunit/math/_compat_numpy_misc.py @@ -0,0 +1,354 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from typing import (Callable, Union, Tuple) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +import opt_einsum +from brainstate._utils import set_module_as +from jax import Array +from jax._src.numpy.lax_numpy import _einsum + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from ._utils import _compatible_with_quantity +from .._base import (DIMENSIONLESS, + Quantity, + fail_for_dimension_mismatch, + is_unitless, + get_unit, ) + +__all__ = [ + + # constants + 'e', 'pi', 'inf', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'broadcast_arrays', 'broadcast_shapes', + 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', + 'rot90', 'tensordot', +] + +# constants +# --------- +e = jnp.e +pi = jnp.pi +inf = jnp.inf + +# data types +# ---------- +dtype = jnp.dtype + + +@set_module_as('brainunit.math') +def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: + if isinstance(a, Quantity): + return jnp.finfo(a.value) + else: + return jnp.finfo(a) + + +@set_module_as('brainunit.math') +def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: + if isinstance(a, Quantity): + return jnp.iinfo(a.value) + else: + return jnp.iinfo(a) + + +# more +# ---- +@set_module_as('brainunit.math') +def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, list[Array]]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(arg, Quantity) for arg in args): + if origin_any(arg.unit != args[0].unit for arg in args): + raise ValueError("All arguments must have the same unit") + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit) + elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): + return jnp.broadcast_arrays(*args) + else: + raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + + +broadcast_shapes = jnp.broadcast_shapes + + +@set_module_as('brainunit.math') +def einsum( + subscripts: str, + /, + *operands: Union[Quantity, jax.Array], + out: None = None, + optimize: Union[str, bool] = "optimal", + precision: jax.lax.PrecisionLike = None, + preferred_element_type: Union[jax.typing.DTypeLike, None] = None, + _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, +) -> Union[jax.Array, Quantity]: + ''' + Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays or quantities corresponding to the subscripts. + optimize: determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + ''' + operands = (subscripts, *operands) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") + spec = operands[0] if isinstance(operands[0], str) else None + optimize = 'optimal' if optimize is True else optimize + + # Allow handling of shape polymorphism + non_constant_dim_types = { + type(d) for op in operands if not isinstance(op, str) + for d in np.shape(op) if not jax.core.is_constant_dim(d) + } + if not non_constant_dim_types: + contract_path = opt_einsum.contract_path + else: + from jax._src.numpy.lax_numpy import _default_poly_einsum_handler + contract_path = _default_poly_einsum_handler + + operands, contractions = contract_path( + *operands, einsum_call=True, use_blas=True, optimize=optimize) + + unit = None + for i in range(len(contractions) - 1): + if contractions[i][4] == 'False': + + fail_for_dimension_mismatch( + Quantity([], unit=unit), operands[i + 1], 'einsum' + ) + elif contractions[i][4] == 'DOT' or \ + contractions[i][4] == 'TDOT' or \ + contractions[i][4] == 'GEMM' or \ + contractions[i][4] == 'OUTER/EINSUM': + if i == 0: + if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): + unit = operands[i].unit * operands[i + 1].unit + elif isinstance(operands[i], Quantity): + unit = operands[i].unit + elif isinstance(operands[i + 1], Quantity): + unit = operands[i + 1].unit + else: + if isinstance(operands[i + 1], Quantity): + unit = unit * operands[i + 1].unit + + contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + + einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + if spec is not None: + einsum = jax.named_call(einsum, name=spec) + operands = [op.value if isinstance(op, Quantity) else op for op in operands] + r = einsum(operands, contractions, precision, # type: ignore[operator] + preferred_element_type, _dot_general) + if unit is not None: + return Quantity(r, unit=unit) + else: + return r + + +@set_module_as('brainunit.math') +def gradient( + f: Union[bst.typing.ArrayLike, Quantity], + *varargs: Union[bst.typing.ArrayLike, Quantity], + axis: Union[int, Sequence[int], None] = None, + edge_order: Union[int, None] = None, +) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + ''' + Computes the gradient of a scalar field. + + Args: + f: input array. + *varargs: list of scalar fields to compute the gradient. + axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. + edge_order: order of the edge used for the finite difference computation. The default is 1. + + Returns: + array containing the gradient of the scalar field. + ''' + if edge_order is not None: + raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") + + if len(varargs) == 0: + if isinstance(f, Quantity) and not is_unitless(f): + return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit) + else: + return jnp.gradient(f) + elif len(varargs) == 1: + unit = get_unit(f) / get_unit(varargs[0]) + if unit is None or unit == DIMENSIONLESS: + return jnp.gradient(f, varargs[0], axis=axis) + else: + return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + else: + unit_list = [get_unit(f) / get_unit(v) for v in varargs] + f = f.value if isinstance(f, Quantity) else f + varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] + result_list = jnp.gradient(f, *varargs, axis=axis) + return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + + +@set_module_as('brainunit.math') +def intersect1d( + ar1: Union[bst.typing.ArrayLike], + ar2: Union[bst.typing.ArrayLike], + assume_unique: bool = False, + return_indices: bool = False +) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: + ''' + Find the intersection of two arrays. + + Args: + ar1: input array. + ar2: input array. + assume_unique: if True, the input arrays are both assumed to be unique. + return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. + + Returns: + array containing the intersection of the two arrays. + ''' + fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') + unit = None + if isinstance(ar1, Quantity): + unit = ar1.unit + ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 + ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 + result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + if return_indices: + if unit is not None: + return (Quantity(result[0], unit=unit), result[1], result[2]) + else: + return result + else: + if unit is not None: + return Quantity(result, unit=unit) + else: + return result + + +@wrap_math_funcs_keep_unit_unary +def nan_to_num(x: Union[bst.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, + neginf: float = -jnp.inf) -> Union[jax.Array, Quantity]: + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +@wrap_math_funcs_keep_unit_unary +def rot90(m: Union[bst.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ + jax.Array, Quantity]: + return jnp.rot90(m, k=k, axes=axes) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def tensordot(a: Union[bst.typing.ArrayLike, Quantity], b: Union[bst.typing.ArrayLike, Quantity], + axes: Union[int, Tuple[int, int]] = 2) -> Union[jax.Array, Quantity]: + return jnp.tensordot(a, b, axes=axes) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmax(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmax(a, axis=axis) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmin(a, axis=axis) + + +# docs for functions above +nan_to_num.__doc__ = ''' + Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. + + Args: + x: input array. + nan: value to replace NaNs with. + posinf: value to replace positive infinity with. + neginf: value to replace negative infinity with. + + Returns: + array with NaNs replaced by zero and infinities replaced by large finite numbers. +''' + +nanargmax.__doc__ = ''' + Return the index of the maximum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the maximum value in the array. +''' + +nanargmin.__doc__ = ''' + Return the index of the minimum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the minimum value in the array. +''' + +rot90.__doc__ = ''' + Rotate an array by 90 degrees in the plane specified by axes. + + Args: + m: array like, Quantity. + k: number of times the array is rotated by 90 degrees. + axes: plane of rotation. Default is the last two axes. + + Returns: + rotated array. +''' + +tensordot.__doc__ = ''' + Compute tensor dot product along specified axes for arrays. + + Args: + a: array like, Quantity. + b: array like, Quantity. + axes: axes along which to compute the tensor dot product. + + Returns: + tensor dot product of the two arrays. +''' diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 9cfec3e..95169d8 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -24,6 +24,7 @@ from brainunit import DimensionMismatchError from brainunit._base import Quantity from brainunit._unit_shortcuts import ms, mV +from brainunit._unit_common import second bst.environ.set(precision=64) @@ -44,6 +45,10 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) + q = bu.math.full(3, 4, unit=second) + self.assertEqual(q.shape, (3,)) + assert_quantity(q, result, second) + def test_eye(self): result = bu.math.eye(3) self.assertEqual(result.shape, (3, 3)) @@ -1706,7 +1711,7 @@ def test_argsort(self): q = [2, 3, 1] * bu.second result_q = bu.math.argsort(q) expected_q = jnp.argsort(jnp.array([2, 3, 1])) - assert jnp.all(result_q == expected_q) + assert_quantity(result_q, expected_q, bu.second) def test_argmax(self): array = jnp.array([2, 3, 1]) diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index ae66103..bf8f0cd 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -32,74 +32,64 @@ def _is_leaf(a): def _compatible_with_quantity( - fun: Callable, return_quantity: bool = True, - module: str = '' ): - func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun - - @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: - unit = None - if isinstance(args[0], Quantity): - unit = args[0].unit - elif isinstance(args[0], tuple): - if len(args[0]) == 1: - unit = args[0][0].unit if isinstance(args[0][0], Quantity) else None - elif len(args[0]) == 2: - # check all args[0] have the same unit - if all(isinstance(a, Quantity) for a in args[0]): - if all(a.unit == args[0][0].unit for a in args[0]): - unit = args[0][0].unit + def decorator(fun: Callable) -> Callable: + @functools.wraps(fun) + def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: + unit = None + if isinstance(args[0], Quantity): + unit = args[0].unit + elif isinstance(args[0], tuple): + if len(args[0]) == 1: + unit = args[0][0].unit if isinstance(args[0][0], Quantity) else None + elif len(args[0]) == 2: + # check all args[0] have the same unit + if all(isinstance(a, Quantity) for a in args[0]): + if all(a.unit == args[0][0].unit for a in args[0]): + unit = args[0][0].unit + else: + raise ValueError(f'Units do not match for {fun.__name__} operation.') + elif all(not isinstance(a, Quantity) for a in args[0]): + unit = None else: raise ValueError(f'Units do not match for {fun.__name__} operation.') - elif all(not isinstance(a, Quantity) for a in args[0]): - unit = None - else: - raise ValueError(f'Units do not match for {fun.__name__} operation.') - args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) - out = None - if len(kwargs): - # compatible with PyTorch syntax - if 'dim' in kwargs: - kwargs['axis'] = kwargs.pop('dim') - if 'keepdim' in kwargs: - kwargs['keepdims'] = kwargs.pop('keepdim') - # compatible with TensorFlow syntax - if 'keep_dims' in kwargs: - kwargs['keepdims'] = kwargs.pop('keep_dims') - # compatible with NumPy/PyTorch syntax - if 'out' in kwargs: - out = kwargs.pop('out') - if not isinstance(out, Quantity): - raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') - # format - kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None + if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + if 'keepdim' in kwargs: + kwargs['keepdims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keep_dims' in kwargs: + kwargs['keepdims'] = kwargs.pop('keep_dims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.pop('out') + if not isinstance(out, Quantity): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) - if not return_quantity: - unit = None + if not return_quantity: + unit = None - r = fun(*args, **kwargs) - if unit is not None: - if isinstance(r, (list, tuple)): - return [Quantity(rr, unit=unit) for rr in r] - else: - if out is None: - return Quantity(r, unit=unit) + r = fun(*args, **kwargs) + if unit is not None: + if isinstance(r, (list, tuple)): + return [Quantity(rr, unit=unit) for rr in r] else: - out.value = r - if out is None: - return r - else: - out.value = r + if out is None: + return Quantity(r, unit=unit) + else: + out.value = r + if out is None: + return r + else: + out.value = r - new_fun.__doc__ = ( - f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' - f'while it is compatible with brainpy Array/Variable. \n\n' - f'Note that this function is also compatible with:\n\n' - f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' - f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' - f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' - ) + return new_fun - return new_fun + return decorator diff --git a/docs/apis/brainunit.math.rst b/docs/apis/brainunit.math.rst index 2b303fc..7d3601d 100644 --- a/docs/apis/brainunit.math.rst +++ b/docs/apis/brainunit.math.rst @@ -1,8 +1,8 @@ ``brainunit.math`` module -========================== +========================= -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math +.. currentmodule:: brainunit.math +.. automodule:: brainunit.math Array Creation -------------- @@ -12,454 +12,387 @@ Array Creation :nosignatures: :template: classtemplate.rst - full - full_like - eye - identity - diag - tri - tril - triu - empty - empty_like - ones - ones_like - zeros - zeros_like - array - asarray - arange - linspace - logspace - fill_diagonal - array_split - meshgrid - vander - -Getting Attribute Funcs ------------------------ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst + full + full_like + eye + identity + diag + tri + tril + triu + empty + empty_like + ones + ones_like + zeros + zeros_like + array + asarray + arange + linspace + logspace + fill_diagonal + array_split + meshgrid + vander - ndim - isreal - isscalar - isfinite - isinf - isnan - shape - size -Math Funcs Keep Unit (Unary) ------------------------------ +Array Manipulation +------------------ .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - real - imag - conj - conjugate - negative - positive - abs - round - around - round_ - rint - floor - ceil - trunc - fix - sum - nancumsum - nansum - cumsum - ediff1d - absolute - fabs - median - nanmin - nanmax - ptp - average - mean - std - nanmedian - nanmean - nanstd - diff - modf - -Math Funcs Keep Unit (Binary) ------------------------------- + reshape + moveaxis + transpose + swapaxes + row_stack + concatenate + stack + vstack + hstack + dstack + column_stack + split + dsplit + hsplit + vsplit + tile + repeat + unique + append + flip + fliplr + flipud + roll + atleast_1d + atleast_2d + atleast_3d + expand_dims + squeeze + sort + argsort + argmax + argmin + argwhere + nonzero + flatnonzero + searchsorted + extract + count_nonzero + max + min + amax + amin + block + compress + diagflat + diagonal + choose + ravel + + +Functions Accepting Unitless +---------------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - fmod - mod - copysign - heaviside - maximum - minimum - fmax - fmin - lcm - gcd - -Math Funcs Keep Unit (N-ary) ------------------------------ + exp + exp2 + expm1 + log + log10 + log1p + log2 + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + cos + cosh + sin + sinc + sinh + tan + tanh + deg2rad + rad2deg + degrees + radians + angle + percentile + nanpercentile + quantile + nanquantile + hypot + arctan2 + logaddexp + logaddexp2 + + +Functions with Bitwise Operations +--------------------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - interp - clip + bitwise_not + invert + bitwise_and + bitwise_or + bitwise_xor + left_shift + right_shift -Math Funcs Match Unit (Binary) -------------------------------- -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - add - subtract - nextafter - -Math Funcs Change Unit (Unary) -------------------------------- +Functions Changing Unit +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - reciprocal - prod - product - nancumprod - nanprod - cumprod - cumproduct - var - nanvar - cbrt - square - frexp - sqrt - -Math Funcs Change Unit (Binary) --------------------------------- + reciprocal + prod + product + nancumprod + nanprod + cumprod + cumproduct + var + nanvar + cbrt + square + frexp + sqrt + multiply + divide + power + cross + ldexp + true_divide + floor_divide + float_power + divmod + remainder + convolve + + +Indexing Functions +------------------ .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - multiply - divide - power - cross - ldexp - true_divide - floor_divide - float_power - divmod - remainder - convolve - -Math Funcs Only Accept Unitless (Unary) ---------------------------------------- + where + tril_indices + tril_indices_from + triu_indices + triu_indices_from + take + select -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - exp - exp2 - expm1 - log - log10 - log1p - log2 - arccos - arccosh - arcsin - arcsinh - arctan - arctanh - cos - cosh - sin - sinc - sinh - tan - tanh - deg2rad - rad2deg - degrees - radians - angle - percentile - nanpercentile - quantile - nanquantile - -Math Funcs Only Accept Unitless (Binary) ----------------------------------------- +Functions Keeping Unit +---------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - hypot - arctan2 - logaddexp - logaddexp2 - -Math Funcs Remove Unit (Unary) -------------------------------- + real + imag + conj + conjugate + negative + positive + abs + round + around + round_ + rint + floor + ceil + trunc + fix + sum + nancumsum + nansum + cumsum + ediff1d + absolute + fabs + median + nanmin + nanmax + ptp + average + mean + std + nanmedian + nanmean + nanstd + diff + modf + fmod + mod + copysign + heaviside + maximum + minimum + fmax + fmin + lcm + gcd + interp + clip + + +Logical Functions +----------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - signbit - sign - histogram - bincount - -Math Funcs Remove Unit (Binary) --------------------------------- + all + any + logical_not + equal + not_equal + greater + greater_equal + less + less_equal + array_equal + isclose + allclose + logical_and + logical_or + logical_xor + alltrue + sometrue + + +Functions Matching Unit +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - corrcoef - correlate - cov - digitize + add + subtract + nextafter -Array Manipulation -------------------- -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - reshape - moveaxis - transpose - swapaxes - row_stack - concatenate - stack - vstack - hstack - dstack - column_stack - split - dsplit - hsplit - vsplit - tile - repeat - unique - append - flip - fliplr - flipud - roll - atleast_1d - atleast_2d - atleast_3d - expand_dims - squeeze - sort - argsort - argmax - argmin - argwhere - nonzero - flatnonzero - searchsorted - extract - count_nonzero - max - min - amax - amin - block - compress - diagflat - diagonal - choose - ravel - -Elementwise Bit Operations (Unary) ----------------------------------- +Functions Removing Unit +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - bitwise_not - invert + signbit + sign + histogram + bincount + corrcoef + correlate + cov + digitize -Elementwise Bit Operations (Binary) ------------------------------------ -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - bitwise_and - bitwise_or - bitwise_xor - left_shift - right_shift - -Logic Funcs (Unary) --------------------- +Window Functions +---------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - all - any - logical_not + bartlett + blackman + hamming + hanning + kaiser -Logic Funcs (Binary) ---------------------- -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - equal - not_equal - greater - greater_equal - less - less_equal - array_equal - isclose - allclose - logical_and - logical_or - logical_xor - alltrue - sometrue - -Indexing Funcs ---------------- +Get Attribute Functions +----------------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - nonzero - where - tril_indices - tril_indices_from - triu_indices - triu_indices_from - take - select + ndim + isreal + isscalar + isfinite + isinf + isnan + shape + size -Window Funcs -------------- -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - - bartlett - blackman - hamming - hanning - kaiser - -Constants ----------- +Linear Algebra Functions +------------------------ .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - e - pi - inf - -Linear Algebra ---------------- - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst + dot + vdot + inner + outer + kron + matmul + trace - dot - vdot - inner - outer - kron - matmul - trace -Data Types ------------ +More Functions +-------------- .. autosummary:: :toctree: generated/ :nosignatures: :template: classtemplate.rst - dtype - finfo - iinfo + finfo + iinfo + broadcast_arrays + broadcast_shapes + einsum + gradient + intersect1d + nan_to_num + nanargmax + nanargmin + rot90 + tensordot + dtype + e + pi + inf -More ------ - -.. autosummary:: - :toctree: generated/ - :nosignatures: - :template: classtemplate.rst - broadcast_arrays - broadcast_shapes - einsum - gradient - intersect1d - nan_to_num - nanargmax - nanargmin - rot90 - tensordot diff --git a/docs/auto_generater.py b/docs/auto_generater.py index b192b88..7b76528 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -226,7 +226,6 @@ def _write_subsections_v4(module_path, fout.write(f'.. currentmodule:: {out_path} \n') fout.write(f'.. automodule:: {out_path} \n\n') - fout.write('.. autosummary::\n') fout.write(' :toctree: generated/\n') fout.write(' :nosignatures:\n') @@ -319,14 +318,29 @@ def _section(header, numpy_mod, brainpy_mod, jax_mod, klass=None, is_jax=False): def main(): os.makedirs('apis/auto/', exist_ok=True) - _write_module(module_name='brainunit', - filename='apis/brainunit.math.rst', - header='``brainunit.init`` module') - - -if __name__ == '__main__': - main() - + module_and_name = [ + ('_compat_numpy_array_creation', 'Array Creation'), + ('_compat_numpy_array_manipulation', 'Array Manipulation'), + ('_compat_numpy_funcs_accept_unitless', 'Functions Accepting Unitless'), + ('_compat_numpy_funcs_bit_operation', 'Functions with Bitwise Operations'), + ('_compat_numpy_funcs_change_unit', 'Functions Changing Unit'), + ('_compat_numpy_funcs_indexing', 'Indexing Functions'), + ('_compat_numpy_funcs_keep_unit', 'Functions Keeping Unit'), + ('_compat_numpy_funcs_logic', 'Logical Functions'), + ('_compat_numpy_funcs_match_unit', 'Functions Matching Unit'), + ('_compat_numpy_funcs_remove_unit', 'Functions Removing Unit'), + ('_compat_numpy_funcs_window', 'Window Functions'), + ('_compat_numpy_get_attribute', 'Get Attribute Functions'), + ('_compat_numpy_linear_algebra', 'Linear Algebra Functions'), + ('_compat_numpy_misc', 'More Functions'), + ] + _write_submodules(module_name='brainunit.math', + filename='apis/brainunit.math.rst', + header='``brainunit.math`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) +if __name__ == '__main__': + main() From 337b3658ecb921c77b870953f8b232857cb3637a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:25:58 +0800 Subject: [PATCH 10/23] Update --- brainunit/math/__init__.py | 18 +++++++++-- .../math/_compat_numpy_array_creation.py | 30 +++++++++---------- .../math/_compat_numpy_array_manipulation.py | 2 +- .../math/_compat_numpy_funcs_change_unit.py | 28 ++++++++--------- .../math/_compat_numpy_funcs_indexing.py | 6 ++-- .../math/_compat_numpy_funcs_keep_unit.py | 12 ++++---- .../math/_compat_numpy_funcs_match_unit.py | 6 ++-- brainunit/math/_compat_numpy_misc.py | 26 ++++++++-------- 8 files changed, 70 insertions(+), 58 deletions(-) diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 03fd080..e574603 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -46,7 +46,6 @@ from ._compat_numpy_misc import * from ._compat_numpy_misc import __all__ as _compat_misc_all - __all__ = _compat_array_creation_all + \ _compat_array_manipulation_all + \ _compat_funcs_change_unit_all + \ @@ -63,5 +62,18 @@ _compat_misc_all + _other_all + \ _other_all -del _compat_array_creation_all, _compat_array_manipulation_all, _compat_funcs_change_unit_all, _compat_funcs_keep_unit_all, _compat_funcs_accept_unitless_all, _compat_funcs_match_unit_all, _compat_funcs_remove_unit_all, _compat_get_attribute_all, _compat_funcs_bit_operation_all, _compat_funcs_logic_all, _compat_funcs_indexing_all, _compat_funcs_window_all, _compat_linear_algebra_all, _compat_misc_all, _other_all - +del _compat_array_creation_all, \ + _compat_array_manipulation_all, \ + _compat_funcs_change_unit_all, \ + _compat_funcs_keep_unit_all, \ + _compat_funcs_accept_unitless_all, \ + _compat_funcs_match_unit_all, \ + _compat_funcs_remove_unit_all, \ + _compat_get_attribute_all, \ + _compat_funcs_bit_operation_all, \ + _compat_funcs_logic_all, \ + _compat_funcs_indexing_all, \ + _compat_funcs_window_all, \ + _compat_linear_algebra_all, \ + _compat_misc_all, \ + _other_all diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 9080502..4feb08d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -436,18 +436,18 @@ def asarray( from builtins import all as origin_all from builtins import any as origin_any if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), unit=a.unit) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): return jnp.asarray(a, dtype=dtype, order=order) # list[Quantity] elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): # check all elements have the same unit - if origin_any(x.unit != a[0].unit for x in a): + if origin_any(x.dim != a[0].dim for x in a): raise ValueError('Units do not match for asarray operation.') values = [x.value for x in a] - unit = a[0].unit + unit = a[0].dim # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), unit=unit) + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: return jnp.asarray(a, dtype=dtype, order=order) @@ -501,7 +501,7 @@ def arange(*args, **kwargs): if stop is None: raise TypeError("Missing stop argument.") if stop is not None and not is_unitless(stop): - start = Quantity(start, unit=stop.unit) + start = Quantity(start, dim=stop.dim) fail_for_dimension_mismatch( start, @@ -533,7 +533,7 @@ def arange(*args, **kwargs): step=step.value if isinstance(step, Quantity) else jnp.asarray(step), **kwargs, ), - unit=unit, + dim=unit, ) else: return Quantity( @@ -543,7 +543,7 @@ def arange(*args, **kwargs): step=step.value if isinstance(step, Quantity) else jnp.asarray(step), **kwargs, ), - unit=unit, + dim=unit, ) @@ -575,12 +575,12 @@ def linspace(start: Union[Quantity, bst.typing.ArrayLike], start=start, stop=stop, ) - unit = getattr(start, "unit", DIMENSIONLESS) + unit = getattr(start, "dim", DIMENSIONLESS) start = start.value if isinstance(start, Quantity) else start stop = stop.value if isinstance(stop, Quantity) else stop result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) @set_module_as('brainunit.math') @@ -611,12 +611,12 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], start=start, stop=stop, ) - unit = getattr(start, "unit", DIMENSIONLESS) + unit = getattr(start, "dim", DIMENSIONLESS) start = start.value if isinstance(start, Quantity) else start stop = stop.value if isinstance(stop, Quantity) else stop result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) @set_module_as('brainunit.math') @@ -638,7 +638,7 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], ''' if isinstance(a, Quantity) and isinstance(val, Quantity): fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), unit=a.unit) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) elif is_unitless(a) or is_unitless(val): @@ -663,7 +663,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. ''' if isinstance(ary, Quantity): - return [Quantity(x, unit=ary.unit) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] elif isinstance(ary, bst.typing.ArrayLike): return jnp.array_split(ary, indices_or_sections, axis) else: @@ -690,7 +690,7 @@ def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], from builtins import all as origin_all if origin_all(isinstance(x, Quantity) for x in xi): fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), unit=xi[0].unit) + return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) else: @@ -713,7 +713,7 @@ def vander(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. ''' if isinstance(x, Quantity): - return Quantity(jnp.vander(x.value, N=N, increasing=increasing), unit=x.unit) + return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)): return jnp.vander(x, N=N, increasing=increasing) else: diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py index dbfca5e..c4a7c26 100644 --- a/brainunit/math/_compat_numpy_array_manipulation.py +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -777,7 +777,7 @@ def wrap_function_to_method(func): @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) else: return func(x, *args, **kwargs) diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index ced279e..e649b14 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -49,7 +49,7 @@ def decorator(func: Callable) -> Callable: @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), unit=change_unit_func(x.unit))) + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dtype=change_unit_func(x.dim))) elif isinstance(x, (jnp.ndarray, np.ndarray)): return func(x, *args, **kwargs) else: @@ -298,16 +298,16 @@ def decorator(func: Callable) -> Callable: def f(x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), unit=change_unit_func(x.unit, y.unit)) + Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.dim, y.dim)) ) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return func(x, y, *args, **kwargs) elif isinstance(x, Quantity): return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), unit=change_unit_func(x.unit, DIMENSIONLESS))) + Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.dim, DIMENSIONLESS))) elif isinstance(y, Quantity): return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), unit=change_unit_func(DIMENSIONLESS, y.unit))) + Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.dim))) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') @@ -443,13 +443,13 @@ def power(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value), unit=x.unit ** y.unit)) + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.power(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.power(x.value, y), dim=x.dim ** y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value), unit=x ** y.unit)) + return _return_check_unitless(Quantity(jnp.power(x, y.value), dim=x ** y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') @@ -468,13 +468,13 @@ def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.floor_divide(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), unit=x.unit / y)) + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), dim=x.dim / y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), unit=x / y.unit)) + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), dim=x / y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') @@ -495,7 +495,7 @@ def float_power(x: Union[Quantity, bst.typing.ArrayLike], if isinstance(y, Quantity): assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' if isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y), unit=x.unit ** y)) + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), dim=x.dim ** y)) elif isinstance(x, (jax.Array, np.ndarray)): return jnp.float_power(x, y) else: @@ -516,12 +516,12 @@ def remainder(x: Union[Quantity, bst.typing.ArrayLike], Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. ''' if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), unit=x.unit / y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim)) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return jnp.remainder(x, y) elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y), unit=x.unit % y)) + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), dim=x.dim % y)) elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value), unit=x % y.unit)) + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), dim=x % y.dim)) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py index bf21d75..7f8d8fc 100644 --- a/brainunit/math/_compat_numpy_funcs_indexing.py +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -61,7 +61,7 @@ def where(condition: Union[bool, bst.typing.ArrayLike], # as both arguments have the same unit, just use the first one's dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] return Quantity.with_units( - jnp.where(condition, *dimensionless_args), args[0].unit + jnp.where(condition, *dimensionless_args), args[0].dim ) else: # illegal number of arguments @@ -155,11 +155,11 @@ def select(condlist: list[Union[bst.typing.ArrayLike]], from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.unit != choicelist[0].unit for choice in choicelist): + if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): raise ValueError("All choices must have the same unit") else: return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), - unit=choicelist[0].unit) + dim=choicelist[0].dim) elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): return jnp.select(condlist, choicelist, default=default) else: diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py index b11f4c4..4a6616e 100644 --- a/brainunit/math/_compat_numpy_funcs_keep_unit.py +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -49,7 +49,7 @@ def wrap_math_funcs_keep_unit_unary(func): @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)): return func(x, *args, **kwargs) else: @@ -578,7 +578,7 @@ def wrap_math_funcs_keep_unit_binary(func): @wraps(func) def f(x1, x2, *args, **kwargs): if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), unit=x1.unit) + return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): return func(x1, x2, *args, **kwargs) else: @@ -775,7 +775,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], ''' unit = None if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit + unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim if isinstance(x, Quantity): x_value = x.value else: @@ -790,7 +790,7 @@ def interp(x: Union[Quantity, bst.typing.ArrayLike], fp_value = fp result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result @@ -812,7 +812,7 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike], ''' unit = None if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit + unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim if isinstance(a, Quantity): a_value = a.value else: @@ -827,6 +827,6 @@ def clip(a: Union[Quantity, bst.typing.ArrayLike], a_max_value = a_max result = jnp.clip(a_value, a_min_value, a_max_value) if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py index b863d87..d9926ad 100644 --- a/brainunit/math/_compat_numpy_funcs_match_unit.py +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -38,17 +38,17 @@ def wrap_math_funcs_match_unit_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) and isinstance(y, Quantity): fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return func(x, y, *args, **kwargs) elif isinstance(x, Quantity): if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), unit=x.unit) + return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') elif isinstance(y, Quantity): if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), unit=y.unit) + return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) else: raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') else: diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index cebb5aa..4a26216 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -81,9 +81,9 @@ def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quan from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.unit != args[0].unit for arg in args): + if origin_any(arg.dim != args[0].dim for arg in args): raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), unit=args[0].unit) + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): return jnp.broadcast_arrays(*args) else: @@ -151,7 +151,7 @@ def einsum( if contractions[i][4] == 'False': fail_for_dimension_mismatch( - Quantity([], unit=unit), operands[i + 1], 'einsum' + Quantity([], dim=unit), operands[i + 1], 'einsum' ) elif contractions[i][4] == 'DOT' or \ contractions[i][4] == 'TDOT' or \ @@ -159,14 +159,14 @@ def einsum( contractions[i][4] == 'OUTER/EINSUM': if i == 0: if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): - unit = operands[i].unit * operands[i + 1].unit + unit = operands[i].dim * operands[i + 1].dim elif isinstance(operands[i], Quantity): - unit = operands[i].unit + unit = operands[i].dim elif isinstance(operands[i + 1], Quantity): - unit = operands[i + 1].unit + unit = operands[i + 1].dim else: if isinstance(operands[i + 1], Quantity): - unit = unit * operands[i + 1].unit + unit = unit * operands[i + 1].dim contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) @@ -177,7 +177,7 @@ def einsum( r = einsum(operands, contractions, precision, # type: ignore[operator] preferred_element_type, _dot_general) if unit is not None: - return Quantity(r, unit=unit) + return Quantity(r, dim=unit) else: return r @@ -206,7 +206,7 @@ def gradient( if len(varargs) == 0: if isinstance(f, Quantity) and not is_unitless(f): - return Quantity(jnp.gradient(f.value, axis=axis), unit=f.unit) + return Quantity(jnp.gradient(f.value, axis=axis), dim=f.dim) else: return jnp.gradient(f) elif len(varargs) == 1: @@ -214,13 +214,13 @@ def gradient( if unit is None or unit == DIMENSIONLESS: return jnp.gradient(f, varargs[0], axis=axis) else: - return [Quantity(r, unit=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] else: unit_list = [get_unit(f) / get_unit(v) for v in varargs] f = f.value if isinstance(f, Quantity) else f varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] result_list = jnp.gradient(f, *varargs, axis=axis) - return [Quantity(r, unit=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] @set_module_as('brainunit.math') @@ -251,12 +251,12 @@ def intersect1d( result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) if return_indices: if unit is not None: - return (Quantity(result[0], unit=unit), result[1], result[2]) + return (Quantity(result[0], dim=unit), result[1], result[2]) else: return result else: if unit is not None: - return Quantity(result, unit=unit) + return Quantity(result, dim=unit) else: return result From 6fc6adddd2f9b126d6307200341d8df7b595e8d4 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:38:06 +0800 Subject: [PATCH 11/23] Fix bugs --- brainunit/_unit_test.py | 30 +++++++++---------- .../math/_compat_numpy_array_creation.py | 2 +- .../math/_compat_numpy_funcs_change_unit.py | 2 +- brainunit/math/_compat_numpy_misc.py | 2 +- brainunit/math/_compat_numpy_test.py | 2 +- brainunit/math/_others.py | 2 +- 6 files changed, 20 insertions(+), 20 deletions(-) diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 5948c85..19095a9 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" + assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" if not jnp.allclose(q.value, values): raise AssertionError(f"Values do not match: {q.value} != {values}") elif isinstance(q, jnp.ndarray): @@ -144,10 +144,10 @@ def test_get_dimensions(): Test various ways of getting/comparing the dimensions of a Array. """ q = 500 * ms - assert get_unit(q) is get_or_create_dimension(q.unit._dims) - assert get_unit(q) is q.unit + assert get_unit(q) is get_or_create_dimension(q.dim._dims) + assert get_unit(q) is q.dim assert q.has_same_unit(3 * second) - dims = q.unit + dims = q.dim assert_equal(dims.get_dimension("time"), 1.0) assert_equal(dims.get_dimension("length"), 0) @@ -201,11 +201,11 @@ def test_unary_operations(): def test_operations(): - q1 = Quantity(5, dim=mV) - q2 = Quantity(10, dim=mV) - assert_quantity(q1 + q2, 15, mV) - assert_quantity(q1 - q2, -5, mV) - assert_quantity(q1 * q2, 50, mV * mV) + q1 = 5 * second + q2 = 10 * second + assert_quantity(q1 + q2, 15, second) + assert_quantity(q1 - q2, -5, second) + assert_quantity(q1 * q2, 50, second * second) assert_quantity(q2 / q1, 2, DIMENSIONLESS) assert_quantity(q2 // q1, 2, DIMENSIONLESS) assert_quantity(q2 % q1, 0, second) @@ -215,21 +215,21 @@ def test_operations(): assert_quantity(round(q1, 0), 5, second) # matmul - q1 = Quantity([1, 2], dim=mV) - q2 = Quantity([3, 4], dim=mV) - assert_quantity(q1 @ q2, 11, mV ** 2) + q1 = [1, 2] * second + q2 = [3, 4] * second + assert_quantity(q1 @ q2, 11, second ** 2) q1 = Quantity([1, 2], unit=second) q2 = Quantity([3, 4], unit=second) assert_quantity(q1 @ q2, 11, second ** 2) # shift - q1 = Quantity(0b1100, dtype=jnp.int32, unit=DIMENSIONLESS) + q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS) assert_quantity(q1 << 1, 0b11000, second) assert_quantity(q1 >> 1, 0b110, second) def test_numpy_methods(): - q = Quantity([[1, 2], [3, 4]], dim=mV) + q = [[1, 2], [3, 4]] * second assert q.all() assert q.any() assert q.nonzero()[0].tolist() == [0, 0, 1, 1] @@ -1603,7 +1603,7 @@ def test_constants(): import brainunit._unit_constants as constants # Check that the expected names exist and have the correct dimensions - assert constants.avogadro_constant.dim == (1 / mole).unit + assert constants.avogadro_constant.dim == (1 / mole).dim assert constants.boltzmann_constant.dim == (joule / kelvin).dim assert constants.electric_constant.dim == (farad / meter).dim assert constants.electron_mass.dim == kilogram.dim diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 4feb08d..16e34a6 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -521,7 +521,7 @@ def arange(*args, **kwargs): stop=stop, step=step, ) - unit = getattr(stop, "unit", DIMENSIONLESS) + unit = getattr(stop, "dim", DIMENSIONLESS) # start is a position-only argument in numpy 2.0 # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only # TODO: check whether this is still the case in the final release diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py index e649b14..227234c 100644 --- a/brainunit/math/_compat_numpy_funcs_change_unit.py +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -49,7 +49,7 @@ def decorator(func: Callable) -> Callable: @wraps(func) def f(x, *args, **kwargs): if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dtype=change_unit_func(x.dim))) + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.dim))) elif isinstance(x, (jnp.ndarray, np.ndarray)): return func(x, *args, **kwargs) else: diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py index 4a26216..0deb591 100644 --- a/brainunit/math/_compat_numpy_misc.py +++ b/brainunit/math/_compat_numpy_misc.py @@ -245,7 +245,7 @@ def intersect1d( fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') unit = None if isinstance(ar1, Quantity): - unit = ar1.unit + unit = ar1.dim ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index cdb47f0..8e39796 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -32,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.unit == unit.dim, f"Unit mismatch: {q.unit} != {unit}" + assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" diff --git a/brainunit/math/_others.py b/brainunit/math/_others.py index 720edba..d316eb4 100644 --- a/brainunit/math/_others.py +++ b/brainunit/math/_others.py @@ -16,7 +16,7 @@ import brainstate as bst -from ._compat_numpy import wrap_math_funcs_only_accept_unitless_unary +from ._compat_numpy_funcs_accept_unitless import wrap_math_funcs_only_accept_unitless_unary __all__ = [ 'exprel', From 08b90cd120f27e384ab648247fe5632b1ea4c0dc Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:41:23 +0800 Subject: [PATCH 12/23] Fix bugs in Python 3.9 --- brainunit/math/_compat_numpy_array_creation.py | 2 +- brainunit/math/_compat_numpy_funcs_bit_operation.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 16e34a6..156a553 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -664,7 +664,7 @@ def array_split(ary: Union[Quantity, bst.typing.ArrayLike], ''' if isinstance(ary, Quantity): return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] - elif isinstance(ary, bst.typing.ArrayLike): + elif isinstance(ary, (jax.Array, np.ndarray)): return jnp.array_split(ary, indices_or_sections, axis) else: raise ValueError(f'Unsupported type: {type(ary)} for array_split') diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index c48fa18..0ebe542 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -91,7 +91,7 @@ def wrap_elementwise_bit_operation_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) or isinstance(y, Quantity): raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, bst.typing.ArrayLike) and isinstance(y, bst.typing.ArrayLike): + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): return func(x, y, *args, **kwargs) else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') From d702944662c6e622038dd4dcd92780341f113b26 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:48:12 +0800 Subject: [PATCH 13/23] Update _compat_numpy_funcs_bit_operation.py --- brainunit/math/_compat_numpy_funcs_bit_operation.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index 0ebe542..280c22f 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -20,6 +20,7 @@ import jax.numpy as jnp import numpy as np from jax import Array +from numpy import number from .._base import (Quantity, ) @@ -91,7 +92,7 @@ def wrap_elementwise_bit_operation_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) or isinstance(y, Quantity): raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, number)): return func(x, y, *args, **kwargs) else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') From b0154ab40fb695503d68fff21e88167563e97806 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 17:51:01 +0800 Subject: [PATCH 14/23] Update _compat_numpy_funcs_bit_operation.py --- brainunit/math/_compat_numpy_funcs_bit_operation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py index 280c22f..1325539 100644 --- a/brainunit/math/_compat_numpy_funcs_bit_operation.py +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -92,7 +92,7 @@ def wrap_elementwise_bit_operation_binary(func): def f(x, y, *args, **kwargs): if isinstance(x, Quantity) or isinstance(y, Quantity): raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, number)): + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): return func(x, y, *args, **kwargs) else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') From d0fcce6777f46d52638b2f42a8183cbd6aec000a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 18:03:57 +0800 Subject: [PATCH 15/23] Fix logic of `asarray` --- .../math/_compat_numpy_array_creation.py | 35 +++++++++++++------ brainunit/math/_compat_numpy_test.py | 2 +- 2 files changed, 25 insertions(+), 12 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 156a553..16ea689 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -433,21 +433,34 @@ def asarray( order: Optional[str] = None, unit: Optional[Unit] = None, ) -> Union[Quantity, jax.Array]: + ''' + Convert the input to a quantity or array. + + If unit is provided, the input is converted to a Quantity object with the given unit. + + Args: + a: array_like, Quantity, or Sequence[Quantity] + dtype: data-type, optional + order: {'C', 'F', 'A', 'K'}, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' from builtins import all as origin_all from builtins import any as origin_any if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + if unit is not None: + assert isinstance(unit, Unit) + return jnp.asarray(a.value, dtype=dtype, order=order) * unit + else: + return jnp.asarray(a.value, dtype=dtype, order=order) elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.asarray(a, dtype=dtype, order=order) - # list[Quantity] - elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if origin_any(x.dim != a[0].dim for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].dim - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + if unit is not None: + assert isinstance(unit, Unit) + return jnp.asarray(a, dtype=dtype, order=order) * unit + else: + return jnp.asarray(a, dtype=dtype, order=order) else: return jnp.asarray(a, dtype=dtype, order=order) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 8e39796..18cdd4a 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -159,7 +159,7 @@ def test_asarray(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == jnp.asarray([1, 2, 3]))) - result_q = bu.math.asarray([1 * bu.second, 2 * bu.second, 3 * bu.second]) + result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) def test_arange(self): From c0f817106508dc92f020369f39dcea13f1d267e8 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Tue, 11 Jun 2024 18:24:18 +0800 Subject: [PATCH 16/23] update --- .../math/_compat_numpy_array_creation.py | 26 ++++++++++--------- 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 16ea689..f2d7527 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -53,10 +53,12 @@ def f(*args, unit: Unit = None, **kwargs): @wrap_array_creation_function -def full(shape: Sequence[int], - fill_value: Any, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: +def full( + shape: Sequence[int], + fill_value: Any, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: return jnp.full(shape, fill_value, dtype=dtype) @@ -447,8 +449,6 @@ def asarray( Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - from builtins import all as origin_all - from builtins import any as origin_any if isinstance(a, Quantity): if unit is not None: assert isinstance(unit, Unit) @@ -561,12 +561,14 @@ def arange(*args, **kwargs): @set_module_as('brainunit.math') -def linspace(start: Union[Quantity, bst.typing.ArrayLike], - stop: Union[Quantity, bst.typing.ArrayLike], - num: int = 50, - endpoint: Optional[bool] = True, - retstep: Optional[bool] = False, - dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: +def linspace( + start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: int = 50, + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None +) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. From 43d9a018c94f54cedb031db457e92df0e6ce961b Mon Sep 17 00:00:00 2001 From: Sichao He <1310722434@qq.com> Date: Tue, 11 Jun 2024 18:24:48 +0800 Subject: [PATCH 17/23] [docs] Update docs for `brainunit.math` (#4) * Update _compat_numpy.py * Update _compat_numpy.py * Update * Update _compat_numpy.py * Fix * Update brainunit.math.rst * Update _compat_numpy.py * Update _unit_test.py * Restruct * Update * Fix bugs * Fix bugs in Python 3.9 * Update _compat_numpy_funcs_bit_operation.py * Update _compat_numpy_funcs_bit_operation.py --- brainunit/_unit_test.py | 73 +- brainunit/math/__init__.py | 65 +- brainunit/math/_compat_numpy.py | 1455 ----------------- .../math/_compat_numpy_array_creation.py | 720 ++++++++ .../math/_compat_numpy_array_manipulation.py | 821 ++++++++++ .../_compat_numpy_funcs_accept_unitless.py | 588 +++++++ .../math/_compat_numpy_funcs_bit_operation.py | 183 +++ .../math/_compat_numpy_funcs_change_unit.py | 527 ++++++ .../math/_compat_numpy_funcs_indexing.py | 166 ++ .../math/_compat_numpy_funcs_keep_unit.py | 832 ++++++++++ brainunit/math/_compat_numpy_funcs_logic.py | 343 ++++ .../math/_compat_numpy_funcs_match_unit.py | 108 ++ .../math/_compat_numpy_funcs_remove_unit.py | 191 +++ brainunit/math/_compat_numpy_funcs_window.py | 69 + brainunit/math/_compat_numpy_get_attribute.py | 215 +++ .../math/_compat_numpy_linear_algebra.py | 149 ++ brainunit/math/_compat_numpy_misc.py | 354 ++++ brainunit/math/_compat_numpy_test.py | 49 +- brainunit/math/_others.py | 2 +- brainunit/math/_utils.py | 117 +- docs/apis/brainunit.math.rst | 395 ++++- docs/auto_generater.py | 32 +- 22 files changed, 5863 insertions(+), 1591 deletions(-) delete mode 100644 brainunit/math/_compat_numpy.py create mode 100644 brainunit/math/_compat_numpy_array_creation.py create mode 100644 brainunit/math/_compat_numpy_array_manipulation.py create mode 100644 brainunit/math/_compat_numpy_funcs_accept_unitless.py create mode 100644 brainunit/math/_compat_numpy_funcs_bit_operation.py create mode 100644 brainunit/math/_compat_numpy_funcs_change_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_indexing.py create mode 100644 brainunit/math/_compat_numpy_funcs_keep_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_logic.py create mode 100644 brainunit/math/_compat_numpy_funcs_match_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_remove_unit.py create mode 100644 brainunit/math/_compat_numpy_funcs_window.py create mode 100644 brainunit/math/_compat_numpy_get_attribute.py create mode 100644 brainunit/math/_compat_numpy_linear_algebra.py create mode 100644 brainunit/math/_compat_numpy_misc.py diff --git a/brainunit/_unit_test.py b/brainunit/_unit_test.py index 4510221..19095a9 100644 --- a/brainunit/_unit_test.py +++ b/brainunit/_unit_test.py @@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds): def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" + assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})" if not jnp.allclose(q.value, values): raise AssertionError(f"Values do not match: {q.value} != {values}") elif isinstance(q, jnp.ndarray): @@ -144,10 +144,10 @@ def test_get_dimensions(): Test various ways of getting/comparing the dimensions of a Array. """ q = 500 * ms - assert get_unit(q) is get_or_create_dimension(q.unit._dims) - assert get_unit(q) is q.unit + assert get_unit(q) is get_or_create_dimension(q.dim._dims) + assert get_unit(q) is q.dim assert q.has_same_unit(3 * second) - dims = q.unit + dims = q.dim assert_equal(dims.get_dimension("time"), 1.0) assert_equal(dims.get_dimension("length"), 0) @@ -201,47 +201,54 @@ def test_unary_operations(): def test_operations(): - q1 = Quantity(5, dim=mV) - q2 = Quantity(10, dim=mV) - assert_quantity(q1 + q2, 15, mV) - assert_quantity(q1 - q2, -5, mV) - assert_quantity(q1 * q2, 50, mV * mV) + q1 = 5 * second + q2 = 10 * second + assert_quantity(q1 + q2, 15, second) + assert_quantity(q1 - q2, -5, second) + assert_quantity(q1 * q2, 50, second * second) assert_quantity(q2 / q1, 2, DIMENSIONLESS) assert_quantity(q2 // q1, 2, DIMENSIONLESS) - assert_quantity(q2 % q1, 0, mV) + assert_quantity(q2 % q1, 0, second) assert_quantity(divmod(q2, q1)[0], 2, DIMENSIONLESS) - assert_quantity(divmod(q2, q1)[1], 0, mV) - assert_quantity(q1 ** 2, 25, mV ** 2) - assert_quantity(q1 << 1, 10, mV) - assert_quantity(q1 >> 1, 2, mV) - assert_quantity(round(q1, 0), 5, mV) + assert_quantity(divmod(q2, q1)[1], 0, second) + assert_quantity(q1 ** 2, 25, second ** 2) + assert_quantity(round(q1, 0), 5, second) + # matmul - q1 = Quantity([1, 2], dim=mV) - q2 = Quantity([3, 4], dim=mV) - assert_quantity(q1 @ q2, 11, mV ** 2) + q1 = [1, 2] * second + q2 = [3, 4] * second + assert_quantity(q1 @ q2, 11, second ** 2) + q1 = Quantity([1, 2], unit=second) + q2 = Quantity([3, 4], unit=second) + assert_quantity(q1 @ q2, 11, second ** 2) + + # shift + q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS) + assert_quantity(q1 << 1, 0b11000, second) + assert_quantity(q1 >> 1, 0b110, second) def test_numpy_methods(): - q = Quantity([[1, 2], [3, 4]], dim=mV) + q = [[1, 2], [3, 4]] * second assert q.all() assert q.any() assert q.nonzero()[0].tolist() == [0, 0, 1, 1] assert q.argmax() == 3 assert q.argmin() == 0 assert q.argsort(axis=None).tolist() == [0, 1, 2, 3] - assert_quantity(q.var(), 1.25, mV ** 2) - assert_quantity(q.round(), [[1, 2], [3, 4]], mV) - assert_quantity(q.std(), 1.11803398875, mV) - assert_quantity(q.sum(), 10, mV) - assert_quantity(q.trace(), 5, mV) - assert_quantity(q.cumsum(), [1, 3, 6, 10], mV) - assert_quantity(q.cumprod(), [1, 2, 6, 24], mV ** 4) - assert_quantity(q.diagonal(), [1, 4], mV) - assert_quantity(q.max(), 4, mV) - assert_quantity(q.mean(), 2.5, mV) - assert_quantity(q.min(), 1, mV) - assert_quantity(q.ptp(), 3, mV) - assert_quantity(q.ravel(), [1, 2, 3, 4], mV) + assert_quantity(q.var(), 1.25, second ** 2) + assert_quantity(q.round(), [[1, 2], [3, 4]], second) + assert_quantity(q.std(), 1.11803398875, second) + assert_quantity(q.sum(), 10, second) + assert_quantity(q.trace(), 5, second) + assert_quantity(q.cumsum(), [1, 3, 6, 10], second) + assert_quantity(q.cumprod(), [1, 2, 6, 24], second ** 4) + assert_quantity(q.diagonal(), [1, 4], second) + assert_quantity(q.max(), 4, second) + assert_quantity(q.mean(), 2.5, second) + assert_quantity(q.min(), 1, second) + assert_quantity(q.ptp(), 3, second) + assert_quantity(q.ravel(), [1, 2, 3, 4], second) def test_shape_manipulation(): @@ -1596,7 +1603,7 @@ def test_constants(): import brainunit._unit_constants as constants # Check that the expected names exist and have the correct dimensions - assert constants.avogadro_constant.dim == (1 / mole).unit + assert constants.avogadro_constant.dim == (1 / mole).dim assert constants.boltzmann_constant.dim == (joule / kelvin).dim assert constants.electric_constant.dim == (farad / meter).dim assert constants.electron_mass.dim == kilogram.dim diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 68b77d5..e574603 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -13,12 +13,67 @@ # limitations under the License. # ============================================================================== -from ._compat_numpy import * -from ._compat_numpy import __all__ as _compat_numpy_all +# from ._compat_numpy import * +# from ._compat_numpy import __all__ as _compat_numpy_all from ._others import * from ._others import __all__ as _other_all +from ._compat_numpy_array_creation import * +from ._compat_numpy_array_creation import __all__ as _compat_array_creation_all +from ._compat_numpy_array_manipulation import * +from ._compat_numpy_array_manipulation import __all__ as _compat_array_manipulation_all +from ._compat_numpy_funcs_accept_unitless import * +from ._compat_numpy_funcs_accept_unitless import __all__ as _compat_funcs_accept_unitless_all +from ._compat_numpy_funcs_bit_operation import * +from ._compat_numpy_funcs_bit_operation import __all__ as _compat_funcs_bit_operation_all +from ._compat_numpy_funcs_change_unit import * +from ._compat_numpy_funcs_change_unit import __all__ as _compat_funcs_change_unit_all +from ._compat_numpy_funcs_indexing import * +from ._compat_numpy_funcs_indexing import __all__ as _compat_funcs_indexing_all +from ._compat_numpy_funcs_keep_unit import * +from ._compat_numpy_funcs_keep_unit import __all__ as _compat_funcs_keep_unit_all +from ._compat_numpy_funcs_logic import * +from ._compat_numpy_funcs_logic import __all__ as _compat_funcs_logic_all +from ._compat_numpy_funcs_match_unit import * +from ._compat_numpy_funcs_match_unit import __all__ as _compat_funcs_match_unit_all +from ._compat_numpy_funcs_remove_unit import * +from ._compat_numpy_funcs_remove_unit import __all__ as _compat_funcs_remove_unit_all +from ._compat_numpy_funcs_window import * +from ._compat_numpy_funcs_window import __all__ as _compat_funcs_window_all +from ._compat_numpy_get_attribute import * +from ._compat_numpy_get_attribute import __all__ as _compat_get_attribute_all +from ._compat_numpy_linear_algebra import * +from ._compat_numpy_linear_algebra import __all__ as _compat_linear_algebra_all +from ._compat_numpy_misc import * +from ._compat_numpy_misc import __all__ as _compat_misc_all -__all__ = _compat_numpy_all + _other_all - -del _compat_numpy_all, _other_all +__all__ = _compat_array_creation_all + \ + _compat_array_manipulation_all + \ + _compat_funcs_change_unit_all + \ + _compat_funcs_keep_unit_all + \ + _compat_funcs_accept_unitless_all + \ + _compat_funcs_match_unit_all + \ + _compat_funcs_remove_unit_all + \ + _compat_get_attribute_all + \ + _compat_funcs_bit_operation_all + \ + _compat_funcs_logic_all + \ + _compat_funcs_indexing_all + \ + _compat_funcs_window_all + \ + _compat_linear_algebra_all + \ + _compat_misc_all + _other_all + \ + _other_all +del _compat_array_creation_all, \ + _compat_array_manipulation_all, \ + _compat_funcs_change_unit_all, \ + _compat_funcs_keep_unit_all, \ + _compat_funcs_accept_unitless_all, \ + _compat_funcs_match_unit_all, \ + _compat_funcs_remove_unit_all, \ + _compat_get_attribute_all, \ + _compat_funcs_bit_operation_all, \ + _compat_funcs_logic_all, \ + _compat_funcs_indexing_all, \ + _compat_funcs_window_all, \ + _compat_linear_algebra_all, \ + _compat_misc_all, \ + _other_all diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py deleted file mode 100644 index b150455..0000000 --- a/brainunit/math/_compat_numpy.py +++ /dev/null @@ -1,1455 +0,0 @@ -# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from collections.abc import Sequence -from functools import wraps -from typing import (Callable, Union, Optional) - -import brainstate as bst -import jax -import jax.numpy as jnp -import numpy as np -import opt_einsum -from brainstate._utils import set_module_as -from jax._src.numpy.lax_numpy import _einsum - -from ._utils import _compatible_with_quantity -from .._base import (DIMENSIONLESS, - Quantity, - Unit, - fail_for_dimension_mismatch, - is_unitless, - get_unit, ) -from .._base import _return_check_unitless - -__all__ = [ - # array creation - 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', - 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', - 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', - 'array_split', 'meshgrid', 'vander', - - # getting attribute funcs - 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', - 'isnan', 'shape', 'size', - - # math funcs keep unit (unary) - 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', - 'abs', 'round', 'around', 'round_', 'rint', - 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', - 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', - 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', - 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', - - # math funcs keep unit (binary) - 'fmod', 'mod', 'copysign', 'heaviside', - 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', - - # math funcs keep unit (n-ary) - 'interp', 'clip', - - # math funcs match unit (binary) - 'add', 'subtract', 'nextafter', - - # math funcs change unit (unary) - 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', - 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', - - # math funcs change unit (binary) - 'multiply', 'divide', 'power', 'cross', 'ldexp', - 'true_divide', 'floor_divide', 'float_power', - 'divmod', 'remainder', 'convolve', - - # math funcs only accept unitless (unary) - 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', - 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', - 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', - 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', - 'percentile', 'nanpercentile', 'quantile', 'nanquantile', - - # math funcs only accept unitless (binary) - 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', - - # math funcs remove unit (unary) - 'signbit', 'sign', 'histogram', 'bincount', - - # math funcs remove unit (binary) - 'corrcoef', 'correlate', 'cov', 'digitize', - - # array manipulation - 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', - 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', - 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', - 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', - 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', - 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', - 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', - 'diagflat', 'diagonal', 'choose', 'ravel', - - # Elementwise bit operations (unary) - 'bitwise_not', 'invert', 'left_shift', 'right_shift', - - # Elementwise bit operations (binary) - 'bitwise_and', 'bitwise_or', 'bitwise_xor', - - # logic funcs (unary) - 'all', 'any', 'logical_not', - - # logic funcs (binary) - 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', - 'array_equal', 'isclose', 'allclose', 'logical_and', - 'logical_or', 'logical_xor', "alltrue", 'sometrue', - - # indexing funcs - 'nonzero', 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', - 'triu_indices_from', 'take', 'select', - - # window funcs - 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', - - # constants - 'e', 'pi', 'inf', - - # linear algebra - 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', - - # data types - 'dtype', 'finfo', 'iinfo', - - # more - 'broadcast_arrays', 'broadcast_shapes', - 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', - 'rot90', 'tensordot', - -] - - -# array creation -# -------------- - -def wrap_array_creation_function(func): - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -# array creation -# -------------- - -full = wrap_array_creation_function(jnp.full) -eye = wrap_array_creation_function(jnp.eye) -identity = wrap_array_creation_function(jnp.identity) -tri = wrap_array_creation_function(jnp.tri) -empty = wrap_array_creation_function(jnp.empty) -ones = wrap_array_creation_function(jnp.ones) -zeros = wrap_array_creation_function(jnp.zeros) - - -@set_module_as('brainunit.math') -def full_like(a, fill_value, dtype=None, shape=None): - if isinstance(a, Quantity) and isinstance(fill_value, Quantity): - fail_for_dimension_mismatch(a, fill_value, error_message='Units do not match for full_like operation.') - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and not isinstance(fill_value, Quantity): - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(fill_value)} for full_like') - - -@set_module_as('brainunit.math') -def diag(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.diag(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.diag(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for diag') - - -@set_module_as('brainunit.math') -def tril(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.tril(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.tril(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for tril') - - -@set_module_as('brainunit.math') -def triu(a, k=0): - if isinstance(a, Quantity): - return Quantity(jnp.triu(a.value, k=k), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.triu(a, k=k) - else: - raise ValueError(f'Unsupported type: {type(a)} for triu') - - -@set_module_as('brainunit.math') -def empty_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.empty_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for empty_like') - - -@set_module_as('brainunit.math') -def ones_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.ones_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for ones_like') - - -@set_module_as('brainunit.math') -def zeros_like(a, dtype=None, shape=None): - if isinstance(a, Quantity): - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.zeros_like(a, dtype=dtype, shape=shape) - else: - raise ValueError(f'Unsupported type: {type(a)} for zeros_like') - - -@set_module_as('brainunit.math') -def asarray( - a, - dtype: Optional[bst.typing.DTypeLike] = None, - order: Optional[str] = None, - unit: Optional[Unit] = None, -): - from builtins import all as origin_all - from builtins import any as origin_any - if isinstance(a, Quantity): - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)): - return jnp.asarray(a, dtype=dtype, order=order) - # list[Quantity] - elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if origin_any(x.dim != a[0].dim for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].dim - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) - else: - return jnp.asarray(a, dtype=dtype, order=order) - - -array = asarray - - -@set_module_as('brainunit.math') -def arange(*args, **kwargs): - # arange has a bit of a complicated argument structure unfortunately - # we leave the actual checking of the number of arguments to numpy, though - - # default values - start = kwargs.pop("start", 0) - step = kwargs.pop("step", 1) - stop = kwargs.pop("stop", None) - if len(args) == 1: - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - stop = args[0] - elif len(args) == 2: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - start, stop = args - elif len(args) == 3: - if start != 0: - raise TypeError("Duplicate definition of 'start'") - if stop is not None: - raise TypeError("Duplicate definition of 'stop'") - if step != 1: - raise TypeError("Duplicate definition of 'step'") - start, stop, step = args - elif len(args) > 3: - raise TypeError("Need between 1 and 3 non-keyword arguments") - - if stop is None: - raise TypeError("Missing stop argument.") - if stop is not None and not is_unitless(stop): - start = Quantity(start, dim=stop.unit) - - fail_for_dimension_mismatch( - start, - stop, - error_message=( - "Start value {start} and stop value {stop} have to have the same units." - ), - start=start, - stop=stop, - ) - fail_for_dimension_mismatch( - stop, - step, - error_message=( - "Stop value {stop} and step value {step} have to have the same units." - ), - stop=stop, - step=step, - ) - unit = getattr(stop, "unit", DIMENSIONLESS) - # start is a position-only argument in numpy 2.0 - # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only - # TODO: check whether this is still the case in the final release - if start == 0: - return Quantity( - jnp.arange( - start=start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - dim=unit, - ) - else: - return Quantity( - jnp.arange( - start.value if isinstance(start, Quantity) else jnp.asarray(start), - stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), - step=step.value if isinstance(step, Quantity) else jnp.asarray(step), - **kwargs, - ), - dim=unit, - ) - - -@set_module_as('brainunit.math') -def linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None): - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) - return Quantity(result, dim=unit) - - -@set_module_as('brainunit.math') -def logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None): - fail_for_dimension_mismatch( - start, - stop, - error_message="Start value {start} and stop value {stop} have to have the same units.", - start=start, - stop=stop, - ) - unit = getattr(start, "unit", DIMENSIONLESS) - start = start.value if isinstance(start, Quantity) else start - stop = stop.value if isinstance(stop, Quantity) else stop - - result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) - return Quantity(result, dim=unit) - - -@set_module_as('brainunit.math') -def fill_diagonal(a, val, wrap=False, inplace=True): - if isinstance(a, Quantity) and isinstance(val, Quantity): - fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.unit) - elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - elif is_unitless(a) or is_unitless(val): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') - - -@set_module_as('brainunit.math') -def array_split(ary, indices_or_sections, axis=0): - if isinstance(ary, Quantity): - return Quantity(jnp.array_split(ary.value, indices_or_sections, axis), dim=ary.unit) - elif isinstance(ary, (jax.Array, np.ndarray)): - return jnp.array_split(ary, indices_or_sections, axis) - else: - raise ValueError(f'Unsupported type: {type(ary)} for array_split') - - -@set_module_as('brainunit.math') -def meshgrid(*xi, copy=True, sparse=False, indexing='xy'): - from builtins import all as origin_all - if origin_all(isinstance(x, Quantity) for x in xi): - fail_for_dimension_mismatch(*xi) - return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) - elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): - return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) - else: - raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') - - -@set_module_as('brainunit.math') -def vander(x, N=None, increasing=False): - if isinstance(x, Quantity): - return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return jnp.vander(x, N=N, increasing=increasing) - else: - raise ValueError(f'Unsupported type: {type(x)} for vander') - - -# getting attribute funcs -# ----------------------- - -@set_module_as('brainunit.math') -def ndim(a): - if isinstance(a, Quantity): - return a.ndim - else: - return jnp.ndim(a) - - -@set_module_as('brainunit.math') -def isreal(a): - if isinstance(a, Quantity): - return a.isreal - else: - return jnp.isreal(a) - - -@set_module_as('brainunit.math') -def isscalar(a): - if isinstance(a, Quantity): - return a.isscalar - else: - return jnp.isscalar(a) - - -@set_module_as('brainunit.math') -def isfinite(a): - if isinstance(a, Quantity): - return a.isfinite - else: - return jnp.isfinite(a) - - -@set_module_as('brainunit.math') -def isinf(a): - if isinstance(a, Quantity): - return a.isinf - else: - return jnp.isinf(a) - - -@set_module_as('brainunit.math') -def isnan(a): - if isinstance(a, Quantity): - return a.isnan - else: - return jnp.isnan(a) - - -@set_module_as('brainunit.math') -def shape(a): - """ - Return the shape of an array. - - Parameters - ---------- - a : array_like - Input array. - - Returns - ------- - shape : tuple of ints - The elements of the shape tuple give the lengths of the - corresponding array dimensions. - - See Also - -------- - len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with - ``N>=1``. - ndarray.shape : Equivalent array method. - - Examples - -------- - >>> brainunit.math.shape(brainunit.math.eye(3)) - (3, 3) - >>> brainunit.math.shape([[1, 3]]) - (1, 2) - >>> brainunit.math.shape([0]) - (1,) - >>> brainunit.math.shape(0) - () - - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - return a.shape - else: - return np.shape(a) - - -@set_module_as('brainunit.math') -def size(a, axis=None): - """ - Return the number of elements along a given axis. - - Parameters - ---------- - a : array_like - Input data. - axis : int, optional - Axis along which the elements are counted. By default, give - the total number of elements. - - Returns - ------- - element_count : int - Number of elements along the specified axis. - - See Also - -------- - shape : dimensions of array - Array.shape : dimensions of array - Array.size : number of elements in array - - Examples - -------- - >>> a = Quantity([[1,2,3], [4,5,6]]) - >>> brainunit.math.size(a) - 6 - >>> brainunit.math.size(a, 1) - 3 - >>> brainunit.math.size(a, 0) - 2 - """ - if isinstance(a, (Quantity, jax.Array, np.ndarray)): - if axis is None: - return a.size - else: - return a.shape[axis] - else: - return np.size(a, axis=axis) - - -# math funcs keep unit (unary) -# ---------------------------- - -def wrap_math_funcs_keep_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -real = wrap_math_funcs_keep_unit_unary(jnp.real) -imag = wrap_math_funcs_keep_unit_unary(jnp.imag) -conj = wrap_math_funcs_keep_unit_unary(jnp.conj) -conjugate = wrap_math_funcs_keep_unit_unary(jnp.conjugate) -negative = wrap_math_funcs_keep_unit_unary(jnp.negative) -positive = wrap_math_funcs_keep_unit_unary(jnp.positive) -abs = wrap_math_funcs_keep_unit_unary(jnp.abs) -round_ = wrap_math_funcs_keep_unit_unary(jnp.round) -around = wrap_math_funcs_keep_unit_unary(jnp.around) -round = wrap_math_funcs_keep_unit_unary(jnp.round) -rint = wrap_math_funcs_keep_unit_unary(jnp.rint) -floor = wrap_math_funcs_keep_unit_unary(jnp.floor) -ceil = wrap_math_funcs_keep_unit_unary(jnp.ceil) -trunc = wrap_math_funcs_keep_unit_unary(jnp.trunc) -fix = wrap_math_funcs_keep_unit_unary(jnp.fix) -sum = wrap_math_funcs_keep_unit_unary(jnp.sum) -nancumsum = wrap_math_funcs_keep_unit_unary(jnp.nancumsum) -nansum = wrap_math_funcs_keep_unit_unary(jnp.nansum) -cumsum = wrap_math_funcs_keep_unit_unary(jnp.cumsum) -ediff1d = wrap_math_funcs_keep_unit_unary(jnp.ediff1d) -absolute = wrap_math_funcs_keep_unit_unary(jnp.absolute) -fabs = wrap_math_funcs_keep_unit_unary(jnp.fabs) -median = wrap_math_funcs_keep_unit_unary(jnp.median) -nanmin = wrap_math_funcs_keep_unit_unary(jnp.nanmin) -nanmax = wrap_math_funcs_keep_unit_unary(jnp.nanmax) -ptp = wrap_math_funcs_keep_unit_unary(jnp.ptp) -average = wrap_math_funcs_keep_unit_unary(jnp.average) -mean = wrap_math_funcs_keep_unit_unary(jnp.mean) -std = wrap_math_funcs_keep_unit_unary(jnp.std) -nanmedian = wrap_math_funcs_keep_unit_unary(jnp.nanmedian) -nanmean = wrap_math_funcs_keep_unit_unary(jnp.nanmean) -nanstd = wrap_math_funcs_keep_unit_unary(jnp.nanstd) -diff = wrap_math_funcs_keep_unit_unary(jnp.diff) -modf = wrap_math_funcs_keep_unit_unary(jnp.modf) - - -# math funcs keep unit (binary) -# ----------------------------- - -def wrap_math_funcs_keep_unit_binary(func): - def f(x1, x2, *args, **kwargs): - if isinstance(x1, Quantity) and isinstance(x2, Quantity): - return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.unit) - elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): - return func(x1, x2, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -fmod = wrap_math_funcs_keep_unit_binary(jnp.fmod) -mod = wrap_math_funcs_keep_unit_binary(jnp.mod) -copysign = wrap_math_funcs_keep_unit_binary(jnp.copysign) -heaviside = wrap_math_funcs_keep_unit_binary(jnp.heaviside) -maximum = wrap_math_funcs_keep_unit_binary(jnp.maximum) -minimum = wrap_math_funcs_keep_unit_binary(jnp.minimum) -fmax = wrap_math_funcs_keep_unit_binary(jnp.fmax) -fmin = wrap_math_funcs_keep_unit_binary(jnp.fmin) -lcm = wrap_math_funcs_keep_unit_binary(jnp.lcm) -gcd = wrap_math_funcs_keep_unit_binary(jnp.gcd) - - -# math funcs keep unit (n-ary) -# ---------------------------- -@set_module_as('brainunit.math') -def interp(x, xp, fp, left=None, right=None, period=None): - unit = None - if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): - unit = x.unit if isinstance(x, Quantity) else xp.unit if isinstance(xp, Quantity) else fp.unit - if isinstance(x, Quantity): - x_value = x.value - else: - x_value = x - if isinstance(xp, Quantity): - xp_value = xp.value - else: - xp_value = xp - if isinstance(fp, Quantity): - fp_value = fp.value - else: - fp_value = fp - result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -@set_module_as('brainunit.math') -def clip(a, a_min, a_max): - unit = None - if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): - unit = a.unit if isinstance(a, Quantity) else a_min.unit if isinstance(a_min, Quantity) else a_max.unit - if isinstance(a, Quantity): - a_value = a.value - else: - a_value = a - if isinstance(a_min, Quantity): - a_min_value = a_min.value - else: - a_min_value = a_min - if isinstance(a_max, Quantity): - a_max_value = a_max.value - else: - a_max_value = a_max - result = jnp.clip(a_value, a_min_value, a_max_value) - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -# math funcs match unit (binary) -# ------------------------------ - -def wrap_math_funcs_match_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.unit) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - if x.is_unitless: - return Quantity(func(x.value, y, *args, **kwargs), dim=x.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - elif isinstance(y, Quantity): - if y.is_unitless: - return Quantity(func(x, y.value, *args, **kwargs), dim=y.unit) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -add = wrap_math_funcs_match_unit_binary(jnp.add) -subtract = wrap_math_funcs_match_unit_binary(jnp.subtract) -nextafter = wrap_math_funcs_match_unit_binary(jnp.nextafter) - - -# math funcs change unit (unary) -# ------------------------------ - -def wrap_math_funcs_change_unit_unary(func, change_unit_func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.unit))) - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -reciprocal = wrap_math_funcs_change_unit_unary(jnp.reciprocal, lambda x: x ** -1) - - -@set_module_as('brainunit.math') -def prod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - - -@set_module_as('brainunit.math') -def nanprod(x, axis=None, dtype=None, out=None, keepdims=False, initial=None): - if isinstance(x, Quantity): - return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - else: - return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial) - - -product = prod - - -@set_module_as('brainunit.math') -def cumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.cumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) - - -@set_module_as('brainunit.math') -def nancumprod(x, axis=None, dtype=None, out=None): - if isinstance(x, Quantity): - return x.nancumprod(axis=axis, dtype=dtype, out=out) - else: - return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) - - -cumproduct = cumprod - -var = wrap_math_funcs_change_unit_unary(jnp.var, lambda x: x ** 2) -nanvar = wrap_math_funcs_change_unit_unary(jnp.nanvar, lambda x: x ** 2) -frexp = wrap_math_funcs_change_unit_unary(jnp.frexp, lambda x, y: x * 2 ** y) -sqrt = wrap_math_funcs_change_unit_unary(jnp.sqrt, lambda x: x ** 0.5) -cbrt = wrap_math_funcs_change_unit_unary(jnp.cbrt, lambda x: x ** (1 / 3)) -square = wrap_math_funcs_change_unit_unary(jnp.square, lambda x: x ** 2) - - -# math funcs change unit (binary) -# ------------------------------- - -def wrap_math_funcs_change_unit_binary(func, change_unit_func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.unit, y.unit)) - ) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless( - Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.unit, DIMENSIONLESS))) - elif isinstance(y, Quantity): - return _return_check_unitless( - Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.unit))) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -multiply = wrap_math_funcs_change_unit_binary(jnp.multiply, lambda x, y: x * y) -divide = wrap_math_funcs_change_unit_binary(jnp.divide, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y.value, *args, **kwargs), dim=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.power(x.value, y, *args, **kwargs), dim=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.power(x, y.value, *args, **kwargs), dim=x ** y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') - - -cross = wrap_math_funcs_change_unit_binary(jnp.cross, lambda x, y: x * y) -ldexp = wrap_math_funcs_change_unit_binary(jnp.ldexp, lambda x, y: x * 2 ** y) -true_divide = wrap_math_funcs_change_unit_binary(jnp.true_divide, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def floor_divide(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value, *args, **kwargs), dim=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.floor_divide(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y, *args, **kwargs), dim=x.unit / y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value, *args, **kwargs), dim=x / y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') - - -@set_module_as('brainunit.math') -def float_power(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y.value, *args, **kwargs), dim=x.unit ** y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.float_power(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x.value, y, *args, **kwargs), dim=x.unit ** y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.float_power(x, y.value, *args, **kwargs), dim=x ** y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') - - -divmod = wrap_math_funcs_change_unit_binary(jnp.divmod, lambda x, y: x / y) - - -@set_module_as('brainunit.math') -def remainder(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value, *args, **kwargs), dim=x.unit / y.unit)) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return jnp.remainder(x, y, *args, **kwargs) - elif isinstance(x, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x.value, y, *args, **kwargs), dim=x.unit % y)) - elif isinstance(y, Quantity): - return _return_check_unitless(Quantity(jnp.remainder(x, y.value, *args, **kwargs), dim=x % y.unit)) - else: - raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') - - -convolve = wrap_math_funcs_change_unit_binary(jnp.convolve, lambda x, y: x * y) - - -# math funcs only accept unitless (unary) -# --------------------------------------- - -def wrap_math_funcs_only_accept_unitless_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - return func(jnp.array(x.value), *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -exp = wrap_math_funcs_only_accept_unitless_unary(jnp.exp) -exp2 = wrap_math_funcs_only_accept_unitless_unary(jnp.exp2) -expm1 = wrap_math_funcs_only_accept_unitless_unary(jnp.expm1) -log = wrap_math_funcs_only_accept_unitless_unary(jnp.log) -log10 = wrap_math_funcs_only_accept_unitless_unary(jnp.log10) -log1p = wrap_math_funcs_only_accept_unitless_unary(jnp.log1p) -log2 = wrap_math_funcs_only_accept_unitless_unary(jnp.log2) -arccos = wrap_math_funcs_only_accept_unitless_unary(jnp.arccos) -arccosh = wrap_math_funcs_only_accept_unitless_unary(jnp.arccosh) -arcsin = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsin) -arcsinh = wrap_math_funcs_only_accept_unitless_unary(jnp.arcsinh) -arctan = wrap_math_funcs_only_accept_unitless_unary(jnp.arctan) -arctanh = wrap_math_funcs_only_accept_unitless_unary(jnp.arctanh) -cos = wrap_math_funcs_only_accept_unitless_unary(jnp.cos) -cosh = wrap_math_funcs_only_accept_unitless_unary(jnp.cosh) -sin = wrap_math_funcs_only_accept_unitless_unary(jnp.sin) -sinc = wrap_math_funcs_only_accept_unitless_unary(jnp.sinc) -sinh = wrap_math_funcs_only_accept_unitless_unary(jnp.sinh) -tan = wrap_math_funcs_only_accept_unitless_unary(jnp.tan) -tanh = wrap_math_funcs_only_accept_unitless_unary(jnp.tanh) -deg2rad = wrap_math_funcs_only_accept_unitless_unary(jnp.deg2rad) -rad2deg = wrap_math_funcs_only_accept_unitless_unary(jnp.rad2deg) -degrees = wrap_math_funcs_only_accept_unitless_unary(jnp.degrees) -radians = wrap_math_funcs_only_accept_unitless_unary(jnp.radians) -angle = wrap_math_funcs_only_accept_unitless_unary(jnp.angle) -percentile = wrap_math_funcs_only_accept_unitless_unary(jnp.percentile) -nanpercentile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanpercentile) -quantile = wrap_math_funcs_only_accept_unitless_unary(jnp.quantile) -nanquantile = wrap_math_funcs_only_accept_unitless_unary(jnp.nanquantile) - - -# math funcs only accept unitless (binary) -# ---------------------------------------- - -def wrap_math_funcs_only_accept_unitless_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - fail_for_dimension_mismatch( - x, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=x, - ) - fail_for_dimension_mismatch( - y, - error_message="%s expects a dimensionless argument but got {value}" % func.__name__, - value=y, - ) - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -hypot = wrap_math_funcs_only_accept_unitless_binary(jnp.hypot) -arctan2 = wrap_math_funcs_only_accept_unitless_binary(jnp.arctan2) -logaddexp = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp) -logaddexp2 = wrap_math_funcs_only_accept_unitless_binary(jnp.logaddexp2) - - -# math funcs remove unit (unary) -# ------------------------------ -def wrap_math_funcs_remove_unit_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return func(x.value, *args, **kwargs) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -signbit = wrap_math_funcs_remove_unit_unary(jnp.signbit) -sign = wrap_math_funcs_remove_unit_unary(jnp.sign) -histogram = wrap_math_funcs_remove_unit_unary(jnp.histogram) -bincount = wrap_math_funcs_remove_unit_unary(jnp.bincount) - - -# math funcs remove unit (binary) -# ------------------------------- -def wrap_math_funcs_remove_unit_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity): - x_value = x.value - if isinstance(y, Quantity): - y_value = y.value - if isinstance(x, Quantity) or isinstance(y, Quantity): - return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) - else: - return func(x, y, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -corrcoef = wrap_math_funcs_remove_unit_binary(jnp.corrcoef) -correlate = wrap_math_funcs_remove_unit_binary(jnp.correlate) -cov = wrap_math_funcs_remove_unit_binary(jnp.cov) -digitize = wrap_math_funcs_remove_unit_binary(jnp.digitize) - -# array manipulation -# ------------------ - -reshape = _compatible_with_quantity(jnp.reshape) -moveaxis = _compatible_with_quantity(jnp.moveaxis) -transpose = _compatible_with_quantity(jnp.transpose) -swapaxes = _compatible_with_quantity(jnp.swapaxes) -concatenate = _compatible_with_quantity(jnp.concatenate) -stack = _compatible_with_quantity(jnp.stack) -vstack = _compatible_with_quantity(jnp.vstack) -row_stack = vstack -hstack = _compatible_with_quantity(jnp.hstack) -dstack = _compatible_with_quantity(jnp.dstack) -column_stack = _compatible_with_quantity(jnp.column_stack) -split = _compatible_with_quantity(jnp.split) -dsplit = _compatible_with_quantity(jnp.dsplit) -hsplit = _compatible_with_quantity(jnp.hsplit) -vsplit = _compatible_with_quantity(jnp.vsplit) -tile = _compatible_with_quantity(jnp.tile) -repeat = _compatible_with_quantity(jnp.repeat) -unique = _compatible_with_quantity(jnp.unique) -append = _compatible_with_quantity(jnp.append) -flip = _compatible_with_quantity(jnp.flip) -fliplr = _compatible_with_quantity(jnp.fliplr) -flipud = _compatible_with_quantity(jnp.flipud) -roll = _compatible_with_quantity(jnp.roll) -atleast_1d = _compatible_with_quantity(jnp.atleast_1d) -atleast_2d = _compatible_with_quantity(jnp.atleast_2d) -atleast_3d = _compatible_with_quantity(jnp.atleast_3d) -expand_dims = _compatible_with_quantity(jnp.expand_dims) -squeeze = _compatible_with_quantity(jnp.squeeze) -sort = _compatible_with_quantity(jnp.sort) - -max = _compatible_with_quantity(jnp.max) -min = _compatible_with_quantity(jnp.min) - -amax = max -amin = min - -choose = _compatible_with_quantity(jnp.choose) -block = _compatible_with_quantity(jnp.block) -compress = _compatible_with_quantity(jnp.compress) -diagflat = _compatible_with_quantity(jnp.diagflat) - -# return jax.numpy.Array, not Quantity -argsort = _compatible_with_quantity(jnp.argsort, return_quantity=False) -argmax = _compatible_with_quantity(jnp.argmax, return_quantity=False) -argmin = _compatible_with_quantity(jnp.argmin, return_quantity=False) -argwhere = _compatible_with_quantity(jnp.argwhere, return_quantity=False) -nonzero = _compatible_with_quantity(jnp.nonzero, return_quantity=False) -flatnonzero = _compatible_with_quantity(jnp.flatnonzero, return_quantity=False) -searchsorted = _compatible_with_quantity(jnp.searchsorted, return_quantity=False) -extract = _compatible_with_quantity(jnp.extract, return_quantity=False) -count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) - - -def wrap_function_to_method(func): - @wraps(func) - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - return Quantity(func(x.value, *args, **kwargs), dim=x.unit) - else: - return func(x, *args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -diagonal = wrap_function_to_method(jnp.diagonal) -ravel = wrap_function_to_method(jnp.ravel) - - -# Elementwise bit operations (unary) -# ---------------------------------- - -def wrap_elementwise_bit_operation_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected integers, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) -invert = wrap_elementwise_bit_operation_unary(jnp.invert) -left_shift = wrap_elementwise_bit_operation_unary(jnp.left_shift) -right_shift = wrap_elementwise_bit_operation_unary(jnp.right_shift) - - -# Elementwise bit operations (binary) -# ----------------------------------- - -def wrap_elementwise_bit_operation_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) or isinstance(y, Quantity): - raise ValueError(f'Expected integers, got {x} and {y}') - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) -bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) -bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) - - -# logic funcs (unary) -# ------------------- - -def wrap_logic_func_unary(func): - def f(x, *args, **kwargs): - if isinstance(x, Quantity): - raise ValueError(f'Expected booleans, got {x}') - elif isinstance(x, (jax.Array, np.ndarray)): - return func(x, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -all = wrap_logic_func_unary(jnp.all) -any = wrap_logic_func_unary(jnp.any) -alltrue = all -sometrue = any -logical_not = wrap_logic_func_unary(jnp.logical_not) - - -# logic funcs (binary) -# -------------------- - -def wrap_logic_func_binary(func): - def f(x, y, *args, **kwargs): - if isinstance(x, Quantity) and isinstance(y, Quantity): - fail_for_dimension_mismatch(x, y) - return func(x.value, y.value, *args, **kwargs) - elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): - return func(x, y, *args, **kwargs) - else: - raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - - f.__module__ = 'brainunit.math' - return f - - -equal = wrap_logic_func_binary(jnp.equal) -not_equal = wrap_logic_func_binary(jnp.not_equal) -greater = wrap_logic_func_binary(jnp.greater) -greater_equal = wrap_logic_func_binary(jnp.greater_equal) -less = wrap_logic_func_binary(jnp.less) -less_equal = wrap_logic_func_binary(jnp.less_equal) -array_equal = wrap_logic_func_binary(jnp.array_equal) -isclose = wrap_logic_func_binary(jnp.isclose) -allclose = wrap_logic_func_binary(jnp.allclose) -logical_and = wrap_logic_func_binary(jnp.logical_and) - -logical_or = wrap_logic_func_binary(jnp.logical_or) -logical_xor = wrap_logic_func_binary(jnp.logical_xor) - - -# indexing funcs -# -------------- -@set_module_as('brainunit.math') -def where(condition, *args, **kwds): # pylint: disable=C0111 - condition = jnp.asarray(condition) - if len(args) == 0: - # nothing to do - return jnp.where(condition, *args, **kwds) - elif len(args) == 2: - # check that x and y have the same dimensions - fail_for_dimension_mismatch( - args[0], args[1], "x and y need to have the same dimensions" - ) - new_args = [] - for arg in args: - if isinstance(arg, Quantity): - new_args.append(arg.value) - if is_unitless(args[0]): - if len(new_args) == 2: - return jnp.where(condition, *new_args, **kwds) - else: - return jnp.where(condition, *args, **kwds) - else: - # as both arguments have the same unit, just use the first one's - dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] - return Quantity.with_units( - jnp.where(condition, *dimensionless_args), args[0].dim - ) - else: - # illegal number of arguments - if len(args) == 1: - raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") - elif len(args) > 2: - raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) - - -tril_indices = jnp.tril_indices - - -@set_module_as('brainunit.math') -def tril_indices_from(arr, k=0): - if isinstance(arr, Quantity): - return jnp.tril_indices_from(arr.value, k=k) - else: - return jnp.tril_indices_from(arr, k=k) - - -triu_indices = jnp.triu_indices - - -@set_module_as('brainunit.math') -def triu_indices_from(arr, k=0): - if isinstance(arr, Quantity): - return jnp.triu_indices_from(arr.value, k=k) - else: - return jnp.triu_indices_from(arr, k=k) - - -@set_module_as('brainunit.math') -def take(a, indices, axis=None, mode=None): - if isinstance(a, Quantity): - return a.take(indices, axis=axis, mode=mode) - else: - return jnp.take(a, indices, axis=axis, mode=mode) - - -@set_module_as('brainunit.math') -def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quantity, jax.Array, np.ndarray], default=0): - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(choice, Quantity) for choice in choicelist): - if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): - raise ValueError("All choices must have the same unit") - else: - return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), - dim=choicelist[0].dim) - elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): - return jnp.select(condlist, choicelist, default=default) - else: - raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") - - -# window funcs -# ------------ - -def wrap_window_funcs(func): - def f(*args, **kwargs): - return Quantity(func(*args, **kwargs)) - - f.__module__ = 'brainunit.math' - return f - - -bartlett = wrap_window_funcs(jnp.bartlett) -blackman = wrap_window_funcs(jnp.blackman) -hamming = wrap_window_funcs(jnp.hamming) -hanning = wrap_window_funcs(jnp.hanning) -kaiser = wrap_window_funcs(jnp.kaiser) - -# constants -# --------- -e = jnp.e -pi = jnp.pi -inf = jnp.inf - -# linear algebra -# -------------- -dot = wrap_math_funcs_change_unit_binary(jnp.dot, lambda x, y: x * y) -vdot = wrap_math_funcs_change_unit_binary(jnp.vdot, lambda x, y: x * y) -inner = wrap_math_funcs_change_unit_binary(jnp.inner, lambda x, y: x * y) -outer = wrap_math_funcs_change_unit_binary(jnp.outer, lambda x, y: x * y) -kron = wrap_math_funcs_change_unit_binary(jnp.kron, lambda x, y: x * y) -matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) -trace = wrap_math_funcs_keep_unit_unary(jnp.trace) - -# data types -# ---------- -dtype = jnp.dtype - - -@set_module_as('brainunit.math') -def finfo(a): - if isinstance(a, Quantity): - return jnp.finfo(a.value) - else: - return jnp.finfo(a) - - -@set_module_as('brainunit.math') -def iinfo(a): - if isinstance(a, Quantity): - return jnp.iinfo(a.value) - else: - return jnp.iinfo(a) - - -# more -# ---- -@set_module_as('brainunit.math') -def broadcast_arrays(*args): - from builtins import all as origin_all - from builtins import any as origin_any - if origin_all(isinstance(arg, Quantity) for arg in args): - if origin_any(arg.dim != args[0].dim for arg in args): - raise ValueError("All arguments must have the same unit") - return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) - elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): - return jnp.broadcast_arrays(*args) - else: - raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") - - -broadcast_shapes = jnp.broadcast_shapes - - -@set_module_as('brainunit.math') -def einsum( - subscripts, /, - *operands, - out: None = None, - optimize: Union[str, bool] = "optimal", - precision: jax.lax.PrecisionLike = None, - preferred_element_type: Union[jax.typing.DTypeLike, None] = None, - _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, -) -> Union[jax.Array, Quantity]: - operands = (subscripts, *operands) - if out is not None: - raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") - spec = operands[0] if isinstance(operands[0], str) else None - optimize = 'optimal' if optimize is True else optimize - - # Allow handling of shape polymorphism - non_constant_dim_types = { - type(d) for op in operands if not isinstance(op, str) - for d in np.shape(op) if not jax.core.is_constant_dim(d) - } - if not non_constant_dim_types: - contract_path = opt_einsum.contract_path - else: - from jax._src.numpy.lax_numpy import _default_poly_einsum_handler - contract_path = _default_poly_einsum_handler - - operands, contractions = contract_path( - *operands, einsum_call=True, use_blas=True, optimize=optimize) - - unit = None - for i in range(len(contractions) - 1): - if contractions[i][4] == 'False': - - fail_for_dimension_mismatch( - Quantity([], dim=unit), operands[i + 1], 'einsum' - ) - elif contractions[i][4] == 'DOT' or \ - contractions[i][4] == 'TDOT' or \ - contractions[i][4] == 'GEMM' or \ - contractions[i][4] == 'OUTER/EINSUM': - if i == 0: - if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): - unit = operands[i].dim * operands[i + 1].dim - elif isinstance(operands[i], Quantity): - unit = operands[i].dim - elif isinstance(operands[i + 1], Quantity): - unit = operands[i + 1].dim - else: - if isinstance(operands[i + 1], Quantity): - unit = unit * operands[i + 1].dim - - contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) - - einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) - if spec is not None: - einsum = jax.named_call(einsum, name=spec) - operands = [op.value if isinstance(op, Quantity) else op for op in operands] - r = einsum(operands, contractions, precision, # type: ignore[operator] - preferred_element_type, _dot_general) - if unit is not None: - return Quantity(r, dim=unit) - else: - return r - - -@set_module_as('brainunit.math') -def gradient( - f: Union[jax.Array, np.ndarray, Quantity], - *varargs: Union[jax.Array, np.ndarray, Quantity], - axis: Union[int, Sequence[int], None] = None, - edge_order: Union[int, None] = None, -) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: - if edge_order is not None: - raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") - - if len(varargs) == 0: - if isinstance(f, Quantity) and not is_unitless(f): - return Quantity(jnp.gradient(f.value, axis=axis), dim=f.unit) - else: - return jnp.gradient(f) - elif len(varargs) == 1: - unit = get_unit(f) / get_unit(varargs[0]) - if unit is None or unit == DIMENSIONLESS: - return jnp.gradient(f, varargs[0], axis=axis) - else: - return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] - else: - unit_list = [get_unit(f) / get_unit(v) for v in varargs] - f = f.value if isinstance(f, Quantity) else f - varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] - result_list = jnp.gradient(f, *varargs, axis=axis) - return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] - - -@set_module_as('brainunit.math') -def intersect1d( - ar1: Union[jax.Array, np.ndarray], - ar2: Union[jax.Array, np.ndarray], - assume_unique: bool = False, - return_indices: bool = False -) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: - fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') - unit = None - if isinstance(ar1, Quantity): - unit = ar1.unit - ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 - ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 - result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) - if return_indices: - if unit is not None: - return (Quantity(result[0], dim=unit), result[1], result[2]) - else: - return result - else: - if unit is not None: - return Quantity(result, dim=unit) - else: - return result - - -nan_to_num = wrap_math_funcs_keep_unit_unary(jnp.nan_to_num) -nanargmax = _compatible_with_quantity(jnp.nanargmax, return_quantity=False) -nanargmin = _compatible_with_quantity(jnp.nanargmin, return_quantity=False) - -rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) -tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py new file mode 100644 index 0000000..156a553 --- /dev/null +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -0,0 +1,720 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional, Any) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as +from jax import Array + +from .._base import (DIMENSIONLESS, + Quantity, + Unit, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + # array creation + 'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu', + 'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like', + 'array', 'asarray', 'arange', 'linspace', 'logspace', 'fill_diagonal', + 'array_split', 'meshgrid', 'vander', +] + + +def wrap_array_creation_function(func: Callable) -> Callable: + @wraps(func) + def f(*args, unit: Unit = None, **kwargs): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return func(*args, **kwargs) * unit + else: + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_array_creation_function +def full(shape: Sequence[int], + fill_value: Any, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.full(shape, fill_value, dtype=dtype) + + +@wrap_array_creation_function +def eye(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.eye(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def identity(n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.identity(n, dtype=dtype) + + +@wrap_array_creation_function +def tri(N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.tri(N, M, k, dtype=dtype) + + +@wrap_array_creation_function +def empty(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.empty(shape, dtype=dtype) + + +@wrap_array_creation_function +def ones(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.ones(shape, dtype=dtype) + + +@wrap_array_creation_function +def zeros(shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None) -> Union[Array, Quantity]: + return jnp.zeros(shape, dtype=dtype) + + +full.__doc__ = ''' + Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `shape` filled with `fill_value`. + + Args: + shape: sequence of integers, describing the shape of the output array. + fill_value: the value to fill the new array with. + dtype: the type of the output array, or `None`. If not `None`, `fill_value` + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + +eye.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +identity.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. + else return an identity matrix of `shape`. + + Args: + n: the number of rows (and columns) in the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +tri.__doc__ = """ + Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. + else return a triangular matrix of `shape`. + + Args: + n: the number of rows in the output array. + m: the number of columns with default being `n`. + k: the index of the diagonal: 0 (the default) refers to the main diagonal, + a positive value refers to an upper diagonal, and a negative value to a + lower diagonal. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# empty +empty.__doc__ = """ + Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `shape` with uninitialized values. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be of type `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# ones +ones.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. + else return an array of `shape` filled with 1. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + +# zeros +zeros.__doc__ = """ + Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. + else return an array of `shape` filled with 0. + + Args: + shape: sequence of integers, describing the shape of the output array. + dtype: the type of the output array, or `None`. If not `None`, elements + will be cast to `dtype`. + sharding: an optional sharding specification for the resulting array, + note, sharding will currently be ignored in jitted mode, this might change + in the future. + unit: the unit of the output array, or `None`. + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. +""" + + +@set_module_as('brainunit.math') +def full_like(a: Union[Quantity, bst.typing.ArrayLike], + fill_value: Union[bst.typing.ArrayLike], + unit: Unit = None, + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + else return an array of `a` filled with `fill_value`. + + Args: + a: array_like, Quantity, shape, or dtype + fill_value: scalar or array_like + unit: Unit, optional + dtype: data-type, optional + shape: sequence of ints, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + else: + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def diag(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Extract a diagonal or construct a diagonal array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.diag(a.value, k=k) * unit + else: + return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) + + +@set_module_as('brainunit.math') +def tril(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Lower triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.tril(a.value, k=k) * unit + else: + return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) + + +@set_module_as('brainunit.math') +def triu(a: Union[Quantity, bst.typing.ArrayLike], + k: int = 0, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Upper triangle of an array. + + Args: + a: array_like, Quantity + k: int, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.triu(a.value, k=k) * unit + else: + return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) + + +@set_module_as('brainunit.math') +def empty_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. + else return an array of `a` with uninitialized values. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def ones_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. + else return an array of `a` filled with 1. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], + dtype: Optional[bst.typing.DTypeLike] = None, + shape: Any = None, + unit: Unit = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. + else return an array of `a` filled with 0. + + Args: + a: array_like, Quantity, shape, or dtype + dtype: data-type, optional + shape: sequence of ints, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. + ''' + if unit is not None: + assert isinstance(unit, Unit) + if isinstance(a, Quantity): + return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) + + +@set_module_as('brainunit.math') +def asarray( + a: Union[Quantity, bst.typing.ArrayLike, Sequence[Quantity]], + dtype: Optional[bst.typing.DTypeLike] = None, + order: Optional[str] = None, + unit: Optional[Unit] = None, +) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if isinstance(a, Quantity): + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence) and origin_all(isinstance(x, Quantity) for x in a): + # check all elements have the same unit + if origin_any(x.dim != a[0].dim for x in a): + raise ValueError('Units do not match for asarray operation.') + values = [x.value for x in a] + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + else: + return jnp.asarray(a, dtype=dtype, order=order) + + +array = asarray + + +@set_module_as('brainunit.math') +def arange(*args, **kwargs): + ''' + Return a Quantity of `arange` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity, optional + stop: number, Quantity, optional + step: number, optional + dtype: dtype, optional + unit: Unit, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + # arange has a bit of a complicated argument structure unfortunately + # we leave the actual checking of the number of arguments to numpy, though + + # default values + start = kwargs.pop("start", 0) + step = kwargs.pop("step", 1) + stop = kwargs.pop("stop", None) + if len(args) == 1: + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + stop = args[0] + elif len(args) == 2: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + start, stop = args + elif len(args) == 3: + if start != 0: + raise TypeError("Duplicate definition of 'start'") + if stop is not None: + raise TypeError("Duplicate definition of 'stop'") + if step != 1: + raise TypeError("Duplicate definition of 'step'") + start, stop, step = args + elif len(args) > 3: + raise TypeError("Need between 1 and 3 non-keyword arguments") + + if stop is None: + raise TypeError("Missing stop argument.") + if stop is not None and not is_unitless(stop): + start = Quantity(start, dim=stop.dim) + + fail_for_dimension_mismatch( + start, + stop, + error_message=( + "Start value {start} and stop value {stop} have to have the same units." + ), + start=start, + stop=stop, + ) + fail_for_dimension_mismatch( + stop, + step, + error_message=( + "Stop value {stop} and step value {step} have to have the same units." + ), + stop=stop, + step=step, + ) + unit = getattr(stop, "dim", DIMENSIONLESS) + # start is a position-only argument in numpy 2.0 + # https://numpy.org/devdocs/release/2.0.0-notes.html#arange-s-start-argument-is-positional-only + # TODO: check whether this is still the case in the final release + if start == 0: + return Quantity( + jnp.arange( + start=start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + dim=unit, + ) + else: + return Quantity( + jnp.arange( + start.value if isinstance(start, Quantity) else jnp.asarray(start), + stop=stop.value if isinstance(stop, Quantity) else jnp.asarray(stop), + step=step.value if isinstance(step, Quantity) else jnp.asarray(step), + **kwargs, + ), + dim=unit, + ) + + +@set_module_as('brainunit.math') +def linspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: int = 50, + endpoint: Optional[bool] = True, + retstep: Optional[bool] = False, + dtype: Optional[bst.typing.DTypeLike] = None) -> Union[Quantity, jax.Array]: + ''' + Return a Quantity of `linspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + retstep: bool, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "dim", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.linspace(start, stop, num=num, endpoint=endpoint, retstep=retstep, dtype=dtype) + return Quantity(result, dim=unit) + + +@set_module_as('brainunit.math') +def logspace(start: Union[Quantity, bst.typing.ArrayLike], + stop: Union[Quantity, bst.typing.ArrayLike], + num: Optional[int] = 50, + endpoint: Optional[bool] = True, + base: Optional[float] = 10.0, + dtype: Optional[bst.typing.DTypeLike] = None): + ''' + Return a Quantity of `logspace` and `unit`, with uninitialized values if `unit` is provided. + + Args: + start: number, Quantity + stop: number, Quantity + num: int, optional + endpoint: bool, optional + base: float, optional + dtype: dtype, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if start and stop are Quantities that have the same unit, else an array. + ''' + fail_for_dimension_mismatch( + start, + stop, + error_message="Start value {start} and stop value {stop} have to have the same units.", + start=start, + stop=stop, + ) + unit = getattr(start, "dim", DIMENSIONLESS) + start = start.value if isinstance(start, Quantity) else start + stop = stop.value if isinstance(stop, Quantity) else stop + + result = jnp.logspace(start, stop, num=num, endpoint=endpoint, base=base, dtype=dtype) + return Quantity(result, dim=unit) + + +@set_module_as('brainunit.math') +def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], + val: Union[Quantity, bst.typing.ArrayLike], + wrap: Optional[bool] = False, + inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: + ''' + Fill the main diagonal of the given array of `a` with `val`. + + Args: + a: array_like, Quantity + val: scalar, Quantity + wrap: bool, optional + inplace: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. + ''' + if isinstance(a, Quantity) and isinstance(val, Quantity): + fail_for_dimension_mismatch(a, val) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + elif is_unitless(a) or is_unitless(val): + return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + else: + raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') + + +@set_module_as('brainunit.math') +def array_split(ary: Union[Quantity, bst.typing.ArrayLike], + indices_or_sections: Union[int, bst.typing.ArrayLike], + axis: Optional[int] = 0) -> Union[list[Quantity], list[Array]]: + ''' + Split an array into multiple sub-arrays. + + Args: + ary: array_like, Quantity + indices_or_sections: int, array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `ary` is a Quantity, else an array. + ''' + if isinstance(ary, Quantity): + return [Quantity(x, dim=ary.dim) for x in jnp.array_split(ary.value, indices_or_sections, axis)] + elif isinstance(ary, (jax.Array, np.ndarray)): + return jnp.array_split(ary, indices_or_sections, axis) + else: + raise ValueError(f'Unsupported type: {type(ary)} for array_split') + + +@set_module_as('brainunit.math') +def meshgrid(*xi: Union[Quantity, bst.typing.ArrayLike], + copy: Optional[bool] = True, + sparse: Optional[bool] = False, + indexing: Optional[str] = 'xy'): + ''' + Return coordinate matrices from coordinate vectors. + + Args: + xi: array_like, Quantity + copy: bool, optional + sparse: bool, optional + indexing: str, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `xi` are Quantities that have the same unit, else an array. + ''' + from builtins import all as origin_all + if origin_all(isinstance(x, Quantity) for x in xi): + fail_for_dimension_mismatch(*xi) + return Quantity(jnp.meshgrid(*[x.value for x in xi], copy=copy, sparse=sparse, indexing=indexing), dim=xi[0].dim) + elif origin_all(isinstance(x, (jax.Array, np.ndarray)) for x in xi): + return jnp.meshgrid(*xi, copy=copy, sparse=sparse, indexing=indexing) + else: + raise ValueError(f'Unsupported types : {type(xi)} for meshgrid') + + +@set_module_as('brainunit.math') +def vander(x: Union[Quantity, bst.typing.ArrayLike], + N: Optional[bool] = None, + increasing: Optional[bool] = False) -> Union[Quantity, jax.Array]: + ''' + Generate a Vandermonde matrix. + + Args: + x: array_like, Quantity + N: int, optional + increasing: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return Quantity(jnp.vander(x.value, N=N, increasing=increasing), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.vander(x, N=N, increasing=increasing) + else: + raise ValueError(f'Unsupported type: {type(x)} for vander') diff --git a/brainunit/math/_compat_numpy_array_manipulation.py b/brainunit/math/_compat_numpy_array_manipulation.py new file mode 100644 index 0000000..c4a7c26 --- /dev/null +++ b/brainunit/math/_compat_numpy_array_manipulation.py @@ -0,0 +1,821 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Union, Optional, Tuple, List) + +import jax +import jax.numpy as jnp +from jax import Array + +from ._utils import _compatible_with_quantity +from .._base import (Quantity, + ) + +__all__ = [ + # array manipulation + 'reshape', 'moveaxis', 'transpose', 'swapaxes', 'row_stack', + 'concatenate', 'stack', 'vstack', 'hstack', 'dstack', 'column_stack', + 'split', 'dsplit', 'hsplit', 'vsplit', 'tile', 'repeat', 'unique', + 'append', 'flip', 'fliplr', 'flipud', 'roll', 'atleast_1d', 'atleast_2d', + 'atleast_3d', 'expand_dims', 'squeeze', 'sort', 'argsort', 'argmax', 'argmin', + 'argwhere', 'nonzero', 'flatnonzero', 'searchsorted', 'extract', + 'count_nonzero', 'max', 'min', 'amax', 'amin', 'block', 'compress', + 'diagflat', 'diagonal', 'choose', 'ravel', +] + + +# array manipulation +# ------------------ + + +@_compatible_with_quantity() +def reshape(a: Union[Array, Quantity], shape: Union[int, Tuple[int, ...]], order: str = 'C') -> Union[Array, Quantity]: + return jnp.reshape(a, shape, order) + + +@_compatible_with_quantity() +def moveaxis(a: Union[Array, Quantity], source: Union[int, Tuple[int, ...]], + destination: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.moveaxis(a, source, destination) + + +@_compatible_with_quantity() +def transpose(a: Union[Array, Quantity], axes: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.transpose(a, axes) + + +@_compatible_with_quantity() +def swapaxes(a: Union[Array, Quantity], axis1: int, axis2: int) -> Union[Array, Quantity]: + return jnp.swapaxes(a, axis1, axis2) + + +@_compatible_with_quantity() +def concatenate(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.concatenate(arrays, axis) + + +@_compatible_with_quantity() +def stack(arrays: Union[Sequence[Array], Sequence[Quantity]], axis: int = 0) -> Union[Array, Quantity]: + return jnp.stack(arrays, axis) + + +@_compatible_with_quantity() +def vstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.vstack(arrays) + + +row_stack = vstack + + +@_compatible_with_quantity() +def hstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.hstack(arrays) + + +@_compatible_with_quantity() +def dstack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.dstack(arrays) + + +@_compatible_with_quantity() +def column_stack(arrays: Union[Sequence[Array], Sequence[Quantity]]) -> Union[Array, Quantity]: + return jnp.column_stack(arrays) + + +@_compatible_with_quantity() +def split(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]], axis: int = 0) -> Union[ + List[Array], List[Quantity]]: + return jnp.split(a, indices_or_sections, axis) + + +@_compatible_with_quantity() +def dsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.dsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def hsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.hsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def vsplit(a: Union[Array, Quantity], indices_or_sections: Union[int, Sequence[int]]) -> Union[ + List[Array], List[Quantity]]: + return jnp.vsplit(a, indices_or_sections) + + +@_compatible_with_quantity() +def tile(A: Union[Array, Quantity], reps: Union[int, Tuple[int, ...]]) -> Union[Array, Quantity]: + return jnp.tile(A, reps) + + +@_compatible_with_quantity() +def repeat(a: Union[Array, Quantity], repeats: Union[int, Tuple[int, ...]], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.repeat(a, repeats, axis) + + +@_compatible_with_quantity() +def unique(a: Union[Array, Quantity], return_index: bool = False, return_inverse: bool = False, + return_counts: bool = False, axis: Optional[int] = None) -> Union[Array, Quantity]: + return jnp.unique(a, return_index, return_inverse, return_counts, axis) + + +@_compatible_with_quantity() +def append(arr: Union[Array, Quantity], values: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.append(arr, values, axis) + + +@_compatible_with_quantity() +def flip(m: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.flip(m, axis) + + +@_compatible_with_quantity() +def fliplr(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.fliplr(m) + + +@_compatible_with_quantity() +def flipud(m: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.flipud(m) + + +@_compatible_with_quantity() +def roll(a: Union[Array, Quantity], shift: Union[int, Tuple[int, ...]], + axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.roll(a, shift, axis) + + +@_compatible_with_quantity() +def atleast_1d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_1d(*arys) + + +@_compatible_with_quantity() +def atleast_2d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_2d(*arys) + + +@_compatible_with_quantity() +def atleast_3d(*arys: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.atleast_3d(*arys) + + +@_compatible_with_quantity() +def expand_dims(a: Union[Array, Quantity], axis: int) -> Union[Array, Quantity]: + return jnp.expand_dims(a, axis) + + +@_compatible_with_quantity() +def squeeze(a: Union[Array, Quantity], axis: Optional[Union[int, Tuple[int, ...]]] = None) -> Union[Array, Quantity]: + return jnp.squeeze(a, axis) + + +@_compatible_with_quantity() +def sort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Union[Array, Quantity]: + return jnp.sort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def argsort(a: Union[Array, Quantity], + axis: Optional[int] = -1, + kind: None = None, + order: None = None, + stable: bool = True, + descending: bool = False, ) -> Array: + return jnp.argsort(a, axis, kind=kind, order=order, stable=stable, descending=descending) + + +@_compatible_with_quantity() +def max(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.max(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def min(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None, + keepdims: bool = False) -> Union[Array, Quantity]: + return jnp.min(a, axis, out, keepdims) + + +@_compatible_with_quantity() +def choose(a: Union[Array, Quantity], choices: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.choose(a, choices) + + +@_compatible_with_quantity() +def block(arrays: Sequence[Union[Array, Quantity]]) -> Union[Array, Quantity]: + return jnp.block(arrays) + + +@_compatible_with_quantity() +def compress(condition: Union[Array, Quantity], a: Union[Array, Quantity], axis: Optional[int] = None) -> Union[ + Array, Quantity]: + return jnp.compress(condition, a, axis) + + +@_compatible_with_quantity() +def diagflat(v: Union[Array, Quantity], k: int = 0) -> Union[Array, Quantity]: + return jnp.diagflat(v, k) + + +# return jax.numpy.Array, not Quantity + +@_compatible_with_quantity(return_quantity=False) +def argmax(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmax(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argmin(a: Union[Array, Quantity], axis: Optional[int] = None, out: Optional[Array] = None) -> Array: + return jnp.argmin(a, axis, out) + + +@_compatible_with_quantity(return_quantity=False) +def argwhere(a: Union[Array, Quantity]) -> Array: + return jnp.argwhere(a) + + +@_compatible_with_quantity(return_quantity=False) +def nonzero(a: Union[Array, Quantity]) -> Tuple[Array, ...]: + return jnp.nonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def flatnonzero(a: Union[Array, Quantity]) -> Array: + return jnp.flatnonzero(a) + + +@_compatible_with_quantity(return_quantity=False) +def searchsorted(a: Union[Array, Quantity], v: Union[Array, Quantity], side: str = 'left', + sorter: Optional[Array] = None) -> Array: + return jnp.searchsorted(a, v, side, sorter) + + +@_compatible_with_quantity(return_quantity=False) +def extract(condition: Union[Array, Quantity], arr: Union[Array, Quantity]) -> Array: + return jnp.extract(condition, arr) + + +@_compatible_with_quantity(return_quantity=False) +def count_nonzero(a: Union[Array, Quantity], axis: Optional[int] = None) -> Array: + return jnp.count_nonzero(a, axis) + + +amax = max +amin = min + +# docs for the functions above +reshape.__doc__ = ''' + Return a reshaped copy of an array or a Quantity. + + Args: + a: input array or Quantity to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + brainunit does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. +''' + +moveaxis.__doc__ = ''' + Moves axes of an array to new positions. Other axes remain in their original order. + + Args: + a: array_like, Quantity + source: int or sequence of ints + destination: int or sequence of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +transpose.__doc__ = ''' + Returns a view of the array with axes transposed. + + Args: + a: array_like, Quantity + axes: tuple or list of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +swapaxes.__doc__ = ''' + Interchanges two axes of an array. + + Args: + a: array_like, Quantity + axis1: int + axis2: int + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +concatenate.__doc__ = ''' + Join a sequence of arrays along an existing axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +stack.__doc__ = ''' + Join a sequence of arrays along a new axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +vstack.__doc__ = ''' + Stack arrays in sequence vertically (row wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +hstack.__doc__ = ''' + Stack arrays in sequence horizontally (column wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +dstack.__doc__ = ''' + Stack arrays in sequence depth wise (along third axis). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +column_stack.__doc__ = ''' + Stack 1-D arrays as columns into a 2-D array. + + Args: + arrays: sequence of 1-D or 2-D array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +split.__doc__ = ''' + Split an array into multiple sub-arrays. + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +dsplit.__doc__ = ''' + Split array along third axis (depth). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +hsplit.__doc__ = ''' + Split an array into multiple sub-arrays horizontally (column-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +vsplit.__doc__ = ''' + Split an array into multiple sub-arrays vertically (row-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + Union[jax.Array, Quantity] a list of Quantity if a is a Quantity, otherwise a list of jax.Array +''' + +tile.__doc__ = ''' + Construct an array by repeating A the number of times given by reps. + + Args: + A: array_like, Quantity + reps: array_like + + Returns: + Union[jax.Array, Quantity] a Quantity if A is a Quantity, otherwise a jax.Array +''' + +repeat.__doc__ = ''' + Repeat elements of an array. + + Args: + a: array_like, Quantity + repeats: array_like + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +unique.__doc__ = ''' + Find the unique elements of an array. + + Args: + a: array_like, Quantity + return_index: bool, optional + return_inverse: bool, optional + return_counts: bool, optional + axis: int or None, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +append.__doc__ = ''' + Append values to the end of an array. + + Args: + arr: array_like, Quantity + values: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if arr and values are Quantity, otherwise a jax.Array +''' + +flip.__doc__ = ''' + Reverse the order of elements in an array along the given axis. + + Args: + m: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +fliplr.__doc__ = ''' + Flip array in the left/right direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +flipud.__doc__ = ''' + Flip array in the up/down direction. + + Args: + m: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if m is a Quantity, otherwise a jax.Array +''' + +roll.__doc__ = ''' + Roll array elements along a given axis. + + Args: + a: array_like, Quantity + shift: int or tuple of ints + axis: int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +atleast_1d.__doc__ = ''' + View inputs as arrays with at least one dimension. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_2d.__doc__ = ''' + View inputs as arrays with at least two dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +atleast_3d.__doc__ = ''' + View inputs as arrays with at least three dimensions. + + Args: + *args: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if any input is a Quantity, otherwise a jax.Array +''' + +expand_dims.__doc__ = ''' + Expand the shape of an array. + + Args: + a: array_like, Quantity + axis: int or tuple of ints + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +squeeze.__doc__ = ''' + Remove single-dimensional entries from the shape of an array. + + Args: + a: array_like, Quantity + axis: None or int or tuple of ints, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +sort.__doc__ = ''' + Return a sorted copy of an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' +max.__doc__ = ''' + Return the maximum of an array or maximum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +min.__doc__ = ''' + Return the minimum of an array or minimum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +choose.__doc__ = ''' + Use an index array to construct a new array from a set of choices. + + Args: + a: array_like, Quantity + choices: array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if a and choices are Quantity, otherwise a jax.Array +''' + +block.__doc__ = ''' + Assemble an nd-array from nested lists of blocks. + + Args: + arrays: sequence of array_like, Quantity + + Returns: + Union[jax.Array, Quantity] a Quantity if all input arrays are Quantity, otherwise a jax.Array +''' + +compress.__doc__ = ''' + Return selected slices of an array along given axis. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + axis: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +diagflat.__doc__ = ''' + Create a two-dimensional array with the flattened input as a diagonal. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity] a Quantity if a is a Quantity, otherwise a jax.Array +''' + +argsort.__doc__ = ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort'}, optional + order: str or list of str, optional + + Returns: + jax.Array jax.numpy.Array (does not return a Quantity) +''' + +argmax.__doc__ = ''' + Returns indices of the max value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argmin.__doc__ = ''' + Returns indices of the min value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +argwhere.__doc__ = ''' + Find indices of non-zero elements. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +nonzero.__doc__ = ''' + Return the indices of the elements that are non-zero. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +flatnonzero.__doc__ = ''' + Return indices that are non-zero in the flattened version of a. + + Args: + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +searchsorted.__doc__ = ''' + Find indices where elements should be inserted to maintain order. + + Args: + a: array_like, Quantity + v: array_like, Quantity + side: {'left', 'right'}, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +extract.__doc__ = ''' + Return the elements of an array that satisfy some condition. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + + Returns: + jax.Array: an array (does not return a Quantity) +''' + +count_nonzero.__doc__ = ''' + Counts the number of non-zero values in the array a. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + jax.Array: an array (does not return a Quantity) +''' + + +def wrap_function_to_method(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_function_to_method +def diagonal(a: Union[jax.Array, Quantity], offset: int = 0, axis1: int = 0, axis2: int = 1) -> Union[ + jax.Array, Quantity]: + return jnp.diagonal(a, offset, axis1, axis2) + + +@wrap_function_to_method +def ravel(a: Union[jax.Array, Quantity], order: str = 'C') -> Union[jax.Array, Quantity]: + return jnp.ravel(a, order) + + +diagonal.__doc__ = ''' + Return specified diagonals. + + Args: + a: array_like, Quantity + offset: int, optional + axis1: int, optional + axis2: int, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +ravel.__doc__ = ''' + Return a contiguous flattened array. + + Args: + a: array_like, Quantity + order: {'C', 'F', 'A', 'K'}, optional + + Returns: + Union[jax.Array, Quantity]: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' diff --git a/brainunit/math/_compat_numpy_funcs_accept_unitless.py b/brainunit/math/_compat_numpy_funcs_accept_unitless.py new file mode 100644 index 0000000..c87890a --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_accept_unitless.py @@ -0,0 +1,588 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs only accept unitless (unary) + 'exp', 'exp2', 'expm1', 'log', 'log10', 'log1p', 'log2', + 'arccos', 'arccosh', 'arcsin', 'arcsinh', 'arctan', + 'arctanh', 'cos', 'cosh', 'sin', 'sinc', 'sinh', 'tan', + 'tanh', 'deg2rad', 'rad2deg', 'degrees', 'radians', 'angle', + 'percentile', 'nanpercentile', 'quantile', 'nanquantile', + + # math funcs only accept unitless (binary) + 'hypot', 'arctan2', 'logaddexp', 'logaddexp2', +] + + +# math funcs only accept unitless (unary) +# --------------------------------------- + +def wrap_math_funcs_only_accept_unitless_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + return func(jnp.array(x.value), *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_unary +def exp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def exp2(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Array, Quantity]: + return jnp.exp2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def expm1(x: Union[Array, Quantity]) -> Array: + return jnp.expm1(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log(x: Union[Array, Quantity]) -> Array: + return jnp.log(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log10(x: Union[Array, Quantity]) -> Array: + return jnp.log10(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log1p(x: Union[Array, Quantity]) -> Array: + return jnp.log1p(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def log2(x: Union[Array, Quantity]) -> Array: + return jnp.log2(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccos(x: Union[Array, Quantity]) -> Array: + return jnp.arccos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arccosh(x: Union[Array, Quantity]) -> Array: + return jnp.arccosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsin(x: Union[Array, Quantity]) -> Array: + return jnp.arcsin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arcsinh(x: Union[Array, Quantity]) -> Array: + return jnp.arcsinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctan(x: Union[Array, Quantity]) -> Array: + return jnp.arctan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def arctanh(x: Union[Array, Quantity]) -> Array: + return jnp.arctanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cos(x: Union[Array, Quantity]) -> Array: + return jnp.cos(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def cosh(x: Union[Array, Quantity]) -> Array: + return jnp.cosh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sin(x: Union[Array, Quantity]) -> Array: + return jnp.sin(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinc(x: Union[Array, Quantity]) -> Array: + return jnp.sinc(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def sinh(x: Union[Array, Quantity]) -> Array: + return jnp.sinh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tan(x: Union[Array, Quantity]) -> Array: + return jnp.tan(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def tanh(x: Union[Array, Quantity]) -> Array: + return jnp.tanh(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def deg2rad(x: Union[Array, Quantity]) -> Array: + return jnp.deg2rad(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def rad2deg(x: Union[Array, Quantity]) -> Array: + return jnp.rad2deg(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def degrees(x: Union[Array, Quantity]) -> Array: + return jnp.degrees(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def radians(x: Union[Array, Quantity]) -> Array: + return jnp.radians(x) + + +@wrap_math_funcs_only_accept_unitless_unary +def angle(x: Union[Array, Quantity]) -> Array: + return jnp.angle(x) + + +# docs for the functions above +exp.__doc__ = ''' + Calculate the exponential of all elements in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +exp2.__doc__ = ''' + Calculate 2 raised to the power of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +expm1.__doc__ = ''' + Calculate the exponential of the input elements minus 1. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log.__doc__ = ''' + Natural logarithm, element-wise. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log10.__doc__ = ''' + Base-10 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log1p.__doc__ = ''' + Natural logarithm of 1 + the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +log2.__doc__ = ''' + Base-2 logarithm of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccos.__doc__ = ''' + Compute the arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arccosh.__doc__ = ''' + Compute the hyperbolic arccosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsin.__doc__ = ''' + Compute the arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arcsinh.__doc__ = ''' + Compute the hyperbolic arcsine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan.__doc__ = ''' + Compute the arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctanh.__doc__ = ''' + Compute the hyperbolic arctangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cos.__doc__ = ''' + Compute the cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cosh.__doc__ = ''' + Compute the hyperbolic cosine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sin.__doc__ = ''' + Compute the sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinc.__doc__ = ''' + Compute the sinc function of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sinh.__doc__ = ''' + Compute the hyperbolic sine of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tan.__doc__ = ''' + Compute the tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +tanh.__doc__ = ''' + Compute the hyperbolic tangent of the input elements. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +deg2rad.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +rad2deg.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +degrees.__doc__ = ''' + Convert angles from radians to degrees. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +radians.__doc__ = ''' + Convert angles from degrees to radians. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +angle.__doc__ = ''' + Return the angle of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs only accept unitless (binary) +# ---------------------------------------- + +def wrap_math_funcs_only_accept_unitless_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + fail_for_dimension_mismatch( + x, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=x, + ) + fail_for_dimension_mismatch( + y, + error_message="%s expects a dimensionless argument but got {value}" % func.__name__, + value=y, + ) + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_only_accept_unitless_binary +def hypot(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.hypot(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def arctan2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.arctan2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def logaddexp2(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.logaddexp2(x, y) + + +@wrap_math_funcs_only_accept_unitless_binary +def percentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.percentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanpercentile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanpercentile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def quantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.quantile(a, q, *args, **kwargs) + + +@wrap_math_funcs_only_accept_unitless_binary +def nanquantile(a: Union[Array, Quantity], q: Union[Array, Quantity], *args, **kwargs) -> Array: + return jnp.nanquantile(a, q, *args, **kwargs) + + +# docs for the functions above +hypot.__doc__ = ''' + Given the “legs” of a right triangle, return its hypotenuse. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +arctan2.__doc__ = ''' + Element-wise arc tangent of `x1/x2` choosing the quadrant correctly. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +logaddexp2.__doc__ = ''' + Logarithm of the sum of exponentiations of the inputs in base-2. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + jax.Array: an array +''' + +percentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanpercentile.__doc__ = ''' + Compute the nth percentile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +quantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +nanquantile.__doc__ = ''' + Compute the qth quantile of the input array along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_bit_operation.py b/brainunit/math/_compat_numpy_funcs_bit_operation.py new file mode 100644 index 0000000..1325539 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_bit_operation.py @@ -0,0 +1,183 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array +from numpy import number + +from .._base import (Quantity, + ) + +__all__ = [ + + # Elementwise bit operations (unary) + 'bitwise_not', 'invert', + + # Elementwise bit operations (binary) + 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', +] + + +# Elementwise bit operations (unary) +# ---------------------------------- + +def wrap_elementwise_bit_operation_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected integers, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_unary +def bitwise_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_not(x) + + +@wrap_elementwise_bit_operation_unary +def invert(x: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.invert(x) + + +# docs for functions above +bitwise_not.__doc__ = ''' + Compute the bit-wise NOT of an array, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + +invert.__doc__ = ''' + Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Args: + x: array_like + + Returns: + jax.Array: an array +''' + + +# Elementwise bit operations (binary) +# ----------------------------------- + +def wrap_elementwise_bit_operation_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) or isinstance(y, Quantity): + raise ValueError(f'Expected integers, got {x} and {y}') + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray, int, float)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_elementwise_bit_operation_binary +def bitwise_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_and(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_or(x, y) + + +@wrap_elementwise_bit_operation_binary +def bitwise_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.bitwise_xor(x, y) + + +@wrap_elementwise_bit_operation_binary +def left_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.left_shift(x, y) + + +@wrap_elementwise_bit_operation_binary +def right_shift(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Array: + return jnp.right_shift(x, y) + + +# docs for functions above +bitwise_and.__doc__ = ''' + Compute the bit-wise AND of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_or.__doc__ = ''' + Compute the bit-wise OR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +bitwise_xor.__doc__ = ''' + Compute the bit-wise XOR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +left_shift.__doc__ = ''' + Shift the bits of an integer to the left. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' + +right_shift.__doc__ = ''' + Shift the bits of an integer to the right. + + Args: + x: array_like + y: array_like + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_change_unit.py b/brainunit/math/_compat_numpy_funcs_change_unit.py new file mode 100644 index 0000000..227234c --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_change_unit.py @@ -0,0 +1,527 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from functools import wraps +from typing import (Callable, Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from ._compat_numpy_get_attribute import isscalar +from .._base import (DIMENSIONLESS, + Quantity, + ) +from .._base import _return_check_unitless + +__all__ = [ + + # math funcs change unit (unary) + 'reciprocal', 'prod', 'product', 'nancumprod', 'nanprod', 'cumprod', + 'cumproduct', 'var', 'nanvar', 'cbrt', 'square', 'frexp', 'sqrt', + + # math funcs change unit (binary) + 'multiply', 'divide', 'power', 'cross', 'ldexp', + 'true_divide', 'floor_divide', 'float_power', + 'divmod', 'remainder', 'convolve', +] + + +# math funcs change unit (unary) +# ------------------------------ + +def wrap_math_funcs_change_unit_unary(change_unit_func: Callable) -> Callable: + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(func(x.value, *args, **kwargs), dim=change_unit_func(x.dim))) + elif isinstance(x, (jnp.ndarray, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + return decorator + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** -1) +def reciprocal(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.reciprocal(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def var(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.var(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def nanvar(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[Union[int, Sequence[int]]] = None, + ddof: int = 0, + keepdims: bool = False) -> Union[Quantity, jax.Array]: + return jnp.nanvar(x, axis=axis, ddof=ddof, keepdims=keepdims) + + +@wrap_math_funcs_change_unit_unary(lambda x: x * 2 ** -1) +def frexp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.frexp(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 0.5) +def sqrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sqrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** (1 / 3)) +def cbrt(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cbrt(x) + + +@wrap_math_funcs_change_unit_unary(lambda x: x ** 2) +def square(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.square(x) + + +# docs for the functions above + +reciprocal.__doc__ = ''' + Return the reciprocal of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +var.__doc__ = ''' + Compute the variance along the specified axis. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +nanvar.__doc__ = ''' + Compute the variance along the specified axis, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + +frexp.__doc__ = ''' + Decompose a floating-point number into its mantissa and exponent. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Tuple of Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the exponent, else a tuple of arrays. +''' + +sqrt.__doc__ = ''' + Compute the square root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square root of the unit of `x`, else an array. +''' + +cbrt.__doc__ = ''' + Compute the cube root of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the cube root of the unit of `x`, else an array. +''' + +square.__doc__ = ''' + Compute the square of each element. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the square of the unit of `x`, else an array. +''' + + +@set_module_as('brainunit.math') +def prod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: Optional[bool] = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None, + promote_integers: bool = True) -> Union[Quantity, jax.Array]: + ''' + Return the product of array elements over a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + promote_integers: bool, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.prod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + else: + return jnp.prod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where, + promote_integers=promote_integers) + + +@set_module_as('brainunit.math') +def nanprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None, + keepdims: bool = False, + initial: Union[Quantity, bst.typing.ArrayLike] = None, + where: Union[Quantity, bst.typing.ArrayLike] = None): + ''' + Return the product of array elements over a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + keepdims: bool, optional + initial: array_like, Quantity, optional + where: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nanprod(axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + else: + return jnp.nanprod(x, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where) + + +product = prod + + +@set_module_as('brainunit.math') +def cumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.cumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.cumprod(x, axis=axis, dtype=dtype, out=out) + + +@set_module_as('brainunit.math') +def nancumprod(x: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + dtype: Optional[bst.typing.DTypeLike] = None, + out: None = None) -> Union[Quantity, bst.typing.ArrayLike]: + ''' + Return the cumulative product of elements along a given axis treating Not a Numbers (NaNs) as one. + + Args: + x: array_like, Quantity + axis: int, optional + dtype: dtype, optional + out: array, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. + ''' + if isinstance(x, Quantity): + return x.nancumprod(axis=axis, dtype=dtype, out=out) + else: + return jnp.nancumprod(x, axis=axis, dtype=dtype, out=out) + + +cumproduct = cumprod + + +# math funcs change unit (binary) +# ------------------------------- + +def wrap_math_funcs_change_unit_binary(change_unit_func): + def decorator(func: Callable) -> Callable: + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y.value, *args, **kwargs), dim=change_unit_func(x.dim, y.dim)) + ) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + return _return_check_unitless( + Quantity(func(x.value, y, *args, **kwargs), dim=change_unit_func(x.dim, DIMENSIONLESS))) + elif isinstance(y, Quantity): + return _return_check_unitless( + Quantity(func(x, y.value, *args, **kwargs), dim=change_unit_func(DIMENSIONLESS, y.dim))) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + return decorator + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def multiply(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.multiply(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def cross(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.cross(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * 2 ** y) +def ldexp(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.ldexp(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def true_divide(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.true_divide(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x / y) +def divmod(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.divmod(x, y) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def convolve(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]): + return jnp.convolve(x, y) + + +# docs for the functions above +multiply.__doc__ = ''' + Multiply arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +divide.__doc__ = ''' + Divide arguments element-wise. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +cross.__doc__ = ''' + Return the cross product of two (arrays of) vectors. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +ldexp.__doc__ = ''' + Return x1 * 2**x2, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and 2 raised to the power of the unit of `y`, else an array. +''' + +true_divide.__doc__ = ''' + Returns a true division of the inputs, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +divmod.__doc__ = ''' + Return element-wise quotient and remainder simultaneously. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. +''' + +convolve.__doc__ = ''' + Returns the discrete, linear convolution of two one-dimensional sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + + +@set_module_as('brainunit.math') +def power(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike], ) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y.value), dim=x.dim ** y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.power(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.power(x.value, y), dim=x.dim ** y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.power(x, y.value), dim=x ** y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.power.__name__}') + + +@set_module_as('brainunit.math') +def floor_divide(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return the largest integer smaller or equal to the division of the inputs. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the quotient of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y.value), dim=x.dim / y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.floor_divide(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x.value, y), dim=x.dim / y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.floor_divide(x, y.value), dim=x / y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.floor_divide.__name__}') + + +@set_module_as('brainunit.math') +def float_power(x: Union[Quantity, bst.typing.ArrayLike], + y: bst.typing.ArrayLike) -> Union[Quantity, jax.Array]: + ''' + First array elements raised to powers from second array, element-wise. + + Args: + x: array_like, Quantity + y: array_like + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(y, Quantity): + assert isscalar(y), f'{jnp.float_power.__name__} only supports scalar exponent' + if isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.float_power(x.value, y), dim=x.dim ** y)) + elif isinstance(x, (jax.Array, np.ndarray)): + return jnp.float_power(x, y) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.float_power.__name__}') + + +@set_module_as('brainunit.math') +def remainder(x: Union[Quantity, bst.typing.ArrayLike], + y: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Return element-wise remainder of division. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the remainder of the unit of `x` and the unit of `y`, else an array. + ''' + if isinstance(x, Quantity) and isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y.value), dim=x.dim / y.dim)) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return jnp.remainder(x, y) + elif isinstance(x, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x.value, y), dim=x.dim % y)) + elif isinstance(y, Quantity): + return _return_check_unitless(Quantity(jnp.remainder(x, y.value), dim=x % y.dim)) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {jnp.remainder.__name__}') diff --git a/brainunit/math/_compat_numpy_funcs_indexing.py b/brainunit/math/_compat_numpy_funcs_indexing.py new file mode 100644 index 0000000..7f8d8fc --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_indexing.py @@ -0,0 +1,166 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + fail_for_dimension_mismatch, + is_unitless, + ) + +__all__ = [ + + # indexing funcs + 'where', 'tril_indices', 'tril_indices_from', 'triu_indices', + 'triu_indices_from', 'take', 'select', +] + + +# indexing funcs +# -------------- +@set_module_as('brainunit.math') +def where(condition: Union[bool, bst.typing.ArrayLike], + *args: Union[Quantity, bst.typing.ArrayLike], + **kwds) -> Union[Quantity, jax.Array]: + condition = jnp.asarray(condition) + if len(args) == 0: + # nothing to do + return jnp.where(condition, *args, **kwds) + elif len(args) == 2: + # check that x and y have the same dimensions + fail_for_dimension_mismatch( + args[0], args[1], "x and y need to have the same dimensions" + ) + new_args = [] + for arg in args: + if isinstance(arg, Quantity): + new_args.append(arg.value) + if is_unitless(args[0]): + if len(new_args) == 2: + return jnp.where(condition, *new_args, **kwds) + else: + return jnp.where(condition, *args, **kwds) + else: + # as both arguments have the same unit, just use the first one's + dimensionless_args = [jnp.asarray(arg.value) if isinstance(arg, Quantity) else jnp.asarray(arg) for arg in args] + return Quantity.with_units( + jnp.where(condition, *dimensionless_args), args[0].dim + ) + else: + # illegal number of arguments + if len(args) == 1: + raise ValueError("where() takes 2 or 3 positional arguments but 1 was given") + elif len(args) > 2: + raise TypeError("where() takes 2 or 3 positional arguments but {} were given".format(len(args))) + + +tril_indices = jnp.tril_indices +tril_indices.__doc__ = ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.tril_indices_from(arr.value, k=k) + else: + return jnp.tril_indices_from(arr, k=k) + + +triu_indices = jnp.triu_indices +triu_indices.__doc__ = ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] +''' + + +@set_module_as('brainunit.math') +def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + tuple[jax.Array]: tuple[array] + ''' + if isinstance(arr, Quantity): + return jnp.triu_indices_from(arr.value, k=k) + else: + return jnp.triu_indices_from(arr, k=k) + + +@set_module_as('brainunit.math') +def take(a: Union[Quantity, bst.typing.ArrayLike], + indices: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None) -> Union[Quantity, jax.Array]: + if isinstance(a, Quantity): + return a.take(indices, axis=axis, mode=mode) + else: + return jnp.take(a, indices, axis=axis, mode=mode) + + +@set_module_as('brainunit.math') +def select(condlist: list[Union[bst.typing.ArrayLike]], + choicelist: Union[Quantity, bst.typing.ArrayLike], + default: int = 0) -> Union[Quantity, jax.Array]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(choice, Quantity) for choice in choicelist): + if origin_any(choice.dim != choicelist[0].dim for choice in choicelist): + raise ValueError("All choices must have the same unit") + else: + return Quantity(jnp.select(condlist, [choice.value for choice in choicelist], default=default), + dim=choicelist[0].dim) + elif origin_all(isinstance(choice, (jax.Array, np.ndarray)) for choice in choicelist): + return jnp.select(condlist, choicelist, default=default) + else: + raise ValueError(f"Unsupported types : {type(condlist)} and {type(choicelist)} for select") diff --git a/brainunit/math/_compat_numpy_funcs_keep_unit.py b/brainunit/math/_compat_numpy_funcs_keep_unit.py new file mode 100644 index 0000000..4a6616e --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_keep_unit.py @@ -0,0 +1,832 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # math funcs keep unit (unary) + 'real', 'imag', 'conj', 'conjugate', 'negative', 'positive', + 'abs', 'round', 'around', 'round_', 'rint', + 'floor', 'ceil', 'trunc', 'fix', 'sum', 'nancumsum', 'nansum', + 'cumsum', 'ediff1d', 'absolute', 'fabs', 'median', + 'nanmin', 'nanmax', 'ptp', 'average', 'mean', 'std', + 'nanmedian', 'nanmean', 'nanstd', 'diff', 'modf', + + # math funcs keep unit (binary) + 'fmod', 'mod', 'copysign', 'heaviside', + 'maximum', 'minimum', 'fmax', 'fmin', 'lcm', 'gcd', + + # math funcs keep unit (n-ary) + 'interp', 'clip', +] + + +# math funcs keep unit (unary) +# ---------------------------- + +def wrap_math_funcs_keep_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return Quantity(func(x.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_unary +def real(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.real(x) + + +@wrap_math_funcs_keep_unit_unary +def imag(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.imag(x) + + +@wrap_math_funcs_keep_unit_unary +def conj(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conj(x) + + +@wrap_math_funcs_keep_unit_unary +def conjugate(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.conjugate(x) + + +@wrap_math_funcs_keep_unit_unary +def negative(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.negative(x) + + +@wrap_math_funcs_keep_unit_unary +def positive(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.positive(x) + + +@wrap_math_funcs_keep_unit_unary +def abs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.abs(x) + + +@wrap_math_funcs_keep_unit_unary +def round_(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def around(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.around(x) + + +@wrap_math_funcs_keep_unit_unary +def round(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.round(x) + + +@wrap_math_funcs_keep_unit_unary +def rint(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.rint(x) + + +@wrap_math_funcs_keep_unit_unary +def floor(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.floor(x) + + +@wrap_math_funcs_keep_unit_unary +def ceil(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ceil(x) + + +@wrap_math_funcs_keep_unit_unary +def trunc(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.trunc(x) + + +@wrap_math_funcs_keep_unit_unary +def fix(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fix(x) + + +@wrap_math_funcs_keep_unit_unary +def sum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.sum(x) + + +@wrap_math_funcs_keep_unit_unary +def nancumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nancumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def nansum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nansum(x) + + +@wrap_math_funcs_keep_unit_unary +def cumsum(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.cumsum(x) + + +@wrap_math_funcs_keep_unit_unary +def ediff1d(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ediff1d(x) + + +@wrap_math_funcs_keep_unit_unary +def absolute(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.absolute(x) + + +@wrap_math_funcs_keep_unit_unary +def fabs(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.fabs(x) + + +@wrap_math_funcs_keep_unit_unary +def median(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.median(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmin(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmin(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmax(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmax(x) + + +@wrap_math_funcs_keep_unit_unary +def ptp(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.ptp(x) + + +@wrap_math_funcs_keep_unit_unary +def average(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.average(x) + + +@wrap_math_funcs_keep_unit_unary +def mean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.mean(x) + + +@wrap_math_funcs_keep_unit_unary +def std(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.std(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmedian(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmedian(x) + + +@wrap_math_funcs_keep_unit_unary +def nanmean(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanmean(x) + + +@wrap_math_funcs_keep_unit_unary +def nanstd(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.nanstd(x) + + +@wrap_math_funcs_keep_unit_unary +def diff(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.diff(x) + + +@wrap_math_funcs_keep_unit_unary +def modf(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + return jnp.modf(x) + + +# docs for the functions above +real.__doc__ = ''' + Return the real part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +imag.__doc__ = ''' + Return the imaginary part of the complex argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conj.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +conjugate.__doc__ = ''' + Return the complex conjugate of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +negative.__doc__ = ''' + Return the negative of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +positive.__doc__ = ''' + Return the positive of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +abs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round_.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +around.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +round.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +rint.__doc__ = ''' + Round an array to the nearest integer. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +floor.__doc__ = ''' + Return the floor of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ceil.__doc__ = ''' + Return the ceiling of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +trunc.__doc__ = ''' + Return the truncated value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fix.__doc__ = ''' + Return the nearest integer towards zero. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +sum.__doc__ = ''' + Return the sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nancumsum.__doc__ = ''' + Return the cumulative sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nansum.__doc__ = ''' + Return the sum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +cumsum.__doc__ = ''' + Return the cumulative sum of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ediff1d.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +absolute.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +fabs.__doc__ = ''' + Return the absolute value of the argument. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +median.__doc__ = ''' + Return the median of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmin.__doc__ = ''' + Return the minimum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmax.__doc__ = ''' + Return the maximum of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +ptp.__doc__ = ''' + Return the range of the array elements (maximum - minimum). + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +average.__doc__ = ''' + Return the weighted average of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +mean.__doc__ = ''' + Return the mean of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +std.__doc__ = ''' + Return the standard deviation of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmedian.__doc__ = ''' + Return the median of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanmean.__doc__ = ''' + Return the mean of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +nanstd.__doc__ = ''' + Return the standard deviation of the array elements, ignoring NaNs. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +diff.__doc__ = ''' + Return the differences between consecutive elements of the array. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` is a Quantity, else an array. +''' + +modf.__doc__ = ''' + Return the fractional and integer parts of the array elements. + + Args: + x: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity tuple if `x` is a Quantity, else an array tuple. +''' + + +# math funcs keep unit (binary) +# ----------------------------- + +def wrap_math_funcs_keep_unit_binary(func): + @wraps(func) + def f(x1, x2, *args, **kwargs): + if isinstance(x1, Quantity) and isinstance(x2, Quantity): + return Quantity(func(x1.value, x2.value, *args, **kwargs), dim=x1.dim) + elif isinstance(x1, (jax.Array, np.ndarray)) and isinstance(x2, (jax.Array, np.ndarray)): + return func(x1, x2, *args, **kwargs) + else: + raise ValueError(f'Unsupported type: {type(x1)} and {type(x2)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_keep_unit_binary +def fmod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def mod(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.mod(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def copysign(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.copysign(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def heaviside(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.heaviside(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def maximum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.maximum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def minimum(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.minimum(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmax(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmax(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def fmin(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.fmin(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def lcm(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.lcm(x1, x2) + + +@wrap_math_funcs_keep_unit_binary +def gcd(x1: Union[Quantity, jax.Array], x2: Union[Quantity, jax.Array]) -> Union[Quantity, jax.Array]: + return jnp.gcd(x1, x2) + + +# docs for the functions above +fmod.__doc__ = ''' + Return the element-wise remainder of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +mod.__doc__ = ''' + Return the element-wise modulus of division. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +copysign.__doc__ = ''' + Return a copy of the first array elements with the sign of the second array. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +heaviside.__doc__ = ''' + Compute the Heaviside step function. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +maximum.__doc__ = ''' + Element-wise maximum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +minimum.__doc__ = ''' + Element-wise minimum of array elements. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmax.__doc__ = ''' + Element-wise maximum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +fmin.__doc__ = ''' + Element-wise minimum of array elements ignoring NaNs. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +lcm.__doc__ = ''' + Return the least common multiple of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + +gcd.__doc__ = ''' + Return the greatest common divisor of `x1` and `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' + + +# math funcs keep unit (n-ary) +# ---------------------------- +@set_module_as('brainunit.math') +def interp(x: Union[Quantity, bst.typing.ArrayLike], + xp: Union[Quantity, bst.typing.ArrayLike], + fp: Union[Quantity, bst.typing.ArrayLike], + left: Union[Quantity, bst.typing.ArrayLike] = None, + right: Union[Quantity, bst.typing.ArrayLike] = None, + period: Union[Quantity, bst.typing.ArrayLike] = None) -> Union[Quantity, jax.Array]: + ''' + One-dimensional linear interpolation. + + Args: + x: array_like, Quantity + xp: array_like, Quantity + fp: array_like, Quantity + left: array_like, Quantity, optional + right: array_like, Quantity, optional + period: array_like, Quantity, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if `x`, `xp`, and `fp` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(x, Quantity) or isinstance(xp, Quantity) or isinstance(fp, Quantity): + unit = x.dim if isinstance(x, Quantity) else xp.dim if isinstance(xp, Quantity) else fp.dim + if isinstance(x, Quantity): + x_value = x.value + else: + x_value = x + if isinstance(xp, Quantity): + xp_value = xp.value + else: + xp_value = xp + if isinstance(fp, Quantity): + fp_value = fp.value + else: + fp_value = fp + result = jnp.interp(x_value, xp_value, fp_value, left=left, right=right, period=period) + if unit is not None: + return Quantity(result, dim=unit) + else: + return result + + +@set_module_as('brainunit.math') +def clip(a: Union[Quantity, bst.typing.ArrayLike], + a_min: Union[Quantity, bst.typing.ArrayLike], + a_max: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: + ''' + Clip (limit) the values in an array. + + Args: + a: array_like, Quantity + a_min: array_like, Quantity + a_max: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `a`, `a_min`, and `a_max` are Quantities that have the same unit, else an array. + ''' + unit = None + if isinstance(a, Quantity) or isinstance(a_min, Quantity) or isinstance(a_max, Quantity): + unit = a.dim if isinstance(a, Quantity) else a_min.dim if isinstance(a_min, Quantity) else a_max.dim + if isinstance(a, Quantity): + a_value = a.value + else: + a_value = a + if isinstance(a_min, Quantity): + a_min_value = a_min.value + else: + a_min_value = a_min + if isinstance(a_max, Quantity): + a_max_value = a_max.value + else: + a_max_value = a_max + result = jnp.clip(a_value, a_min_value, a_max_value) + if unit is not None: + return Quantity(result, dim=unit) + else: + return result diff --git a/brainunit/math/_compat_numpy_funcs_logic.py b/brainunit/math/_compat_numpy_funcs_logic.py new file mode 100644 index 0000000..e7d69e7 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_logic.py @@ -0,0 +1,343 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # logic funcs (unary) + 'all', 'any', 'logical_not', + + # logic funcs (binary) + 'equal', 'not_equal', 'greater', 'greater_equal', 'less', 'less_equal', + 'array_equal', 'isclose', 'allclose', 'logical_and', + 'logical_or', 'logical_xor', "alltrue", 'sometrue', +] + + +# logic funcs (unary) +# ------------------- + +def wrap_logic_func_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + raise ValueError(f'Expected booleans, got {x}') + elif isinstance(x, (jax.Array, np.ndarray)): + return func(x, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_unary +def all(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.all(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def any(x: Union[Quantity, bst.typing.ArrayLike], axis: Optional[int] = None, + out: Optional[Array] = None, keepdims: bool = False, + where: Optional[Array] = None) -> Union[bool, Array]: + return jnp.any(x, axis=axis, out=out, keepdims=keepdims, where=where) + + +@wrap_logic_func_unary +def logical_not(x: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.logical_not(x) + + +alltrue = all +sometrue = any + +# docs for functions above +all.__doc__ = ''' + Test whether all array elements along a given axis evaluate to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +any.__doc__ = ''' + Test whether any array element along a given axis evaluates to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_not.__doc__ = ''' + Compute the truth value of NOT x element-wise. + + Args: + x: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + + +# logic funcs (binary) +# -------------------- + +def wrap_logic_func_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return func(x.value, y.value, *args, **kwargs) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + else: + raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_logic_func_binary +def equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.equal(x, y) + + +@wrap_logic_func_binary +def not_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.not_equal(x, y) + + +@wrap_logic_func_binary +def greater(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater(x, y) + + +@wrap_logic_func_binary +def greater_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.greater_equal(x, y) + + +@wrap_logic_func_binary +def less(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less(x, y) + + +@wrap_logic_func_binary +def less_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[bool, Array]: + return jnp.less_equal(x, y) + + +@wrap_logic_func_binary +def array_equal(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.array_equal(x, y) + + +@wrap_logic_func_binary +def isclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.isclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def allclose(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike], + rtol: float = 1e-05, atol: float = 1e-08, equal_nan: bool = False) -> Union[bool, Array]: + return jnp.allclose(x, y, rtol=rtol, atol=atol, equal_nan=equal_nan) + + +@wrap_logic_func_binary +def logical_and(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_and(x, y) + + +@wrap_logic_func_binary +def logical_or(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_or(x, y) + + +@wrap_logic_func_binary +def logical_xor(x: Union[Quantity, bst.typing.ArrayLike], y: Union[Quantity, bst.typing.ArrayLike]) -> Union[ + bool, Array]: + return jnp.logical_xor(x, y) + + +# docs for functions above +equal.__doc__ = ''' + Return (x == y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +not_equal.__doc__ = ''' + Return (x != y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater.__doc__ = ''' + Return (x > y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +greater_equal.__doc__ = ''' + Return (x >= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less.__doc__ = ''' + Return (x < y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +less_equal.__doc__ = ''' + Return (x <= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +array_equal.__doc__ = ''' + Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[bool, jax.Array]: bool or array +''' + +isclose.__doc__ = ''' + Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +allclose.__doc__ = ''' + Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + bool: boolean result +''' + +logical_and.__doc__ = ''' + Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_or.__doc__ = ''' + Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' + +logical_xor.__doc__ = ''' + Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + Union[bool, jax.Array]: bool or array +''' diff --git a/brainunit/math/_compat_numpy_funcs_match_unit.py b/brainunit/math/_compat_numpy_funcs_match_unit.py new file mode 100644 index 0000000..d9926ad --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_match_unit.py @@ -0,0 +1,108 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union) + +import jax +import jax.numpy as jnp +import numpy as np +from jax import Array + +from .._base import (Quantity, + fail_for_dimension_mismatch, + ) + +__all__ = [ + # math funcs match unit (binary) + 'add', 'subtract', 'nextafter', +] + + +# math funcs match unit (binary) +# ------------------------------ + +def wrap_math_funcs_match_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity) and isinstance(y, Quantity): + fail_for_dimension_mismatch(x, y) + return Quantity(func(x.value, y.value, *args, **kwargs), dim=x.dim) + elif isinstance(x, (jax.Array, np.ndarray)) and isinstance(y, (jax.Array, np.ndarray)): + return func(x, y, *args, **kwargs) + elif isinstance(x, Quantity): + if x.is_unitless: + return Quantity(func(x.value, y, *args, **kwargs), dim=x.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + elif isinstance(y, Quantity): + if y.is_unitless: + return Quantity(func(x, y.value, *args, **kwargs), dim=y.dim) + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + else: + raise ValueError(f'Unsupported types : {type(x)} abd {type(y)} for {func.__name__}') + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_match_unit_binary +def add(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.add(x, y) + + +@wrap_math_funcs_match_unit_binary +def subtract(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.subtract(x, y) + + +@wrap_math_funcs_match_unit_binary +def nextafter(x: Union[Quantity, Array], y: Union[Quantity, Array]) -> Union[Quantity, Array]: + return jnp.nextafter(x, y) + + +# docs for the functions above +add.__doc__ = ''' + Add arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +subtract.__doc__ = ''' + Subtract arguments element-wise. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x` and `y` are Quantities that have the same unit, else an array. +''' + +nextafter.__doc__ = ''' + Return the next floating-point value after `x1` towards `x2`. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if `x1` and `x2` are Quantities that have the same unit, else an array. +''' diff --git a/brainunit/math/_compat_numpy_funcs_remove_unit.py b/brainunit/math/_compat_numpy_funcs_remove_unit.py new file mode 100644 index 0000000..afea533 --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_remove_unit.py @@ -0,0 +1,191 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps +from typing import (Union, Optional) + +import jax.numpy as jnp +from jax import Array + +from .._base import (Quantity, + ) + +__all__ = [ + + # math funcs remove unit (unary) + 'signbit', 'sign', 'histogram', 'bincount', + + # math funcs remove unit (binary) + 'corrcoef', 'correlate', 'cov', 'digitize', +] + + +# math funcs remove unit (unary) +# ------------------------------ +def wrap_math_funcs_remove_unit_unary(func): + @wraps(func) + def f(x, *args, **kwargs): + if isinstance(x, Quantity): + return func(x.value, *args, **kwargs) + else: + return func(x, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_unary +def signbit(x: Union[Array, Quantity]) -> Array: + return jnp.signbit(x) + + +@wrap_math_funcs_remove_unit_unary +def sign(x: Union[Array, Quantity]) -> Array: + return jnp.sign(x) + + +@wrap_math_funcs_remove_unit_unary +def histogram(x: Union[Array, Quantity]) -> tuple[Array, Array]: + return jnp.histogram(x) + + +@wrap_math_funcs_remove_unit_unary +def bincount(x: Union[Array, Quantity]) -> Array: + return jnp.bincount(x) + + +# docs for the functions above +signbit.__doc__ = ''' + Returns element-wise True where signbit is set (less than zero). + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +sign.__doc__ = ''' + Returns the sign of each element in the input array. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + +histogram.__doc__ = ''' + Compute the histogram of a set of data. + + Args: + x: array_like, Quantity + + Returns: + tuple[jax.Array]: Tuple of arrays (hist, bin_edges) +''' + +bincount.__doc__ = ''' + Count number of occurrences of each value in array of non-negative integers. + + Args: + x: array_like, Quantity + + Returns: + jax.Array: an array +''' + + +# math funcs remove unit (binary) +# ------------------------------- +def wrap_math_funcs_remove_unit_binary(func): + @wraps(func) + def f(x, y, *args, **kwargs): + if isinstance(x, Quantity): + x_value = x.value + if isinstance(y, Quantity): + y_value = y.value + if isinstance(x, Quantity) or isinstance(y, Quantity): + return func(jnp.array(x_value), jnp.array(y_value), *args, **kwargs) + else: + return func(x, y, *args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_math_funcs_remove_unit_binary +def corrcoef(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.corrcoef(x, y) + + +@wrap_math_funcs_remove_unit_binary +def correlate(x: Union[Array, Quantity], y: Union[Array, Quantity]) -> Array: + return jnp.correlate(x, y) + + +@wrap_math_funcs_remove_unit_binary +def cov(x: Union[Array, Quantity], y: Optional[Union[Array, Quantity]] = None) -> Array: + return jnp.cov(x, y) + + +@wrap_math_funcs_remove_unit_binary +def digitize(x: Union[Array, Quantity], bins: Union[Array, Quantity]) -> Array: + return jnp.digitize(x, bins) + + +# docs for the functions above +corrcoef.__doc__ = ''' + Return Pearson product-moment correlation coefficients. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +correlate.__doc__ = ''' + Cross-correlation of two sequences. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + jax.Array: an array +''' + +cov.__doc__ = ''' + Covariance matrix. + + Args: + x: array_like, Quantity + y: array_like, Quantity (optional, if not provided, x is assumed to be a 2D array) + + Returns: + jax.Array: an array +''' + +digitize.__doc__ = ''' + Return the indices of the bins to which each value in input array belongs. + + Args: + x: array_like, Quantity + bins: array_like, Quantity + + Returns: + jax.Array: an array +''' diff --git a/brainunit/math/_compat_numpy_funcs_window.py b/brainunit/math/_compat_numpy_funcs_window.py new file mode 100644 index 0000000..776450f --- /dev/null +++ b/brainunit/math/_compat_numpy_funcs_window.py @@ -0,0 +1,69 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from functools import wraps + +import jax.numpy as jnp +from jax import Array + +__all__ = [ + + # window funcs + 'bartlett', 'blackman', 'hamming', 'hanning', 'kaiser', +] + + +# window funcs +# ------------ + +def wrap_window_funcs(func): + @wraps(func) + def f(*args, **kwargs): + return func(*args, **kwargs) + + f.__module__ = 'brainunit.math' + return f + + +@wrap_window_funcs +def bartlett(M: int) -> Array: + return jnp.bartlett(M) + + +@wrap_window_funcs +def blackman(M: int) -> Array: + return jnp.blackman(M) + + +@wrap_window_funcs +def hamming(M: int) -> Array: + return jnp.hamming(M) + + +@wrap_window_funcs +def hanning(M: int) -> Array: + return jnp.hanning(M) + + +@wrap_window_funcs +def kaiser(M: int, beta: float) -> Array: + return jnp.kaiser(M, beta) + + +# docs for functions above +bartlett.__doc__ = jnp.bartlett.__doc__ +blackman.__doc__ = jnp.blackman.__doc__ +hamming.__doc__ = jnp.hamming.__doc__ +hanning.__doc__ = jnp.hanning.__doc__ +kaiser.__doc__ = jnp.kaiser.__doc__ diff --git a/brainunit/math/_compat_numpy_get_attribute.py b/brainunit/math/_compat_numpy_get_attribute.py new file mode 100644 index 0000000..03bec0d --- /dev/null +++ b/brainunit/math/_compat_numpy_get_attribute.py @@ -0,0 +1,215 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +from brainstate._utils import set_module_as + +from .._base import (Quantity, + ) + +__all__ = [ + # getting attribute funcs + 'ndim', 'isreal', 'isscalar', 'isfinite', 'isinf', + 'isnan', 'shape', 'size', +] + + +@set_module_as('brainunit.math') +def ndim(a: Union[Quantity, bst.typing.ArrayLike]) -> int: + ''' + Return the number of dimensions of an array. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: int + ''' + if isinstance(a, Quantity): + return a.ndim + else: + return jnp.ndim(a) + + +@set_module_as('brainunit.math') +def isreal(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return True if the input array is real. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isreal + else: + return jnp.isreal(a) + + +@set_module_as('brainunit.math') +def isscalar(a: Union[Quantity, bst.typing.ArrayLike]) -> bool: + ''' + Return True if the input is a scalar. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isscalar + else: + return jnp.isscalar(a) + + +@set_module_as('brainunit.math') +def isfinite(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is finite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isfinite + else: + return jnp.isfinite(a) + + +@set_module_as('brainunit.math') +def isinf(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is infinite or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isinf + else: + return jnp.isinf(a) + + +@set_module_as('brainunit.math') +def isnan(a: Union[Quantity, bst.typing.ArrayLike]) -> jax.Array: + ''' + Return each element of the array is NaN or not. + + Args: + a: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: boolean array + ''' + if isinstance(a, Quantity): + return a.isnan + else: + return jnp.isnan(a) + + +@set_module_as('brainunit.math') +def shape(a: Union[Quantity, bst.typing.ArrayLike]) -> tuple[int, ...]: + """ + Return the shape of an array. + + Parameters + ---------- + a : array_like + Input array. + + Returns + ------- + shape : tuple of ints + The elements of the shape tuple give the lengths of the + corresponding array dimensions. + + See Also + -------- + len : ``len(a)`` is equivalent to ``np.shape(a)[0]`` for N-D arrays with + ``N>=1``. + ndarray.shape : Equivalent array method. + + Examples + -------- + >>> brainunit.math.shape(brainunit.math.eye(3)) + (3, 3) + >>> brainunit.math.shape([[1, 3]]) + (1, 2) + >>> brainunit.math.shape([0]) + (1,) + >>> brainunit.math.shape(0) + () + + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + return a.shape + else: + return np.shape(a) + + +@set_module_as('brainunit.math') +def size(a: Union[Quantity, bst.typing.ArrayLike], axis: int = None) -> int: + """ + Return the number of elements along a given axis. + + Parameters + ---------- + a : array_like + Input data. + axis : int, optional + Axis along which the elements are counted. By default, give + the total number of elements. + + Returns + ------- + element_count : int + Number of elements along the specified axis. + + See Also + -------- + shape : dimensions of array + Array.shape : dimensions of array + Array.size : number of elements in array + + Examples + -------- + >>> a = Quantity([[1,2,3], [4,5,6]]) + >>> brainunit.math.size(a) + 6 + >>> brainunit.math.size(a, 1) + 3 + >>> brainunit.math.size(a, 0) + 2 + """ + if isinstance(a, (Quantity, jax.Array, np.ndarray)): + if axis is None: + return a.size + else: + return a.shape[axis] + else: + return np.size(a, axis=axis) diff --git a/brainunit/math/_compat_numpy_linear_algebra.py b/brainunit/math/_compat_numpy_linear_algebra.py new file mode 100644 index 0000000..88f27e9 --- /dev/null +++ b/brainunit/math/_compat_numpy_linear_algebra.py @@ -0,0 +1,149 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from typing import (Union) + +import jax.numpy as jnp +from jax import Array + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from .._base import (Quantity, + ) + +__all__ = [ + + # linear algebra + 'dot', 'vdot', 'inner', 'outer', 'kron', 'matmul', 'trace', + +] + + + + +# linear algebra +# -------------- + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def dot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.dot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def vdot(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.vdot(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def inner(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.inner(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def outer(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.outer(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def kron(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.kron(a, b) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def matmul(a: Union[Array, Quantity], b: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.matmul(a, b) + + +@wrap_math_funcs_keep_unit_unary +def trace(a: Union[Array, Quantity]) -> Union[Array, Quantity]: + return jnp.trace(a) + + +# docs for functions above +dot.__doc__ = ''' + Dot product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +vdot.__doc__ = ''' + Return the dot product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +inner.__doc__ = ''' + Inner product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +outer.__doc__ = ''' + Compute the outer product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +kron.__doc__ = ''' + Compute the Kronecker product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +matmul.__doc__ = ''' + Matrix product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + Union[jax.Array, Quantity]: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +trace.__doc__ = ''' + Return the sum of the diagonal elements of a matrix or quantity. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + Union[jax.Array, Quantity]: Quantity if the input is a Quantity, else an array. +''' diff --git a/brainunit/math/_compat_numpy_misc.py b/brainunit/math/_compat_numpy_misc.py new file mode 100644 index 0000000..0deb591 --- /dev/null +++ b/brainunit/math/_compat_numpy_misc.py @@ -0,0 +1,354 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +from collections.abc import Sequence +from typing import (Callable, Union, Tuple) + +import brainstate as bst +import jax +import jax.numpy as jnp +import numpy as np +import opt_einsum +from brainstate._utils import set_module_as +from jax import Array +from jax._src.numpy.lax_numpy import _einsum + +from ._compat_numpy_funcs_change_unit import wrap_math_funcs_change_unit_binary +from ._compat_numpy_funcs_keep_unit import wrap_math_funcs_keep_unit_unary +from ._utils import _compatible_with_quantity +from .._base import (DIMENSIONLESS, + Quantity, + fail_for_dimension_mismatch, + is_unitless, + get_unit, ) + +__all__ = [ + + # constants + 'e', 'pi', 'inf', + + # data types + 'dtype', 'finfo', 'iinfo', + + # more + 'broadcast_arrays', 'broadcast_shapes', + 'einsum', 'gradient', 'intersect1d', 'nan_to_num', 'nanargmax', 'nanargmin', + 'rot90', 'tensordot', +] + +# constants +# --------- +e = jnp.e +pi = jnp.pi +inf = jnp.inf + +# data types +# ---------- +dtype = jnp.dtype + + +@set_module_as('brainunit.math') +def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: + if isinstance(a, Quantity): + return jnp.finfo(a.value) + else: + return jnp.finfo(a) + + +@set_module_as('brainunit.math') +def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: + if isinstance(a, Quantity): + return jnp.iinfo(a.value) + else: + return jnp.iinfo(a) + + +# more +# ---- +@set_module_as('brainunit.math') +def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, list[Array]]: + from builtins import all as origin_all + from builtins import any as origin_any + if origin_all(isinstance(arg, Quantity) for arg in args): + if origin_any(arg.dim != args[0].dim for arg in args): + raise ValueError("All arguments must have the same unit") + return Quantity(jnp.broadcast_arrays(*[arg.value for arg in args]), dim=args[0].dim) + elif origin_all(isinstance(arg, (jax.Array, np.ndarray)) for arg in args): + return jnp.broadcast_arrays(*args) + else: + raise ValueError(f"Unsupported types : {type(args)} for broadcast_arrays") + + +broadcast_shapes = jnp.broadcast_shapes + + +@set_module_as('brainunit.math') +def einsum( + subscripts: str, + /, + *operands: Union[Quantity, jax.Array], + out: None = None, + optimize: Union[str, bool] = "optimal", + precision: jax.lax.PrecisionLike = None, + preferred_element_type: Union[jax.typing.DTypeLike, None] = None, + _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, +) -> Union[jax.Array, Quantity]: + ''' + Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays or quantities corresponding to the subscripts. + optimize: determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + ''' + operands = (subscripts, *operands) + if out is not None: + raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") + spec = operands[0] if isinstance(operands[0], str) else None + optimize = 'optimal' if optimize is True else optimize + + # Allow handling of shape polymorphism + non_constant_dim_types = { + type(d) for op in operands if not isinstance(op, str) + for d in np.shape(op) if not jax.core.is_constant_dim(d) + } + if not non_constant_dim_types: + contract_path = opt_einsum.contract_path + else: + from jax._src.numpy.lax_numpy import _default_poly_einsum_handler + contract_path = _default_poly_einsum_handler + + operands, contractions = contract_path( + *operands, einsum_call=True, use_blas=True, optimize=optimize) + + unit = None + for i in range(len(contractions) - 1): + if contractions[i][4] == 'False': + + fail_for_dimension_mismatch( + Quantity([], dim=unit), operands[i + 1], 'einsum' + ) + elif contractions[i][4] == 'DOT' or \ + contractions[i][4] == 'TDOT' or \ + contractions[i][4] == 'GEMM' or \ + contractions[i][4] == 'OUTER/EINSUM': + if i == 0: + if isinstance(operands[i], Quantity) and isinstance(operands[i + 1], Quantity): + unit = operands[i].dim * operands[i + 1].dim + elif isinstance(operands[i], Quantity): + unit = operands[i].dim + elif isinstance(operands[i + 1], Quantity): + unit = operands[i + 1].dim + else: + if isinstance(operands[i + 1], Quantity): + unit = unit * operands[i + 1].dim + + contractions = tuple((a, frozenset(b), c) for a, b, c, *_ in contractions) + + einsum = jax.jit(_einsum, static_argnums=(1, 2, 3, 4), inline=True) + if spec is not None: + einsum = jax.named_call(einsum, name=spec) + operands = [op.value if isinstance(op, Quantity) else op for op in operands] + r = einsum(operands, contractions, precision, # type: ignore[operator] + preferred_element_type, _dot_general) + if unit is not None: + return Quantity(r, dim=unit) + else: + return r + + +@set_module_as('brainunit.math') +def gradient( + f: Union[bst.typing.ArrayLike, Quantity], + *varargs: Union[bst.typing.ArrayLike, Quantity], + axis: Union[int, Sequence[int], None] = None, + edge_order: Union[int, None] = None, +) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + ''' + Computes the gradient of a scalar field. + + Args: + f: input array. + *varargs: list of scalar fields to compute the gradient. + axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. + edge_order: order of the edge used for the finite difference computation. The default is 1. + + Returns: + array containing the gradient of the scalar field. + ''' + if edge_order is not None: + raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") + + if len(varargs) == 0: + if isinstance(f, Quantity) and not is_unitless(f): + return Quantity(jnp.gradient(f.value, axis=axis), dim=f.dim) + else: + return jnp.gradient(f) + elif len(varargs) == 1: + unit = get_unit(f) / get_unit(varargs[0]) + if unit is None or unit == DIMENSIONLESS: + return jnp.gradient(f, varargs[0], axis=axis) + else: + return [Quantity(r, dim=unit) for r in jnp.gradient(f.value, varargs[0].value, axis=axis)] + else: + unit_list = [get_unit(f) / get_unit(v) for v in varargs] + f = f.value if isinstance(f, Quantity) else f + varargs = [v.value if isinstance(v, Quantity) else v for v in varargs] + result_list = jnp.gradient(f, *varargs, axis=axis) + return [Quantity(r, dim=unit) if unit is not None else r for r, unit in zip(result_list, unit_list)] + + +@set_module_as('brainunit.math') +def intersect1d( + ar1: Union[bst.typing.ArrayLike], + ar2: Union[bst.typing.ArrayLike], + assume_unique: bool = False, + return_indices: bool = False +) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: + ''' + Find the intersection of two arrays. + + Args: + ar1: input array. + ar2: input array. + assume_unique: if True, the input arrays are both assumed to be unique. + return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. + + Returns: + array containing the intersection of the two arrays. + ''' + fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') + unit = None + if isinstance(ar1, Quantity): + unit = ar1.dim + ar1 = ar1.value if isinstance(ar1, Quantity) else ar1 + ar2 = ar2.value if isinstance(ar2, Quantity) else ar2 + result = jnp.intersect1d(ar1, ar2, assume_unique=assume_unique, return_indices=return_indices) + if return_indices: + if unit is not None: + return (Quantity(result[0], dim=unit), result[1], result[2]) + else: + return result + else: + if unit is not None: + return Quantity(result, dim=unit) + else: + return result + + +@wrap_math_funcs_keep_unit_unary +def nan_to_num(x: Union[bst.typing.ArrayLike, Quantity], nan: float = 0.0, posinf: float = jnp.inf, + neginf: float = -jnp.inf) -> Union[jax.Array, Quantity]: + return jnp.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) + + +@wrap_math_funcs_keep_unit_unary +def rot90(m: Union[bst.typing.ArrayLike, Quantity], k: int = 1, axes: Tuple[int, int] = (0, 1)) -> Union[ + jax.Array, Quantity]: + return jnp.rot90(m, k=k, axes=axes) + + +@wrap_math_funcs_change_unit_binary(lambda x, y: x * y) +def tensordot(a: Union[bst.typing.ArrayLike, Quantity], b: Union[bst.typing.ArrayLike, Quantity], + axes: Union[int, Tuple[int, int]] = 2) -> Union[jax.Array, Quantity]: + return jnp.tensordot(a, b, axes=axes) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmax(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmax(a, axis=axis) + + +@_compatible_with_quantity(return_quantity=False) +def nanargmin(a: Union[bst.typing.ArrayLike, Quantity], axis: int = None) -> jax.Array: + return jnp.nanargmin(a, axis=axis) + + +# docs for functions above +nan_to_num.__doc__ = ''' + Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. + + Args: + x: input array. + nan: value to replace NaNs with. + posinf: value to replace positive infinity with. + neginf: value to replace negative infinity with. + + Returns: + array with NaNs replaced by zero and infinities replaced by large finite numbers. +''' + +nanargmax.__doc__ = ''' + Return the index of the maximum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the maximum value in the array. +''' + +nanargmin.__doc__ = ''' + Return the index of the minimum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the minimum value in the array. +''' + +rot90.__doc__ = ''' + Rotate an array by 90 degrees in the plane specified by axes. + + Args: + m: array like, Quantity. + k: number of times the array is rotated by 90 degrees. + axes: plane of rotation. Default is the last two axes. + + Returns: + rotated array. +''' + +tensordot.__doc__ = ''' + Compute tensor dot product along specified axes for arrays. + + Args: + a: array like, Quantity. + b: array like, Quantity. + axes: axes along which to compute the tensor dot product. + + Returns: + tensor dot product of the two arrays. +''' diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 258cd85..8e39796 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -24,6 +24,7 @@ from brainunit import DimensionMismatchError from brainunit._base import Quantity from brainunit._unit_shortcuts import ms, mV +from brainunit._unit_common import second bst.environ.set(precision=64) @@ -31,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.unit == unit.dim, f"Unit mismatch: {q.unit} != {unit}" + assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" @@ -44,6 +45,10 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) + q = bu.math.full(3, 4, unit=second) + self.assertEqual(q.shape, (3,)) + assert_quantity(q, result, second) + def test_eye(self): result = bu.math.eye(3) self.assertEqual(result.shape, (3, 3)) @@ -87,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4 * bu.second) + result_q = bu.math.full_like(q, 4, unit=bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): @@ -97,7 +102,7 @@ def test_diag(self): self.assertTrue(jnp.all(result == jnp.diag(array))) q = [1, 2, 3] * bu.second - result_q = bu.math.diag(q) + result_q = bu.math.diag(q, unit=bu.second) assert_quantity(result_q, jnp.diag(jnp.array([1, 2, 3])), bu.second) def test_tril(self): @@ -107,7 +112,7 @@ def test_tril(self): self.assertTrue(jnp.all(result == jnp.tril(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.tril(q) + result_q = bu.math.tril(q, unit=bu.second) assert_quantity(result_q, jnp.tril(jnp.ones((3, 3))), bu.second) def test_triu(self): @@ -117,7 +122,7 @@ def test_triu(self): self.assertTrue(jnp.all(result == jnp.triu(array))) q = jnp.ones((3, 3)) * bu.second - result_q = bu.math.triu(q) + result_q = bu.math.triu(q, unit=bu.second) assert_quantity(result_q, jnp.triu(jnp.ones((3, 3))), bu.second) def test_empty_like(self): @@ -1706,7 +1711,7 @@ def test_argsort(self): q = [2, 3, 1] * bu.second result_q = bu.math.argsort(q) expected_q = jnp.argsort(jnp.array([2, 3, 1])) - assert jnp.all(result_q == expected_q) + assert_quantity(result_q, expected_q, bu.second) def test_argmax(self): array = jnp.array([2, 3, 1]) @@ -1810,22 +1815,6 @@ def test_invert(self): q = [0b1100] * bu.second result_q = bu.math.invert(q) - def test_left_shift(self): - result = bu.math.left_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.left_shift(q, 2) - - def test_right_shift(self): - result = bu.math.right_shift(jnp.array([0b0100]), 2) - self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b0100]), 2))) - - with pytest.raises(ValueError): - q = [0b0100] * bu.second - result_q = bu.math.right_shift(q, 2) - class TestElementwiseBitOperationsBinary(unittest.TestCase): @@ -1856,6 +1845,22 @@ def test_bitwise_xor(self): q2 = [0b1010] * bu.second result_q = bu.math.bitwise_xor(q1, q2) + def test_left_shift(self): + result = bu.math.left_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.left_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.left_shift(q, 2) + + def test_right_shift(self): + result = bu.math.right_shift(jnp.array([0b1100]), 2) + self.assertTrue(jnp.all(result == jnp.right_shift(jnp.array([0b1100]), 2))) + + with pytest.raises(ValueError): + q = [0b1100] * bu.second + result_q = bu.math.right_shift(q, 2) + class TestLogicFuncsUnary(unittest.TestCase): def test_all(self): diff --git a/brainunit/math/_others.py b/brainunit/math/_others.py index 720edba..d316eb4 100644 --- a/brainunit/math/_others.py +++ b/brainunit/math/_others.py @@ -16,7 +16,7 @@ import brainstate as bst -from ._compat_numpy import wrap_math_funcs_only_accept_unitless_unary +from ._compat_numpy_funcs_accept_unitless import wrap_math_funcs_only_accept_unitless_unary __all__ = [ 'exprel', diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index b934c1f..61242e0 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -15,8 +15,9 @@ import functools -from typing import Callable +from typing import Callable, Union +import jax from jax.tree_util import tree_map from .._base import Quantity @@ -31,74 +32,64 @@ def _is_leaf(a): def _compatible_with_quantity( - fun: Callable, return_quantity: bool = True, - module: str = '' ): - func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun - - @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs): - unit = None - if isinstance(args[0], Quantity): - unit = args[0].dim - elif isinstance(args[0], tuple): - if len(args[0]) == 1: - unit = args[0][0].dim if isinstance(args[0][0], Quantity) else None - elif len(args[0]) == 2: - # check all args[0] have the same unit - if all(isinstance(a, Quantity) for a in args[0]): - if all(a.dim == args[0][0].dim for a in args[0]): - unit = args[0][0].dim + def decorator(fun: Callable) -> Callable: + @functools.wraps(fun) + def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: + unit = None + if isinstance(args[0], Quantity): + unit = args[0].dim + elif isinstance(args[0], tuple): + if len(args[0]) == 1: + unit = args[0][0].dim if isinstance(args[0][0], Quantity) else None + elif len(args[0]) == 2: + # check all args[0] have the same unit + if all(isinstance(a, Quantity) for a in args[0]): + if all(a.dim == args[0][0].dim for a in args[0]): + unit = args[0][0].dim + else: + raise ValueError(f'Units do not match for {fun.__name__} operation.') + elif all(not isinstance(a, Quantity) for a in args[0]): + unit = None else: raise ValueError(f'Units do not match for {fun.__name__} operation.') - elif all(not isinstance(a, Quantity) for a in args[0]): - unit = None - else: - raise ValueError(f'Units do not match for {fun.__name__} operation.') - args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) - out = None - if len(kwargs): - # compatible with PyTorch syntax - if 'dim' in kwargs: - kwargs['axis'] = kwargs.pop('dim') - if 'keepdim' in kwargs: - kwargs['keepdims'] = kwargs.pop('keepdim') - # compatible with TensorFlow syntax - if 'keep_dims' in kwargs: - kwargs['keepdims'] = kwargs.pop('keep_dims') - # compatible with NumPy/PyTorch syntax - if 'out' in kwargs: - out = kwargs.pop('out') - if not isinstance(out, Quantity): - raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') - # format - kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) + args = tree_map(_as_jax_array_, args, is_leaf=_is_leaf) + out = None + if len(kwargs): + # compatible with PyTorch syntax + if 'dim' in kwargs: + kwargs['axis'] = kwargs.pop('dim') + if 'keepdim' in kwargs: + kwargs['keepdims'] = kwargs.pop('keepdim') + # compatible with TensorFlow syntax + if 'keep_dims' in kwargs: + kwargs['keepdims'] = kwargs.pop('keep_dims') + # compatible with NumPy/PyTorch syntax + if 'out' in kwargs: + out = kwargs.pop('out') + if not isinstance(out, Quantity): + raise TypeError(f'"out" must be an instance of brainpy Array. While we got {type(out)}') + # format + kwargs = tree_map(_as_jax_array_, kwargs, is_leaf=_is_leaf) - if not return_quantity: - unit = None + if not return_quantity: + unit = None - r = fun(*args, **kwargs) - if unit is not None: - if isinstance(r, (list, tuple)): - return [Quantity(rr, dim=unit) for rr in r] - else: - if out is None: - return Quantity(r, dim=unit) + r = fun(*args, **kwargs) + if unit is not None: + if isinstance(r, (list, tuple)): + return [Quantity(rr, dim=unit) for rr in r] else: - out.value = r - if out is None: - return r - else: - out.value = r + if out is None: + return Quantity(r, dim=unit) + else: + out.value = r + if out is None: + return r + else: + out.value = r - new_fun.__doc__ = ( - f'Similar to ``jax.numpy.{module + fun.__name__}`` function, ' - f'while it is compatible with brainpy Array/Variable. \n\n' - f'Note that this function is also compatible with:\n\n' - f'1. NumPy or PyTorch syntax when receiving ``out`` argument.\n' - f'2. PyTorch syntax when receiving ``keepdim`` or ``dim`` argument.\n' - f'3. TensorFlow syntax when receiving ``keep_dims`` argument.' - ) + return new_fun - return new_fun + return decorator diff --git a/docs/apis/brainunit.math.rst b/docs/apis/brainunit.math.rst index a6ab19c..7d3601d 100644 --- a/docs/apis/brainunit.math.rst +++ b/docs/apis/brainunit.math.rst @@ -1,9 +1,398 @@ ``brainunit.math`` module -========================== +========================= -.. currentmodule:: brainunit.math -.. automodule:: brainunit.math +.. currentmodule:: brainunit.math +.. automodule:: brainunit.math + +Array Creation +-------------- .. autosummary:: :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + full + full_like + eye + identity + diag + tri + tril + triu + empty + empty_like + ones + ones_like + zeros + zeros_like + array + asarray + arange + linspace + logspace + fill_diagonal + array_split + meshgrid + vander + + +Array Manipulation +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reshape + moveaxis + transpose + swapaxes + row_stack + concatenate + stack + vstack + hstack + dstack + column_stack + split + dsplit + hsplit + vsplit + tile + repeat + unique + append + flip + fliplr + flipud + roll + atleast_1d + atleast_2d + atleast_3d + expand_dims + squeeze + sort + argsort + argmax + argmin + argwhere + nonzero + flatnonzero + searchsorted + extract + count_nonzero + max + min + amax + amin + block + compress + diagflat + diagonal + choose + ravel + + +Functions Accepting Unitless +---------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + exp + exp2 + expm1 + log + log10 + log1p + log2 + arccos + arccosh + arcsin + arcsinh + arctan + arctanh + cos + cosh + sin + sinc + sinh + tan + tanh + deg2rad + rad2deg + degrees + radians + angle + percentile + nanpercentile + quantile + nanquantile + hypot + arctan2 + logaddexp + logaddexp2 + + +Functions with Bitwise Operations +--------------------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bitwise_not + invert + bitwise_and + bitwise_or + bitwise_xor + left_shift + right_shift + + +Functions Changing Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + reciprocal + prod + product + nancumprod + nanprod + cumprod + cumproduct + var + nanvar + cbrt + square + frexp + sqrt + multiply + divide + power + cross + ldexp + true_divide + floor_divide + float_power + divmod + remainder + convolve + + +Indexing Functions +------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + where + tril_indices + tril_indices_from + triu_indices + triu_indices_from + take + select + + +Functions Keeping Unit +---------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + real + imag + conj + conjugate + negative + positive + abs + round + around + round_ + rint + floor + ceil + trunc + fix + sum + nancumsum + nansum + cumsum + ediff1d + absolute + fabs + median + nanmin + nanmax + ptp + average + mean + std + nanmedian + nanmean + nanstd + diff + modf + fmod + mod + copysign + heaviside + maximum + minimum + fmax + fmin + lcm + gcd + interp + clip + + +Logical Functions +----------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + all + any + logical_not + equal + not_equal + greater + greater_equal + less + less_equal + array_equal + isclose + allclose + logical_and + logical_or + logical_xor + alltrue + sometrue + + +Functions Matching Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + add + subtract + nextafter + + +Functions Removing Unit +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + signbit + sign + histogram + bincount + corrcoef + correlate + cov + digitize + + +Window Functions +---------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + bartlett + blackman + hamming + hanning + kaiser + + +Get Attribute Functions +----------------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + ndim + isreal + isscalar + isfinite + isinf + isnan + shape + size + + +Linear Algebra Functions +------------------------ + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + dot + vdot + inner + outer + kron + matmul + trace + + +More Functions +-------------- + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: classtemplate.rst + + finfo + iinfo + broadcast_arrays + broadcast_shapes + einsum + gradient + intersect1d + nan_to_num + nanargmax + nanargmin + rot90 + tensordot + dtype + e + pi + inf + diff --git a/docs/auto_generater.py b/docs/auto_generater.py index b192b88..7b76528 100644 --- a/docs/auto_generater.py +++ b/docs/auto_generater.py @@ -226,7 +226,6 @@ def _write_subsections_v4(module_path, fout.write(f'.. currentmodule:: {out_path} \n') fout.write(f'.. automodule:: {out_path} \n\n') - fout.write('.. autosummary::\n') fout.write(' :toctree: generated/\n') fout.write(' :nosignatures:\n') @@ -319,14 +318,29 @@ def _section(header, numpy_mod, brainpy_mod, jax_mod, klass=None, is_jax=False): def main(): os.makedirs('apis/auto/', exist_ok=True) - _write_module(module_name='brainunit', - filename='apis/brainunit.math.rst', - header='``brainunit.init`` module') - - -if __name__ == '__main__': - main() - + module_and_name = [ + ('_compat_numpy_array_creation', 'Array Creation'), + ('_compat_numpy_array_manipulation', 'Array Manipulation'), + ('_compat_numpy_funcs_accept_unitless', 'Functions Accepting Unitless'), + ('_compat_numpy_funcs_bit_operation', 'Functions with Bitwise Operations'), + ('_compat_numpy_funcs_change_unit', 'Functions Changing Unit'), + ('_compat_numpy_funcs_indexing', 'Indexing Functions'), + ('_compat_numpy_funcs_keep_unit', 'Functions Keeping Unit'), + ('_compat_numpy_funcs_logic', 'Logical Functions'), + ('_compat_numpy_funcs_match_unit', 'Functions Matching Unit'), + ('_compat_numpy_funcs_remove_unit', 'Functions Removing Unit'), + ('_compat_numpy_funcs_window', 'Window Functions'), + ('_compat_numpy_get_attribute', 'Get Attribute Functions'), + ('_compat_numpy_linear_algebra', 'Linear Algebra Functions'), + ('_compat_numpy_misc', 'More Functions'), + ] + _write_submodules(module_name='brainunit.math', + filename='apis/brainunit.math.rst', + header='``brainunit.math`` module', + submodule_names=[k[0] for k in module_and_name], + section_names=[k[1] for k in module_and_name]) +if __name__ == '__main__': + main() From 1b0d380c6513e9c6f4a3e88abf3130ab3022ad94 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 22:33:15 +0800 Subject: [PATCH 18/23] Update array creation funcs --- .../math/_compat_numpy_array_creation.py | 301 +++++++++++------- brainunit/math/_compat_numpy_test.py | 2 +- 2 files changed, 187 insertions(+), 116 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index f2d7527..32219a8 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -39,76 +39,14 @@ ] -def wrap_array_creation_function(func: Callable) -> Callable: - @wraps(func) - def f(*args, unit: Unit = None, **kwargs): - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return func(*args, **kwargs) * unit - else: - return func(*args, **kwargs) - - f.__module__ = 'brainunit.math' - return f - - -@wrap_array_creation_function +@set_module_as('brainunit.math') def full( shape: Sequence[int], fill_value: Any, dtype: Optional[Any] = None, unit: Optional[Unit] = None ) -> Union[Array, Quantity]: - return jnp.full(shape, fill_value, dtype=dtype) - - -@wrap_array_creation_function -def eye(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.eye(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def identity(n: int, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.identity(n, dtype=dtype) - - -@wrap_array_creation_function -def tri(N: int, - M: Optional[int] = None, - k: int = 0, - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.tri(N, M, k, dtype=dtype) - - -@wrap_array_creation_function -def empty(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.empty(shape, dtype=dtype) - - -@wrap_array_creation_function -def ones(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.ones(shape, dtype=dtype) - - -@wrap_array_creation_function -def zeros(shape: Sequence[int], - dtype: Optional[Any] = None, - unit: Optional[Unit] = None) -> Union[Array, Quantity]: - return jnp.zeros(shape, dtype=dtype) - - -full.__doc__ = ''' + ''' Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. else return an array of `shape` filled with `fill_value`. @@ -125,8 +63,22 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.full(shape, fill_value, dtype=dtype) * unit + else: + return jnp.full(shape, fill_value, dtype=dtype) -eye.__doc__ = """ + +@set_module_as('brainunit.math') +def eye( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -144,9 +96,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.eye(N, M, k, dtype=dtype) * unit + else: + return jnp.eye(N, M, k, dtype=dtype) + -identity.__doc__ = """ +@set_module_as('brainunit.math') +def identity( + n: int, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing an identity matrix if `unit` is provided. else return an identity matrix of `shape`. @@ -161,9 +125,23 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.identity(n, dtype=dtype) * unit + else: + return jnp.identity(n, dtype=dtype) + -tri.__doc__ = """ +@set_module_as('brainunit.math') +def tri( + N: int, + M: Optional[int] = None, + k: int = 0, + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, representing a triangular matrix if `unit` is provided. else return a triangular matrix of `shape`. @@ -182,10 +160,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.tri(N, M, k, dtype=dtype) * unit + else: + return jnp.tri(N, M, k, dtype=dtype) -# empty -empty.__doc__ = """ + +@set_module_as('brainunit.math') +def empty( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, with uninitialized values if `unit` is provided. else return an array of `shape` with uninitialized values. @@ -200,10 +189,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.empty(shape, dtype=dtype) * unit + else: + return jnp.empty(shape, dtype=dtype) + -# ones -ones.__doc__ = """ +@set_module_as('brainunit.math') +def ones( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 1 if `unit` is provided. else return an array of `shape` filled with 1. @@ -218,10 +218,21 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.ones(shape, dtype=dtype) * unit + else: + return jnp.ones(shape, dtype=dtype) -# zeros -zeros.__doc__ = """ + +@set_module_as('brainunit.math') +def zeros( + shape: Sequence[int], + dtype: Optional[Any] = None, + unit: Optional[Unit] = None +) -> Union[Array, Quantity]: + ''' Returns a Quantity of `shape` and `unit`, filled with 0 if `unit` is provided. else return an array of `shape` filled with 0. @@ -236,35 +247,40 @@ def zeros(shape: Sequence[int], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. -""" + ''' + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + return jnp.zeros(shape, dtype=dtype) * unit + else: + return jnp.zeros(shape, dtype=dtype) @set_module_as('brainunit.math') def full_like(a: Union[Quantity, bst.typing.ArrayLike], - fill_value: Union[bst.typing.ArrayLike], - unit: Unit = None, + fill_value: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None) -> Union[Quantity, jax.Array]: ''' - Return a Quantity of `a` and `unit`, filled with `fill_value` if `unit` is provided. + Return a Quantity if `a` and `fill_value` are Quantities that have the same unit or only `fill_value` is a Quantity. else return an array of `a` filled with `fill_value`. Args: a: array_like, Quantity, shape, or dtype fill_value: scalar or array_like - unit: Unit, optional dtype: data-type, optional shape: sequence of ints, optional Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if isinstance(fill_value, Quantity): + return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), dim=fill_value.dim) else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) * unit + return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) else: return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @@ -284,12 +300,19 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.diag(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.diag(a, k=k) * unit + else: + return jnp.diag(a, k=k) else: return jnp.diag(a, k=k) @@ -309,12 +332,19 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.tril(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.tril(a, k=k) * unit + else: + return jnp.tril(a, k=k) else: return jnp.tril(a, k=k) @@ -334,12 +364,19 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.triu(a.value, k=k) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) else: + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' return jnp.triu(a, k=k) * unit + else: + return jnp.triu(a, k=k) else: return jnp.triu(a, k=k) @@ -362,16 +399,24 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.empty_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.empty_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.empty_like(a, dtype=dtype, shape=shape) else: return jnp.empty_like(a, dtype=dtype, shape=shape) + @set_module_as('brainunit.math') def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, @@ -390,12 +435,19 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.ones_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.ones_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.ones_like(a, dtype=dtype, shape=shape) else: return jnp.ones_like(a, dtype=dtype, shape=shape) @@ -418,12 +470,19 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit) - if isinstance(a, Quantity): - return jnp.zeros_like(a.value, dtype=dtype, shape=shape) * unit + if isinstance(a, Quantity): + if unit is not None: + assert isinstance(unit, Unit) + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) else: + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + elif isinstance(a, (jax.Array, np.ndarray)): + if unit is not None: + assert isinstance(unit, Unit) return jnp.zeros_like(a, dtype=dtype, shape=shape) * unit + else: + return jnp.zeros_like(a, dtype=dtype, shape=shape) else: return jnp.zeros_like(a, dtype=dtype, shape=shape) @@ -438,7 +497,8 @@ def asarray( ''' Convert the input to a quantity or array. - If unit is provided, the input is converted to a Quantity object with the given unit. + If unit is provided, the input will be checked whether it has the same unit as the provided unit. + If unit is not provided, the input will be converted to an array. Args: a: array_like, Quantity, or Sequence[Quantity] @@ -452,19 +512,30 @@ def asarray( if isinstance(a, Quantity): if unit is not None: assert isinstance(unit, Unit) - return jnp.asarray(a.value, dtype=dtype, order=order) * unit + fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) else: - return jnp.asarray(a.value, dtype=dtype, order=order) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) return jnp.asarray(a, dtype=dtype, order=order) * unit else: return jnp.asarray(a, dtype=dtype, order=order) + # list[Quantity] + elif isinstance(a, Sequence) and all(isinstance(x, Quantity) for x in a): + # check all elements have the same unit + if any(x.dim != a[0].dim for x in a): + raise ValueError('Units do not match for asarray operation.') + values = [x.value for x in a] + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: return jnp.asarray(a, dtype=dtype, order=order) + array = asarray diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 18cdd4a..2e73403 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -92,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4, unit=bu.second) + result_q = bu.math.full_like(q, 4 * bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): From bb439c6efc5b3bf3877eaa495cfca3343af3a723 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 09:08:15 +0800 Subject: [PATCH 19/23] Update _compat_numpy_test.py --- brainunit/math/_compat_numpy_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 18cdd4a..2e73403 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -92,7 +92,7 @@ def test_full_like(self): self.assertTrue(jnp.all(result == 4)) q = [1, 2, 3] * bu.second - result_q = bu.math.full_like(q, 4, unit=bu.second) + result_q = bu.math.full_like(q, 4 * bu.second) assert_quantity(result_q, jnp.full_like(jnp.array([1, 2, 3]), 4), bu.second) def test_diag(self): From ea4e9d5e9b360a0c55dcbf90f04101c3be4f3cfb Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:10:58 +0800 Subject: [PATCH 20/23] Add magnitude conversion for `asarray` --- brainunit/math/_compat_numpy_array_creation.py | 9 +++++++-- brainunit/math/_compat_numpy_test.py | 8 ++++++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 32219a8..68873dc 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -498,6 +498,7 @@ def asarray( Convert the input to a quantity or array. If unit is provided, the input will be checked whether it has the same unit as the provided unit. + (If they have same dimension but different magnitude, the input will be converted to the provided unit.) If unit is not provided, the input will be converted to an array. Args: @@ -513,9 +514,13 @@ def asarray( if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + if a.dim == unit: + return a + else: + # Convert to the magnitude of the provided unit + return Quantity(a.value / unit.value, dim=unit) else: - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order) / unit.value, dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 2e73403..0bd8c10 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -32,7 +32,7 @@ def assert_quantity(q, values, unit): values = jnp.asarray(values) if isinstance(q, Quantity): - assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}" + assert q.dim == unit.dim or q.dim == unit, f"Unit mismatch: {q.dim} != {unit}" assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}" else: assert jnp.allclose(q, values), f"Values do not match: {q} != {values}" @@ -162,6 +162,10 @@ def test_asarray(self): result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) + q1 = [1, 2, 3] * bu.second + result_q = bu.math.asarray(q1, unit=bu.ms) + assert_quantity(result_q, jnp.asarray([1, 2, 3]) * 1000, bu.ms) + def test_arange(self): result = bu.math.arange(5) self.assertEqual(result.shape, (5,)) @@ -171,7 +175,7 @@ def test_arange(self): assert_quantity(result_q, jnp.arange(5, step=1), bu.second) result_q = bu.math.arange(3 * bu.second, 9 * bu.second, 1 * bu.second) - assert_quantity(result_q, jnp.arange(3, 9, 1), bu.second) + assert_quantity(result_q, jnp.arange(3, 9, 1), bu.ms) def test_linspace(self): result = bu.math.linspace(0, 10, 5) From b967edc79ed7c19269f31dba4c1c9733b9c1174c Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:45:27 +0800 Subject: [PATCH 21/23] Update _compat_numpy_array_creation.py --- .../math/_compat_numpy_array_creation.py | 149 +++++++----------- 1 file changed, 59 insertions(+), 90 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 68873dc..912a87d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -42,12 +42,11 @@ @set_module_as('brainunit.math') def full( shape: Sequence[int], - fill_value: Any, + fill_value: Union[Quantity, int, float], dtype: Optional[Any] = None, - unit: Optional[Unit] = None ) -> Union[Array, Quantity]: ''' - Returns a Quantity of `shape` and `unit`, filled with `fill_value` if `unit` is provided. + Returns a Quantity of `shape`, filled with `fill_value` if `fill_value` is a Quantity. else return an array of `shape` filled with `fill_value`. Args: @@ -55,19 +54,13 @@ def full( fill_value: the value to fill the new array with. dtype: the type of the output array, or `None`. If not `None`, `fill_value` will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. - unit: the unit of the output array, or `None`. Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if unit is not None: - assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' - return jnp.full(shape, fill_value, dtype=dtype) * unit - else: - return jnp.full(shape, fill_value, dtype=dtype) + if isinstance(fill_value, Quantity): + return jnp.full(shape, fill_value.magnitude, dtype=dtype) * fill_value.unit + return jnp.full(shape, fill_value, dtype=dtype) @set_module_as('brainunit.math') @@ -89,9 +82,6 @@ def eye( lower diagonal. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -118,9 +108,6 @@ def identity( n: the number of rows (and columns) in the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -153,9 +140,6 @@ def tri( lower diagonal. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -182,9 +166,6 @@ def empty( shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be of type `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -211,9 +192,6 @@ def ones( shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -240,9 +218,6 @@ def zeros( shape: sequence of integers, describing the shape of the output array. dtype: the type of the output array, or `None`. If not `None`, elements will be cast to `dtype`. - sharding: an optional sharding specification for the resulting array, - note, sharding will currently be ignored in jitted mode, this might change - in the future. unit: the unit of the output array, or `None`. Returns: @@ -273,22 +248,23 @@ def full_like(a: Union[Quantity, bst.typing.ArrayLike], Returns: Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' - if isinstance(a, Quantity): - fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") - return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)): - if isinstance(fill_value, Quantity): + if isinstance(fill_value, Quantity): + if isinstance(a, Quantity): + fail_for_dimension_mismatch(a, fill_value, error_message="a and fill_value have to have the same units.") + return Quantity(jnp.full_like(a.value, fill_value.value, dtype=dtype, shape=shape), dim=a.dim) + else: return Quantity(jnp.full_like(a, fill_value.value, dtype=dtype, shape=shape), dim=fill_value.dim) + else: + if isinstance(a, Quantity): + return jnp.full_like(a.value, fill_value, dtype=dtype, shape=shape) else: return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) - else: - return jnp.full_like(a, fill_value, dtype=dtype, shape=shape) @set_module_as('brainunit.math') def diag(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Extract a diagonal or construct a diagonal array. @@ -304,9 +280,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.diag(a.value, k=k), dim=a.dim) - else: - return Quantity(jnp.diag(a.value, k=k), dim=a.dim) + return Quantity(jnp.diag(a.value, k=k), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -320,7 +294,7 @@ def diag(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def tril(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Lower triangle of an array. @@ -336,9 +310,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.tril(a.value, k=k), dim=a.dim) - else: - return Quantity(jnp.tril(a.value, k=k), dim=a.dim) + return Quantity(jnp.tril(a.value, k=k), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -352,7 +324,7 @@ def tril(a: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def triu(a: Union[Quantity, bst.typing.ArrayLike], k: int = 0, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Upper triangle of an array. @@ -368,9 +340,7 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.triu(a.value, k=k), dim=a.dim) - else: - return Quantity(jnp.triu(a.value, k=k), dim=a.dim) + return Quantity(jnp.triu(a.value, k=k), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit), f'unit must be an instance of Unit, got {type(unit)}' @@ -385,7 +355,7 @@ def triu(a: Union[Quantity, bst.typing.ArrayLike], def empty_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, with uninitialized values if `unit` is provided. else return an array of `a` with uninitialized values. @@ -403,9 +373,7 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) - else: - return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + return Quantity(jnp.empty_like(a.value, dtype=dtype, shape=shape), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) @@ -416,12 +384,11 @@ def empty_like(a: Union[Quantity, bst.typing.ArrayLike], return jnp.empty_like(a, dtype=dtype, shape=shape) - @set_module_as('brainunit.math') def ones_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, filled with 1 if `unit` is provided. else return an array of `a` filled with 1. @@ -439,9 +406,7 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) - else: - return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + return Quantity(jnp.ones_like(a.value, dtype=dtype, shape=shape), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) @@ -456,7 +421,7 @@ def ones_like(a: Union[Quantity, bst.typing.ArrayLike], def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], dtype: Optional[bst.typing.DTypeLike] = None, shape: Any = None, - unit: Unit = None) -> Union[Quantity, jax.Array]: + unit: Optional[Unit] = None) -> Union[Quantity, jax.Array]: ''' Return a Quantity of `a` and `unit`, filled with 0 if `unit` is provided. else return an array of `a` filled with 0. @@ -474,9 +439,7 @@ def zeros_like(a: Union[Quantity, bst.typing.ArrayLike], if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) - else: - return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) + return Quantity(jnp.zeros_like(a.value, dtype=dtype, shape=shape), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) @@ -514,13 +477,7 @@ def asarray( if unit is not None: assert isinstance(unit, Unit) fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.") - if a.dim == unit: - return a - else: - # Convert to the magnitude of the provided unit - return Quantity(a.value / unit.value, dim=unit) - else: - return Quantity(jnp.asarray(a.value, dtype=dtype, order=order) / unit.value, dim=a.dim) + return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim) elif isinstance(a, (jax.Array, np.ndarray)): if unit is not None: assert isinstance(unit, Unit) @@ -528,17 +485,28 @@ def asarray( else: return jnp.asarray(a, dtype=dtype, order=order) # list[Quantity] - elif isinstance(a, Sequence) and all(isinstance(x, Quantity) for x in a): - # check all elements have the same unit - if any(x.dim != a[0].dim for x in a): - raise ValueError('Units do not match for asarray operation.') - values = [x.value for x in a] - unit = a[0].dim - # Convert the values to a jnp.ndarray and create a Quantity object - return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + elif isinstance(a, Sequence): + leaves, tree = jax.tree.flatten(a, is_leaf=lambda x: isinstance(x, Quantity)) + if all([isinstance(leaf, Quantity) for leaf in leaves]): + # check all elements have the same unit + if any(x.dim != leaves[0].dim for x in leaves): + raise ValueError('Units do not match for asarray operation.') + values = jax.tree.unflatten(tree, [x.value for x in a]) + + fail_for_dimension_mismatch(a[0], unit, error_message="a and unit have to have the same units.") + unit = a[0].dim + # Convert the values to a jnp.ndarray and create a Quantity object + return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) + else: + values = jax.tree.unflatten(tree, [x.value for x in a]) + val = jnp.asarray(values, dtype=dtype, order=order) + if unit is not None: + assert isinstance(unit, Unit) + return val * unit + else: + return val else: - return jnp.asarray(a, dtype=dtype, order=order) - + raise TypeError('Invalid input type for asarray.') array = asarray @@ -713,8 +681,7 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], val: Union[Quantity, bst.typing.ArrayLike], - wrap: Optional[bool] = False, - inplace: Optional[bool] = True) -> Union[Quantity, jax.Array]: + wrap: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -722,20 +689,22 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], a: array_like, Quantity val: scalar, Quantity wrap: bool, optional - inplace: bool, optional + unit: Unit, optional Returns: Union[jax.Array, Quantity]: Quantity if `a` and `val` are Quantities that have the same unit, else an array. ''' - if isinstance(a, Quantity) and isinstance(val, Quantity): - fail_for_dimension_mismatch(a, val) - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap=wrap, inplace=inplace), dim=a.dim) - elif isinstance(a, (jax.Array, np.ndarray)) and isinstance(val, (jax.Array, np.ndarray)): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) - elif is_unitless(a) or is_unitless(val): - return jnp.fill_diagonal(a, val, wrap=wrap, inplace=inplace) + if isinstance(val, Quantity): + if isinstance(a, Quantity): + fail_for_dimension_mismatch(a, val, error_message="Array and value have to have the same units.") + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap), dim=a.dim) + else: + return Quantity(jnp.fill_diagonal(a, val.value, wrap), dim=val.dim) else: - raise ValueError(f'Unsupported types : {type(a)} abd {type(val)} for fill_diagonal') + if isinstance(a, Quantity): + return jnp.fill_diagonal(a.value, val, wrap) + else: + return jnp.fill_diagonal(a, val, wrap) @set_module_as('brainunit.math') From 512f7343b1998c2a4d7b0936559335dced34bf0a Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:45:43 +0800 Subject: [PATCH 22/23] Update _compat_numpy_test.py --- brainunit/math/_compat_numpy_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 0bd8c10..2f8de29 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -162,9 +162,6 @@ def test_asarray(self): result_q = bu.math.asarray([1, 2, 3], unit=bu.second) assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second) - q1 = [1, 2, 3] * bu.second - result_q = bu.math.asarray(q1, unit=bu.ms) - assert_quantity(result_q, jnp.asarray([1, 2, 3]) * 1000, bu.ms) def test_arange(self): result = bu.math.arange(5) From 1647baa38bdfab38822d8d1d767606a0807f4841 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Wed, 12 Jun 2024 13:53:04 +0800 Subject: [PATCH 23/23] Fix bugs --- brainunit/math/_compat_numpy_array_creation.py | 15 ++++++++------- brainunit/math/_compat_numpy_test.py | 6 +++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 912a87d..1e31f4d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -59,7 +59,7 @@ def full( Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if isinstance(fill_value, Quantity): - return jnp.full(shape, fill_value.magnitude, dtype=dtype) * fill_value.unit + return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim) return jnp.full(shape, fill_value, dtype=dtype) @@ -498,7 +498,7 @@ def asarray( # Convert the values to a jnp.ndarray and create a Quantity object return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: - values = jax.tree.unflatten(tree, [x.value for x in a]) + values = jax.tree.unflatten(tree, leaves) val = jnp.asarray(values, dtype=dtype, order=order) if unit is not None: assert isinstance(unit, Unit) @@ -681,7 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], val: Union[Quantity, bst.typing.ArrayLike], - wrap: Optional[bool] = False) -> Union[Quantity, jax.Array]: + wrap: Optional[bool] = False, + inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -697,14 +698,14 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], if isinstance(val, Quantity): if isinstance(a, Quantity): fail_for_dimension_mismatch(a, val, error_message="Array and value have to have the same units.") - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap), dim=a.dim) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap, inplace=inplace), dim=a.dim) else: - return Quantity(jnp.fill_diagonal(a, val.value, wrap), dim=val.dim) + return Quantity(jnp.fill_diagonal(a, val.value, wrap, inplace=inplace), dim=val.dim) else: if isinstance(a, Quantity): - return jnp.fill_diagonal(a.value, val, wrap) + return jnp.fill_diagonal(a.value, val, wrap, inplace=inplace) else: - return jnp.fill_diagonal(a, val, wrap) + return jnp.fill_diagonal(a, val, wrap, inplace=inplace) @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 2f8de29..615d720 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -45,7 +45,7 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) - q = bu.math.full(3, 4, unit=second) + q = bu.math.full(3, 4 * second) self.assertEqual(q.shape, (3,)) assert_quantity(q, result, second) @@ -192,11 +192,11 @@ def test_logspace(self): def test_fill_diagonal(self): array = jnp.zeros((3, 3)) - result = bu.math.fill_diagonal(array, 5, inplace=False) + result = bu.math.fill_diagonal(array, 5) self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]))) q = jnp.zeros((3, 3)) * bu.second - result_q = bu.math.fill_diagonal(q, 5 * bu.second, inplace=False) + result_q = bu.math.fill_diagonal(q, 5 * bu.second) assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), bu.second) def test_array_split(self):