From 4ee28ede8830a5990fad4904f9d3c4e9432de8d9 Mon Sep 17 00:00:00 2001 From: He Sichao <1310722434@qq.com> Date: Tue, 11 Jun 2024 11:38:38 +0800 Subject: [PATCH] Update --- brainunit/math/_compat_numpy.py | 1054 ++++++++++++++++++++++++++++++- brainunit/math/_utils.py | 5 +- 2 files changed, 1037 insertions(+), 22 deletions(-) diff --git a/brainunit/math/_compat_numpy.py b/brainunit/math/_compat_numpy.py index 0dae459..184d3dc 100644 --- a/brainunit/math/_compat_numpy.py +++ b/brainunit/math/_compat_numpy.py @@ -101,10 +101,10 @@ 'diagflat', 'diagonal', 'choose', 'ravel', # Elementwise bit operations (unary) - 'bitwise_not', 'invert', 'left_shift', 'right_shift', + 'bitwise_not', 'invert', # Elementwise bit operations (binary) - 'bitwise_and', 'bitwise_or', 'bitwise_xor', + 'bitwise_and', 'bitwise_or', 'bitwise_xor', 'left_shift', 'right_shift', # logic funcs (unary) 'all', 'any', 'logical_not', @@ -2657,6 +2657,490 @@ def f(x, y, *args, **kwargs): extract = _compatible_with_quantity(jnp.extract, return_quantity=False) count_nonzero = _compatible_with_quantity(jnp.count_nonzero, return_quantity=False) +# docs for the functions above +reshape.__doc__ = ''' + Return a reshaped copy of an array or a Quantity. + + Args: + a: input array or Quantity to reshape + shape: integer or sequence of integers giving the new shape, which must match the + size of the input array. If any single dimension is given size ``-1``, it will be + replaced with a value such that the output has the correct size. + order: ``'F'`` or ``'C'``, specifies whether the reshape should apply column-major + (fortran-style, ``"F"``) or row-major (C-style, ``"C"``) order; default is ``"C"``. + brainunit does not support ``order="A"``. + + Returns: + reshaped copy of input array with the specified shape. +''' + +moveaxis.__doc__ = ''' + Moves axes of an array to new positions. Other axes remain in their original order. + + Args: + a: array_like, Quantity + source: int or sequence of ints + destination: int or sequence of ints + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +transpose.__doc__ = ''' + Returns a view of the array with axes transposed. + + Args: + a: array_like, Quantity + axes: tuple or list of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +swapaxes.__doc__ = ''' + Interchanges two axes of an array. + + Args: + a: array_like, Quantity + axis1: int + axis2: int + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +concatenate.__doc__ = ''' + Join a sequence of arrays along an existing axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +stack.__doc__ = ''' + Join a sequence of arrays along a new axis. + + Args: + arrays: sequence of array_like, Quantity + axis: int + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +vstack.__doc__ = ''' + Stack arrays in sequence vertically (row wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +hstack.__doc__ = ''' + Stack arrays in sequence horizontally (column wise). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +dstack.__doc__ = ''' + Stack arrays in sequence depth wise (along third axis). + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +column_stack.__doc__ = ''' + Stack 1-D arrays as columns into a 2-D array. + + Args: + arrays: sequence of 1-D or 2-D array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +split.__doc__ = ''' + Split an array into multiple sub-arrays. + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + axis: int, optional + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +dsplit.__doc__ = ''' + Split array along third axis (depth). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +hsplit.__doc__ = ''' + Split an array into multiple sub-arrays horizontally (column-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +vsplit.__doc__ = ''' + Split an array into multiple sub-arrays vertically (row-wise). + + Args: + a: array_like, Quantity + indices_or_sections: int or 1-D array + + Returns: + out: a list of Quantity if a is a Quantity, otherwise a list of jax.numpy.Array +''' + +tile.__doc__ = ''' + Construct an array by repeating A the number of times given by reps. + + Args: + A: array_like, Quantity + reps: array_like + + Returns: + out: a Quantity if A is a Quantity, otherwise a jax.numpy.Array +''' + +repeat.__doc__ = ''' + Repeat elements of an array. + + Args: + a: array_like, Quantity + repeats: array_like + axis: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +unique.__doc__ = ''' + Find the unique elements of an array. + + Args: + a: array_like, Quantity + return_index: bool, optional + return_inverse: bool, optional + return_counts: bool, optional + axis: int or None, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +append.__doc__ = ''' + Append values to the end of an array. + + Args: + arr: array_like, Quantity + values: array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if arr and values are Quantity, otherwise a jax.numpy.Array +''' + +flip.__doc__ = ''' + Reverse the order of elements in an array along the given axis. + + Args: + m: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +fliplr.__doc__ = ''' + Flip array in the left/right direction. + + Args: + m: array_like, Quantity + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +flipud.__doc__ = ''' + Flip array in the up/down direction. + + Args: + m: array_like, Quantity + + Returns: + out: a Quantity if m is a Quantity, otherwise a jax.numpy.Array +''' + +roll.__doc__ = ''' + Roll array elements along a given axis. + + Args: + a: array_like, Quantity + shift: int or tuple of ints + axis: int or tuple of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_1d.__doc__ = ''' + View inputs as arrays with at least one dimension. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_2d.__doc__ = ''' + View inputs as arrays with at least two dimensions. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +atleast_3d.__doc__ = ''' + View inputs as arrays with at least three dimensions. + + Args: + *args: array_like, Quantity + + Returns: + out: a Quantity if any input is a Quantity, otherwise a jax.numpy.Array +''' + +expand_dims.__doc__ = ''' + Expand the shape of an array. + + Args: + a: array_like, Quantity + axis: int or tuple of ints + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +squeeze.__doc__ = ''' + Remove single-dimensional entries from the shape of an array. + + Args: + a: array_like, Quantity + axis: None or int or tuple of ints, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +sort.__doc__ = ''' + Return a sorted copy of an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort', 'stable'}, optional + order: str or list of str, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' +max.__doc__ = ''' + Return the maximum of an array or maximum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +min.__doc__ = ''' + Return the minimum of an array or minimum along an axis. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + keepdims: bool, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +choose.__doc__ = ''' + Use an index array to construct a new array from a set of choices. + + Args: + a: array_like, Quantity + choices: array_like, Quantity + + Returns: + out: a Quantity if a and choices are Quantity, otherwise a jax.numpy.Array +''' + +block.__doc__ = ''' + Assemble an nd-array from nested lists of blocks. + + Args: + arrays: sequence of array_like, Quantity + + Returns: + out: a Quantity if all input arrays are Quantity, otherwise a jax.numpy.Array +''' + +compress.__doc__ = ''' + Return selected slices of an array along given axis. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + axis: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +diagflat.__doc__ = ''' + Create a two-dimensional array with the flattened input as a diagonal. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +argsort.__doc__ = ''' + Returns the indices that would sort an array. + + Args: + a: array_like, Quantity + axis: int or None, optional + kind: {'quicksort', 'mergesort', 'heapsort'}, optional + order: str or list of str, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argmax.__doc__ = ''' + Returns indices of the max value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argmin.__doc__ = ''' + Returns indices of the min value along an axis. + + Args: + a: array_like, Quantity + axis: int, optional + out: array, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +argwhere.__doc__ = ''' + Find indices of non-zero elements. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +nonzero.__doc__ = ''' + Return the indices of the elements that are non-zero. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +flatnonzero.__doc__ = ''' + Return indices that are non-zero in the flattened version of a. + + Args: + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +searchsorted.__doc__ = ''' + Find indices where elements should be inserted to maintain order. + + Args: + a: array_like, Quantity + v: array_like, Quantity + side: {'left', 'right'}, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +extract.__doc__ = ''' + Return the elements of an array that satisfy some condition. + + Args: + condition: array_like, Quantity + a: array_like, Quantity + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + +count_nonzero.__doc__ = ''' + Counts the number of non-zero values in the array a. + + Args: + a: array_like, Quantity + axis: int or tuple of ints, optional + + Returns: + out: jax.numpy.Array (does not return a Quantity) +''' + def wrap_function_to_method(func): @wraps(func) @@ -2673,6 +3157,30 @@ def f(x, *args, **kwargs): diagonal = wrap_function_to_method(jnp.diagonal) ravel = wrap_function_to_method(jnp.ravel) +diagonal.__doc__ = ''' + Return specified diagonals. + + Args: + a: array_like, Quantity + offset: int, optional + axis1: int, optional + axis2: int, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + +ravel.__doc__ = ''' + Return a contiguous flattened array. + + Args: + a: array_like, Quantity + order: {'C', 'F', 'A', 'K'}, optional + + Returns: + out: a Quantity if a is a Quantity, otherwise a jax.numpy.Array +''' + # Elementwise bit operations (unary) # ---------------------------------- @@ -2692,8 +3200,27 @@ def f(x, *args, **kwargs): bitwise_not = wrap_elementwise_bit_operation_unary(jnp.bitwise_not) invert = wrap_elementwise_bit_operation_unary(jnp.invert) -left_shift = wrap_elementwise_bit_operation_unary(jnp.left_shift) -right_shift = wrap_elementwise_bit_operation_unary(jnp.right_shift) + +# docs for functions above +bitwise_not.__doc__ = ''' + Compute the bit-wise NOT of an array, element-wise. + + Args: + x: array_like + + Returns: + out: an array +''' + +invert.__doc__ = ''' + Compute bit-wise inversion, or bit-wise NOT, element-wise. + + Args: + x: array_like + + Returns: + out: an array +''' # Elementwise bit operations (binary) @@ -2708,13 +3235,71 @@ def f(x, y, *args, **kwargs): else: raise ValueError(f'Unsupported types {type(x)} and {type(y)} for {func.__name__}') - f.__module__ = 'brainunit.math' - return f + f.__module__ = 'brainunit.math' + return f + + +bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) +bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) +bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) +left_shift = wrap_elementwise_bit_operation_binary(jnp.left_shift) +right_shift = wrap_elementwise_bit_operation_binary(jnp.right_shift) + +# docs for functions above +bitwise_and.__doc__ = ''' + Compute the bit-wise AND of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +bitwise_or.__doc__ = ''' + Compute the bit-wise OR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +bitwise_xor.__doc__ = ''' + Compute the bit-wise XOR of two arrays element-wise. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' + +left_shift.__doc__ = ''' + Shift the bits of an integer to the left. + + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' +right_shift.__doc__ = ''' + Shift the bits of an integer to the right. -bitwise_and = wrap_elementwise_bit_operation_binary(jnp.bitwise_and) -bitwise_or = wrap_elementwise_bit_operation_binary(jnp.bitwise_or) -bitwise_xor = wrap_elementwise_bit_operation_binary(jnp.bitwise_xor) + Args: + x: array_like + y: array_like + + Returns: + out: an array +''' # logic funcs (unary) @@ -2739,6 +3324,46 @@ def f(x, *args, **kwargs): sometrue = any logical_not = wrap_logic_func_unary(jnp.logical_not) +# docs for functions above +all.__doc__ = ''' + Test whether all array elements along a given axis evaluate to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + out: bool or array +''' + +any.__doc__ = ''' + Test whether any array element along a given axis evaluates to True. + + Args: + a: array_like + axis: int, optional + out: array, optional + keepdims: bool, optional + where: array_like of bool, optional + + Returns: + out: bool or array +''' + +logical_not.__doc__ = ''' + Compute the truth value of NOT x element-wise. + + Args: + x: array_like + out: array, optional + + Returns: + out: bool or array +''' + # logic funcs (binary) # -------------------- @@ -2771,11 +3396,155 @@ def f(x, y, *args, **kwargs): logical_or = wrap_logic_func_binary(jnp.logical_or) logical_xor = wrap_logic_func_binary(jnp.logical_xor) +# docs for functions above +equal.__doc__ = ''' + Return (x == y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +not_equal.__doc__ = ''' + Return (x != y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +greater.__doc__ = ''' + Return (x > y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +greater_equal.__doc__ = ''' + Return (x >= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +less.__doc__ = ''' + Return (x < y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +less_equal.__doc__ = ''' + Return (x <= y) element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like, Quantity + y: array_like, Quantity + + Returns: + out: bool or array +''' + +array_equal.__doc__ = ''' + Return True if two arrays have the same shape, elements, and units (if they are Quantity), False otherwise. + + Args: + x1: array_like, Quantity + x2: array_like, Quantity + + Returns: + out: bool or array +''' + +isclose.__doc__ = ''' + Returns a boolean array where two arrays are element-wise equal within a tolerance and have the same unit if they are Quantity. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + out: bool or array +''' + +allclose.__doc__ = ''' + Returns True if the two arrays are equal within the given tolerance and have the same unit if they are Quantity; False otherwise. + + Args: + a: array_like, Quantity + b: array_like, Quantity + rtol: float, optional + atol: float, optional + equal_nan: bool, optional + + Returns: + out: bool +''' + +logical_and.__doc__ = ''' + Compute the truth value of x AND y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + +logical_or.__doc__ = ''' + Compute the truth value of x OR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + +logical_xor.__doc__ = ''' + Compute the truth value of x XOR y element-wise and have the same unit if x and y are Quantity. + + Args: + x: array_like + y: array_like + out: array, optional + + Returns: + out: bool or array +''' + # indexing funcs # -------------- @set_module_as('brainunit.math') -def where(condition, *args, **kwds): # pylint: disable=C0111 +def where(condition: Union[bool, bst.typing.ArrayLike], + *args: Union[Quantity, bst.typing.ArrayLike], + **kwds) -> Union[Quantity, jax.Array]: condition = jnp.asarray(condition) if len(args) == 0: # nothing to do @@ -2809,10 +3578,32 @@ def where(condition, *args, **kwds): # pylint: disable=C0111 tril_indices = jnp.tril_indices +tril_indices.__doc__ = ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + out: tuple[array] +''' @set_module_as('brainunit.math') -def tril_indices_from(arr, k=0): +def tril_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the lower-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + out: tuple[array] + ''' if isinstance(arr, Quantity): return jnp.tril_indices_from(arr.value, k=k) else: @@ -2820,10 +3611,32 @@ def tril_indices_from(arr, k=0): triu_indices = jnp.triu_indices +triu_indices.__doc__ = ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + n: int + m: int + k: int, optional + + Returns: + out: tuple[array] +''' @set_module_as('brainunit.math') -def triu_indices_from(arr, k=0): +def triu_indices_from(arr: Union[Quantity, bst.typing.ArrayLike], + k: Optional[int] = 0) -> tuple[jax.Array, jax.Array]: + ''' + Return the indices for the upper-triangle of an (n, m) array. + + Args: + arr: array_like, Quantity + k: int, optional + + Returns: + out: tuple[array] + ''' if isinstance(arr, Quantity): return jnp.triu_indices_from(arr.value, k=k) else: @@ -2831,7 +3644,10 @@ def triu_indices_from(arr, k=0): @set_module_as('brainunit.math') -def take(a, indices, axis=None, mode=None): +def take(a: Union[Quantity, bst.typing.ArrayLike], + indices: Union[Quantity, bst.typing.ArrayLike], + axis: Optional[int] = None, + mode: Optional[str] = None) -> Union[Quantity, jax.Array]: if isinstance(a, Quantity): return a.take(indices, axis=axis, mode=mode) else: @@ -2839,7 +3655,9 @@ def take(a, indices, axis=None, mode=None): @set_module_as('brainunit.math') -def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quantity, jax.Array, np.ndarray], default=0): +def select(condlist: list[Union[bst.typing.ArrayLike]], + choicelist: Union[Quantity, bst.typing.ArrayLike], + default: int = 0) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(choice, Quantity) for choice in choicelist): @@ -2859,7 +3677,7 @@ def select(condlist: list[Union[jnp.array, np.ndarray]], choicelist: Union[Quant def wrap_window_funcs(func): def f(*args, **kwargs): - return Quantity(func(*args, **kwargs)) + return func(*args, **kwargs) f.__module__ = 'brainunit.math' return f @@ -2871,6 +3689,13 @@ def f(*args, **kwargs): hanning = wrap_window_funcs(jnp.hanning) kaiser = wrap_window_funcs(jnp.kaiser) +# docs for functions above +bartlett.__doc__ = jnp.bartlett.__doc__ +blackman.__doc__ = jnp.blackman.__doc__ +hamming.__doc__ = jnp.hamming.__doc__ +hanning.__doc__ = jnp.hanning.__doc__ +kaiser.__doc__ = jnp.kaiser.__doc__ + # constants # --------- e = jnp.e @@ -2887,13 +3712,91 @@ def f(*args, **kwargs): matmul = wrap_math_funcs_change_unit_binary(jnp.matmul, lambda x, y: x * y) trace = wrap_math_funcs_keep_unit_unary(jnp.trace) +# docs for functions above +dot.__doc__ = ''' + Dot product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `x` and the unit of `y`, else an array. +''' + +vdot.__doc__ = ''' + Return the dot product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +inner.__doc__ = ''' + Inner product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +outer.__doc__ = ''' + Compute the outer product of two vectors or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +kron.__doc__ = ''' + Compute the Kronecker product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +matmul.__doc__ = ''' + Matrix product of two arrays or quantities. + + Args: + a: array_like, Quantity + b: array_like, Quantity + + Returns: + out: Quantity if the final unit is the product of the unit of `a` and the unit of `b`, else an array. +''' + +trace.__doc__ = ''' + Return the sum of the diagonal elements of a matrix or quantity. + + Args: + a: array_like, Quantity + offset: int, optional + + Returns: + out: Quantity if the input is a Quantity, else an array. +''' + # data types # ---------- dtype = jnp.dtype @set_module_as('brainunit.math') -def finfo(a): +def finfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.finfo: if isinstance(a, Quantity): return jnp.finfo(a.value) else: @@ -2901,7 +3804,7 @@ def finfo(a): @set_module_as('brainunit.math') -def iinfo(a): +def iinfo(a: Union[Quantity, bst.typing.ArrayLike]) -> jnp.iinfo: if isinstance(a, Quantity): return jnp.iinfo(a.value) else: @@ -2911,7 +3814,7 @@ def iinfo(a): # more # ---- @set_module_as('brainunit.math') -def broadcast_arrays(*args): +def broadcast_arrays(*args: Union[Quantity, bst.typing.ArrayLike]) -> Union[Quantity, jax.Array]: from builtins import all as origin_all from builtins import any as origin_any if origin_all(isinstance(arg, Quantity) for arg in args): @@ -2929,14 +3832,37 @@ def broadcast_arrays(*args): @set_module_as('brainunit.math') def einsum( - subscripts, /, - *operands, + subscripts: str, + /, + *operands: Union[Quantity, jax.Array], out: None = None, optimize: Union[str, bool] = "optimal", precision: jax.lax.PrecisionLike = None, preferred_element_type: Union[jax.typing.DTypeLike, None] = None, _dot_general: Callable[..., jax.Array] = jax.lax.dot_general, ) -> Union[jax.Array, Quantity]: + ''' + Evaluates the Einstein summation convention on the operands. + + Args: + subscripts: string containing axes names separated by commas. + *operands: sequence of one or more arrays or quantities corresponding to the subscripts. + optimize: determine whether to optimize the order of computation. In JAX + this defaults to ``"optimize"`` which produces optimized expressions via + the opt_einsum_ package. + precision: either ``None`` (default), which means the default precision for + the backend, a :class:`~jax.lax.Precision` enum value (``Precision.DEFAULT``, + ``Precision.HIGH`` or ``Precision.HIGHEST``). + preferred_element_type: either ``None`` (default), which means the default + accumulation type for the input types, or a datatype, indicating to + accumulate results to and return a result with that datatype. + out: unsupported by JAX + _dot_general: optionally override the ``dot_general`` callable used by ``einsum``. + This parameter is experimental, and may be removed without warning at any time. + + Returns: + array containing the result of the einstein summation. + ''' operands = (subscripts, *operands) if out is not None: raise NotImplementedError("The 'out' argument to jnp.einsum is not supported.") @@ -3000,6 +3926,18 @@ def gradient( axis: Union[int, Sequence[int], None] = None, edge_order: Union[int, None] = None, ) -> Union[jax.Array, list[jax.Array], Quantity, list[Quantity]]: + ''' + Computes the gradient of a scalar field. + + Args: + f: input array. + *varargs: list of scalar fields to compute the gradient. + axis: axis or axes along which to compute the gradient. The default is to compute the gradient along all axes. + edge_order: order of the edge used for the finite difference computation. The default is 1. + + Returns: + array containing the gradient of the scalar field. + ''' if edge_order is not None: raise NotImplementedError("The 'edge_order' argument to jnp.gradient is not supported.") @@ -3029,6 +3967,18 @@ def intersect1d( assume_unique: bool = False, return_indices: bool = False ) -> Union[jax.Array, Quantity, tuple[Union[jax.Array, Quantity], jax.Array, jax.Array]]: + ''' + Find the intersection of two arrays. + + Args: + ar1: input array. + ar2: input array. + assume_unique: if True, the input arrays are both assumed to be unique. + return_indices: if True, the indices which correspond to the intersection of the two arrays are returned. + + Returns: + array containing the intersection of the two arrays. + ''' fail_for_dimension_mismatch(ar1, ar2, 'intersect1d') unit = None if isinstance(ar1, Quantity): @@ -3054,3 +4004,67 @@ def intersect1d( rot90 = wrap_math_funcs_keep_unit_unary(jnp.rot90) tensordot = wrap_math_funcs_change_unit_binary(jnp.tensordot, lambda x, y: x * y) + +# docs for functions above +nan_to_num.__doc__ = ''' + Replace NaN with zero and infinity with large finite numbers (default behaviour) or with the numbers defined by the user using the `nan`, `posinf` and `neginf` arguments. + + Args: + x: input array. + nan: value to replace NaNs with. + posinf: value to replace positive infinity with. + neginf: value to replace negative infinity with. + + Returns: + array with NaNs replaced by zero and infinities replaced by large finite numbers. +''' + +nanargmax.__doc__ = ''' + Return the index of the maximum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the maximum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the maximum value in the array. +''' + +nanargmin.__doc__ = ''' + Return the index of the minimum value in an array, ignoring NaNs. + + Args: + a: array like, Quantity. + axis: axis along which to operate. The default is to compute the index of the minimum over all the dimensions of the input array. + out: output array, optional. + keepdims: if True, the result is broadcast to the input array with the same number of dimensions. + + Returns: + index of the minimum value in the array. +''' + +rot90.__doc__ = ''' + Rotate an array by 90 degrees in the plane specified by axes. + + Args: + m: array like, Quantity. + k: number of times the array is rotated by 90 degrees. + axes: plane of rotation. Default is the last two axes. + + Returns: + rotated array. +''' + +tensordot.__doc__ = ''' + Compute tensor dot product along specified axes for arrays. + + Args: + a: array like, Quantity. + b: array like, Quantity. + axes: axes along which to compute the tensor dot product. + + Returns: + tensor dot product of the two arrays. +''' \ No newline at end of file diff --git a/brainunit/math/_utils.py b/brainunit/math/_utils.py index f30ec85..ae66103 100644 --- a/brainunit/math/_utils.py +++ b/brainunit/math/_utils.py @@ -15,8 +15,9 @@ import functools -from typing import Callable +from typing import Callable, Union +import jax from jax.tree_util import tree_map from .._base import Quantity @@ -38,7 +39,7 @@ def _compatible_with_quantity( func_to_wrap = fun.__np_wrapped__ if hasattr(fun, '__np_wrapped__') else fun @functools.wraps(func_to_wrap) - def new_fun(*args, **kwargs): + def new_fun(*args, **kwargs) -> Union[list[Quantity], Quantity, jax.Array]: unit = None if isinstance(args[0], Quantity): unit = args[0].unit