diff --git a/brainunit/math/_compat_numpy_array_creation.py b/brainunit/math/_compat_numpy_array_creation.py index 912a87d..1e31f4d 100644 --- a/brainunit/math/_compat_numpy_array_creation.py +++ b/brainunit/math/_compat_numpy_array_creation.py @@ -59,7 +59,7 @@ def full( Union[jax.Array, Quantity]: Quantity if `unit` is provided, else an array. ''' if isinstance(fill_value, Quantity): - return jnp.full(shape, fill_value.magnitude, dtype=dtype) * fill_value.unit + return Quantity(jnp.full(shape, fill_value.value, dtype=dtype), dim=fill_value.dim) return jnp.full(shape, fill_value, dtype=dtype) @@ -498,7 +498,7 @@ def asarray( # Convert the values to a jnp.ndarray and create a Quantity object return Quantity(jnp.asarray(values, dtype=dtype, order=order), dim=unit) else: - values = jax.tree.unflatten(tree, [x.value for x in a]) + values = jax.tree.unflatten(tree, leaves) val = jnp.asarray(values, dtype=dtype, order=order) if unit is not None: assert isinstance(unit, Unit) @@ -681,7 +681,8 @@ def logspace(start: Union[Quantity, bst.typing.ArrayLike], @set_module_as('brainunit.math') def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], val: Union[Quantity, bst.typing.ArrayLike], - wrap: Optional[bool] = False) -> Union[Quantity, jax.Array]: + wrap: Optional[bool] = False, + inplace: Optional[bool] = False) -> Union[Quantity, jax.Array]: ''' Fill the main diagonal of the given array of `a` with `val`. @@ -697,14 +698,14 @@ def fill_diagonal(a: Union[Quantity, bst.typing.ArrayLike], if isinstance(val, Quantity): if isinstance(a, Quantity): fail_for_dimension_mismatch(a, val, error_message="Array and value have to have the same units.") - return Quantity(jnp.fill_diagonal(a.value, val.value, wrap), dim=a.dim) + return Quantity(jnp.fill_diagonal(a.value, val.value, wrap, inplace=inplace), dim=a.dim) else: - return Quantity(jnp.fill_diagonal(a, val.value, wrap), dim=val.dim) + return Quantity(jnp.fill_diagonal(a, val.value, wrap, inplace=inplace), dim=val.dim) else: if isinstance(a, Quantity): - return jnp.fill_diagonal(a.value, val, wrap) + return jnp.fill_diagonal(a.value, val, wrap, inplace=inplace) else: - return jnp.fill_diagonal(a, val, wrap) + return jnp.fill_diagonal(a, val, wrap, inplace=inplace) @set_module_as('brainunit.math') diff --git a/brainunit/math/_compat_numpy_test.py b/brainunit/math/_compat_numpy_test.py index 2f8de29..615d720 100644 --- a/brainunit/math/_compat_numpy_test.py +++ b/brainunit/math/_compat_numpy_test.py @@ -45,7 +45,7 @@ def test_full(self): self.assertEqual(result.shape, (3,)) self.assertTrue(jnp.all(result == 4)) - q = bu.math.full(3, 4, unit=second) + q = bu.math.full(3, 4 * second) self.assertEqual(q.shape, (3,)) assert_quantity(q, result, second) @@ -192,11 +192,11 @@ def test_logspace(self): def test_fill_diagonal(self): array = jnp.zeros((3, 3)) - result = bu.math.fill_diagonal(array, 5, inplace=False) + result = bu.math.fill_diagonal(array, 5) self.assertTrue(jnp.all(result == jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]))) q = jnp.zeros((3, 3)) * bu.second - result_q = bu.math.fill_diagonal(q, 5 * bu.second, inplace=False) + result_q = bu.math.fill_diagonal(q, 5 * bu.second) assert_quantity(result_q, jnp.array([[5, 0, 0], [0, 5, 0], [0, 0, 5]]), bu.second) def test_array_split(self):