Skip to content

Commit

Permalink
Fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 12, 2024
1 parent 512f734 commit 1647baa
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 10 deletions.
15 changes: 8 additions & 7 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def full(
Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array.
'''
if isinstance(fill_value, Quantity):
return jnp.full(shape, fill_value.magnitude, dtype=dtype) * fill_value.unit
return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim)
return jnp.full(shape, fill_value, dtype=dtype)


Expand Down Expand Up @@ -498,7 +498,7 @@ def asarray(
# Convert the values to a jnp.ndarray and create a Quantity object
return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit)
else:
values = jax.tree.unflatten(tree, [x.value for x in a])
values = jax.tree.unflatten(tree, leaves)
val = jnp.asarray(values, dtype=dtype, order=order)
if unit is not None:
assert isinstance(unit, Unit)
Expand Down Expand Up @@ -681,7 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike],
@set_module_as('brainunit.math')
def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
val: Union[Quantity, bst.typing.ArrayLike],
wrap: Optional[bool] = False) -> Union[Quantity, jax.Array]:
wrap: Optional[bool] = False,
inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]:
'''
Fill the main diagonal of the given array of `a` with `val`.
Expand All @@ -697,14 +698,14 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike],
if isinstance(val, Quantity):
if isinstance(a, Quantity):
fail_for_dimension_mismatch(a, val, error_message="Array and value have to have the same units.")
return Quantity(jnp.fill_diagonal(a.value, val.value, wrap), dim=a.dim)
return Quantity(jnp.fill_diagonal(a.value, val.value, wrap, inplace=inplace), dim=a.dim)
else:
return Quantity(jnp.fill_diagonal(a, val.value, wrap), dim=val.dim)
return Quantity(jnp.fill_diagonal(a, val.value, wrap, inplace=inplace), dim=val.dim)
else:
if isinstance(a, Quantity):
return jnp.fill_diagonal(a.value, val, wrap)
return jnp.fill_diagonal(a.value, val, wrap, inplace=inplace)
else:
return jnp.fill_diagonal(a, val, wrap)
return jnp.fill_diagonal(a, val, wrap, inplace=inplace)


@set_module_as('brainunit.math')
Expand Down
6 changes: 3 additions & 3 deletions brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_full(self):
self.assertEqual(result.shape, (3,))
self.assertTrue(jnp.all(result == 4))

q = bu.math.full(3, 4, unit=second)
q = bu.math.full(3, 4 * second)
self.assertEqual(q.shape, (3,))
assert_quantity(q, result, second)

Expand Down Expand Up @@ -192,11 +192,11 @@ def test_logspace(self):

def test_fill_diagonal(self):
array = jnp.zeros((3, 3))
result = bu.math.fill_diagonal(array, 5, inplace=False)
result = bu.math.fill_diagonal(array, 5)
self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]])))

q = jnp.zeros((3, 3)) * bu.second
result_q = bu.math.fill_diagonal(q, 5 * bu.second, inplace=False)
result_q = bu.math.fill_diagonal(q, 5 * bu.second)
assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), bu.second)

def test_array_split(self):
Expand Down

0 comments on commit 1647baa

Please sign in to comment.