From a627c778416409c778f6cdb864c8089b25391bd2 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 5 Apr 2022 13:10:18 +0800 Subject: [PATCH] feat: operation on Variable return `jax.numpy.ndarray` not `brainpy.math.JaxArray` --- brainpy/math/jaxarray.py | 487 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 473 insertions(+), 14 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 099183863..504349a91 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -26,16 +26,22 @@ _all_slice = slice(None, None, None) +msg = ('JaxArray cannot be updated in JIT mode. You should ' + 'mark it as the brainpy.math.Variable instead.') + class JaxArray(object): """Multiple-dimensional array for JAX backend. """ - __slots__ = "_value" + __slots__ = ("_value", "_jit_mode") def __init__(self, value): if isinstance(value, (list, tuple)): value = jnp.asarray(value) + if isinstance(value, JaxArray): + value = value._value self._value = value + self._jit_mode = False @property def value(self): @@ -48,6 +54,8 @@ def value(self, value): def update(self, value): """Update the value of this JaxArray. """ + if self._jit_mode: + raise MathError(msg) if value.shape != self._value.shape: raise MathError(f"The shape of the original data is {self._value.shape}, " f"while we got {value.shape}.") @@ -138,6 +146,9 @@ def __getitem__(self, index): return self.value[index] def __setitem__(self, index, value): + if self._jit_mode: + raise MathError(msg) + # value is JaxArray if isinstance(value, JaxArray): value = value.value @@ -201,6 +212,8 @@ def __radd__(self, oc): def __iadd__(self, oc): # a += b + if self._jit_mode: + raise MathError(msg) self._value += (oc._value if isinstance(oc, JaxArray) else oc) return self @@ -212,6 +225,8 @@ def __rsub__(self, oc): def __isub__(self, oc): # a -= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__sub__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -223,6 +238,8 @@ def __rmul__(self, oc): def __imul__(self, oc): # a *= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__mul__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -240,6 +257,8 @@ def __rtruediv__(self, oc): def __itruediv__(self, oc): # a /= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__truediv__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -251,6 +270,8 @@ def __rfloordiv__(self, oc): def __ifloordiv__(self, oc): # a //= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__floordiv__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -268,6 +289,8 @@ def __rmod__(self, oc): def __imod__(self, oc): # a %= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__mod__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -279,6 +302,8 @@ def __rpow__(self, oc): def __ipow__(self, oc): # a **= b + if self._jit_mode: + raise MathError(msg) self._value = self._value ** (oc._value if isinstance(oc, JaxArray) else oc) return self @@ -290,6 +315,8 @@ def __rmatmul__(self, oc): def __imatmul__(self, oc): # a @= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__matmul__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -301,6 +328,8 @@ def __rand__(self, oc): def __iand__(self, oc): # a &= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__and__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -312,6 +341,8 @@ def __ror__(self, oc): def __ior__(self, oc): # a |= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__or__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -323,6 +354,8 @@ def __rxor__(self, oc): def __ixor__(self, oc): # a ^= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__xor__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -334,6 +367,8 @@ def __rlshift__(self, oc): def __ilshift__(self, oc): # a <<= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__lshift__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -345,6 +380,8 @@ def __rrshift__(self, oc): def __irshift__(self, oc): # a >>= b + if self._jit_mode: + raise MathError(msg) self._value = self._value.__rshift__(oc._value if isinstance(oc, JaxArray) else oc) return self @@ -857,38 +894,460 @@ def __jax_array__(self): class Variable(JaxArray): """The pointer to specify the dynamical variable. - - Parameters - ---------- - value : - Used to specify the data. """ - __slots__ = () + __slots__ = ('_value', ) def __init__(self, value): - if isinstance(value, JaxArray): - value = value.value super(Variable, self).__init__(value) + @property + def real(self): + return self._value.real + + @property + def T(self): + return self.value.T + + def __neg__(self): + return self._value.__neg__() + + def __pos__(self): + return self._value.__pos__() + + def __abs__(self): + return self._value.__abs__() + + def __invert__(self): + return self._value.__invert__() + + def __eq__(self, oc): + return self._value.__eq__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ne__(self, oc): + return self._value.__ne__(oc._value if isinstance(oc, JaxArray) else oc) + + def __lt__(self, oc): + return self._value.__lt__(oc._value if isinstance(oc, JaxArray) else oc) + + def __le__(self, oc): + return self._value.__le__(oc._value if isinstance(oc, JaxArray) else oc) + + def __gt__(self, oc): + return self._value.__gt__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ge__(self, oc): + return self._value.__ge__(oc._value if isinstance(oc, JaxArray) else oc) + + def __add__(self, oc): + return self._value.__add__(oc._value if isinstance(oc, JaxArray) else oc) + + def __radd__(self, oc): + return self._value.__radd__(oc._value if isinstance(oc, JaxArray) else oc) + + def __iadd__(self, oc): + # a += b + self._value += (oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __sub__(self, oc): + return self._value.__sub__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rsub__(self, oc): + return self._value.__rsub__(oc._value if isinstance(oc, JaxArray) else oc) + + def __isub__(self, oc): + # a -= b + self._value = self._value.__sub__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __mul__(self, oc): + return self._value.__mul__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rmul__(self, oc): + return self._value.__rmul__(oc._value if isinstance(oc, JaxArray) else oc) + + def __imul__(self, oc): + # a *= b + self._value = self._value.__mul__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __div__(self, oc): + return self._value.__div__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rdiv__(self, oc): + return self._value.__rdiv__(oc._value if isinstance(oc, JaxArray) else oc) + + def __truediv__(self, oc): + return self._value.__truediv__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rtruediv__(self, oc): + return self._value.__rtruediv__(oc._value if isinstance(oc, JaxArray) else oc) + + def __itruediv__(self, oc): + # a /= b + self._value = self._value.__truediv__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __floordiv__(self, oc): + return self._value.__floordiv__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rfloordiv__(self, oc): + return self._value.__rfloordiv__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ifloordiv__(self, oc): + # a //= b + self._value = self._value.__floordiv__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __divmod__(self, oc): + return self._value.__divmod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rdivmod__(self, oc): + return self._value.__rdivmod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __mod__(self, oc): + return self._value.__mod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rmod__(self, oc): + return self._value.__rmod__(oc._value if isinstance(oc, JaxArray) else oc) + + def __imod__(self, oc): + # a %= b + self._value = self._value.__mod__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __pow__(self, oc): + return self._value.__pow__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rpow__(self, oc): + return self._value.__rpow__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ipow__(self, oc): + # a **= b + self._value = self._value ** (oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __matmul__(self, oc): + return self._value.__matmul__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rmatmul__(self, oc): + return self._value.__rmatmul__(oc._value if isinstance(oc, JaxArray) else oc) + + def __imatmul__(self, oc): + # a @= b + self._value = self._value.__matmul__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __and__(self, oc): + return self._value.__and__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rand__(self, oc): + return self._value.__rand__(oc._value if isinstance(oc, JaxArray) else oc) + def __iand__(self, oc): + # a &= b + self._value = self._value.__and__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __or__(self, oc): + return self._value.__or__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ror__(self, oc): + return self._value.__ror__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ior__(self, oc): + # a |= b + self._value = self._value.__or__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __xor__(self, oc): + return self._value.__xor__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rxor__(self, oc): + return self._value.__rxor__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ixor__(self, oc): + # a ^= b + self._value = self._value.__xor__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __lshift__(self, oc): + return self._value.__lshift__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rlshift__(self, oc): + return self._value.__rlshift__(oc._value if isinstance(oc, JaxArray) else oc) + + def __ilshift__(self, oc): + # a <<= b + self._value = self._value.__lshift__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __rshift__(self, oc): + return self._value.__rshift__(oc._value if isinstance(oc, JaxArray) else oc) + + def __rrshift__(self, oc): + return self._value.__rrshift__(oc._value if isinstance(oc, JaxArray) else oc) + + def __irshift__(self, oc): + # a >>= b + self._value = self._value.__rshift__(oc._value if isinstance(oc, JaxArray) else oc) + return self + + def __round__(self, ndigits=None): + return self._value.__round__(ndigits) + + # ----------------------- # + # JAX methods # + # ----------------------- # + + def block_host_until_ready(self, *args): + self._value.block_host_until_ready(*args) + + def block_until_ready(self, *args): + self._value.block_until_ready(*args) + + # ----------------------- # + # NumPy methods # + # ----------------------- # + + def all(self, axis=None, keepdims=False): + """Returns True if all elements evaluate to True.""" + r = self.value.all(axis=axis, keepdims=keepdims) + return r + + def any(self, axis=None, keepdims=False): + """Returns True if any of the elements of a evaluate to True.""" + r = self.value.any(axis=axis, keepdims=keepdims) + return r + + def argmax(self, axis=None): + """Return indices of the maximum values along the given axis.""" + return self.value.argmax(axis=axis) + + def argmin(self, axis=None): + """Return indices of the minimum values along the given axis.""" + return self.value.argmin(axis=axis) + + def argpartition(self, kth, axis=-1, kind='introselect', order=None): + """Returns the indices that would partition this array.""" + return self.value.argpartition(kth=kth, axis=axis, kind=kind, order=order) + + def argsort(self, axis=-1, kind=None, order=None): + """Returns the indices that would sort this array.""" + return self.value.argsort(axis=axis, kind=kind, order=order) + + def astype(self, dtype, order='K', casting='unsafe', subok=True, copy=True): + """Copy of the array, cast to a specified type.""" + return self.value.astype(dtype=dtype, order=order, casting=casting, subok=subok, copy=copy) + + def byteswap(self, inplace=False): + """Swap the bytes of the array elements + + Toggle between low-endian and big-endian data representation by + returning a byteswapped array, optionally swapped in-place. + Arrays of byte-strings are not swapped. The real and imaginary + parts of a complex number are swapped individually.""" + return self.value.byteswap(inplace=inplace) + + def choose(self, choices, mode='raise'): + """Use an index array to construct a new array from a set of choices.""" + choices = choices.value if isinstance(choices, JaxArray) else choices + return self.value.choose(choices=choices, mode=mode) + + def clip(self, min=None, max=None): + """Return an array whose values are limited to [min, max]. One of max or min must be given.""" + return self.value.clip(min=min, max=max) + + def compress(self, condition, axis=None): + """Return selected slices of this array along given axis.""" + condition = condition.value if isinstance(condition, JaxArray) else condition + return self.value.compress(condition=condition, axis=axis) + + def conj(self): + """Complex-conjugate all elements.""" + return self.value.conj() + + def conjugate(self): + """Return the complex conjugate, element-wise.""" + return self.value.conjugate() + + def copy(self): + """Return a copy of the array.""" + return self.value.copy() + + def cumprod(self, axis=None, dtype=None): + """Return the cumulative product of the elements along the given axis.""" + return self.value.cumprod(axis=axis, dtype=dtype) + + def cumsum(self, axis=None, dtype=None): + """Return the cumulative sum of the elements along the given axis.""" + return self.value.cumsum(axis=axis, dtype=dtype) + + def diagonal(self, offset=0, axis1=0, axis2=1): + """Return specified diagonals.""" + return self.value.diagonal(offset=offset, axis1=axis1, axis2=axis2) + + def dot(self, b): + """Dot product of two arrays.""" + return self.value.dot(b) + + def fill(self, value): + """Fill the array with a scalar value.""" + self._value = jnp.ones_like(self.value) * value + + def flatten(self, order='C'): + return self.value.flatten(order=order) + + def item(self, *args): + """Copy an element of an array to a standard Python scalar and return it.""" + return self.value.item(*args) + + def max(self, axis=None, keepdims=False, *args, **kwargs): + """Return the maximum along a given axis.""" + res = self.value.max(axis=axis, keepdims=keepdims, *args, **kwargs) + return res + + def mean(self, axis=None, dtype=None, keepdims=False, *args, **kwargs): + """Returns the average of the array elements along given axis.""" + res = self.value.mean(axis=axis, dtype=dtype, keepdims=keepdims, *args, **kwargs) + return res + + def min(self, axis=None, keepdims=False, *args, **kwargs): + """Return the minimum along a given axis.""" + res = self.value.min(axis=axis, keepdims=keepdims, *args, **kwargs) + return res + + def nonzero(self): + """Return the indices of the elements that are non-zero.""" + return self.value.nonzero() + + def prod(self, axis=None, dtype=None, keepdims=False, initial=1, where=True): + """Return the product of the array elements over the given axis.""" + res = self.value.prod(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + return res + + def ptp(self, axis=None, keepdims=False): + """Peak to peak (maximum - minimum) value along a given axis.""" + r = self.value.ptp(axis=axis, keepdims=keepdims) + return r + + def ravel(self, order=None): + """Return a flattened array.""" + return self.value.ravel(order=order) + + def repeat(self, repeats, axis=None): + """Repeat elements of an array.""" + return self.value.repeat(repeats=repeats, axis=axis) + + def reshape(self, shape, order='C'): + """Returns an array containing the same data with a new shape.""" + return self.value.reshape(*shape, order=order) + + def round(self, decimals=0): + """Return ``a`` with each element rounded to the given number of decimals.""" + return self.value.round(decimals=decimals) + + def searchsorted(self, v, side='left', sorter=None): + """Find indices where elements should be inserted to maintain order.""" + v = v.value if isinstance(v, JaxArray) else v + return self.value.searchsorted(v=v, side=side, sorter=sorter) + + def sort(self, axis=-1, kind=None, order=None): + """Sort an array in-place.""" + self._value = self.value.sort(axis=axis, kind=kind, order=order) + + def squeeze(self, axis=None): + """Remove axes of length one from ``a``.""" + return self.value.squeeze(axis=axis) + + def std(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Compute the standard deviation along the specified axis. + + Returns the standard deviation, a measure of the spread of a distribution, + of the array elements. The standard deviation is computed for the + flattened array by default, otherwise over the specified axis. + """ + r = self.value.std(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + return r + + def sum(self, axis=None, dtype=None, keepdims=False, initial=0, where=True): + """Return the sum of the array elements over the given axis.""" + res = self.value.sum(axis=axis, dtype=dtype, keepdims=keepdims, initial=initial, where=where) + return res + + def swapaxes(self, axis1, axis2): + """Return a view of the array with `axis1` and `axis2` interchanged.""" + return self.value.swapaxes(axis1, axis2) + + def split(self, indices_or_sections, axis=0): + """Split an array into multiple sub-arrays as views into ``ary``. + """ + return self.value.split(indices_or_sections, axis=axis) + + def take(self, indices, axis=None, mode=None): + """Return an array formed from the elements of a at the given indices.""" + indices = indices.value if isinstance(indices, JaxArray) else indices + return self.value.take(indices=indices, axis=axis, mode=mode) + + def tobytes(self, order='C'): + """Construct Python bytes containing the raw data bytes in the array. + + Constructs Python bytes showing a copy of the raw contents of data memory. + The bytes object is produced in C-order by default. This behavior is + controlled by the ``order`` parameter.""" + return self.value.tobytes(order=order) + + def tolist(self): + """Return the array as an ``a.ndim``-levels deep nested list of Python scalars. + + Return a copy of the array data as a (nested) Python list. + Data items are converted to the nearest compatible builtin Python type, via + the `~numpy.ndarray.item` function. + + If ``a.ndim`` is 0, then since the depth of the nested list is 0, it will + not be a list at all, but a simple Python scalar. + """ + return self.value.tolist() + + def trace(self, offset=0, axis1=0, axis2=1, dtype=None): + """Return the sum along diagonals of the array.""" + return self.value.trace(offset=offset, axis1=axis1, axis2=axis2, dtype=dtype) + + def transpose(self, *axes): + """Returns a view of the array with axes transposed. + """ + return self.value.transpose(*axes) + + def tile(self, reps): + """Construct an array by repeating A the number of times given by reps. + """ + reps = reps.value if isinstance(reps, JaxArray) else reps + return self.value.tile(reps) + + def var(self, axis=None, dtype=None, ddof=0, keepdims=False): + """Returns the variance of the array elements, along given axis.""" + r = self.value.var(axis=axis, dtype=dtype, ddof=ddof, keepdims=keepdims) + return r + + def view(self, dtype=None, *args, **kwargs): + """New view of array with the same data.""" + return self.value.view(dtype=dtype, *args, **kwargs) + class TrainVar(Variable): """The pointer to specify the trainable variable. """ - __slots__ = () + __slots__ = ('_value', ) def __init__(self, value): - if isinstance(value, JaxArray): - value = value.value super(TrainVar, self).__init__(value) class Parameter(Variable): """The pointer to specify the parameter. """ - __slots__ = () + __slots__ = ('_value', ) def __init__(self, value): - if isinstance(value, JaxArray): value = value.value super(Parameter, self).__init__(value)