Skip to content

Commit

Permalink
feat: add "__jax_array__" support
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Apr 5, 2022
1 parent 352a586 commit 2d546e1
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions brainpy/math/jaxarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,6 +848,9 @@ def __array__(self):
"""Support ``numpy.array()`` and ``numpy.asarray()`` functions."""
return np.asarray(self.value)

def __jax_array__(self):
return self.value


ndarray = JaxArray

Expand All @@ -865,7 +868,6 @@ class Variable(JaxArray):
def __init__(self, value):
if isinstance(value, JaxArray):
value = value.value
# assert jnp.ndim(value) >= 1, 'Must be an array, not scalar.'
super(Variable, self).__init__(value)


Expand All @@ -875,7 +877,8 @@ class TrainVar(Variable):
__slots__ = ()

def __init__(self, value):
if isinstance(value, JaxArray): value = value.value
if isinstance(value, JaxArray):
value = value.value
super(TrainVar, self).__init__(value)


Expand Down

0 comments on commit 2d546e1

Please sign in to comment.