Skip to content

Commit

Permalink
Add magnitude conversion for asarray
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed Jun 12, 2024
1 parent bb439c6 commit ea4e9d5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
9 changes: 7 additions & 2 deletions brainunit/math/_compat_numpy_array_creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ def asarray(
Convert the input to a quantity or array.
If unit is provided, the input will be checked whether it has the same unit as the provided unit.
(If they have same dimension but different magnitude, the input will be converted to the provided unit.)
If unit is not provided, the input will be converted to an array.
Args:
Expand All @@ -513,9 +514,13 @@ def asarray(
if unit is not None:
assert isinstance(unit, Unit)
fail_for_dimension_mismatch(a, unit, error_message="a and unit have to have the same units.")
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim)
if a.dim == unit:
return a
else:
# Convert to the magnitude of the provided unit
return Quantity(a.value / unit.value, dim=unit)
else:
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order), dim=a.dim)
return Quantity(jnp.asarray(a.value, dtype=dtype, order=order) / unit.value, dim=a.dim)
elif isinstance(a, (jax.Array, np.ndarray)):
if unit is not None:
assert isinstance(unit, Unit)
Expand Down
8 changes: 6 additions & 2 deletions brainunit/math/_compat_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert q.dim == unit.dim, f"Unit mismatch: {q.dim} != {unit}"
assert q.dim == unit.dim or q.dim == unit, f"Unit mismatch: {q.dim} != {unit}"
assert jnp.allclose(q.value, values), f"Values do not match: {q.value} != {values}"
else:
assert jnp.allclose(q, values), f"Values do not match: {q} != {values}"
Expand Down Expand Up @@ -162,6 +162,10 @@ def test_asarray(self):
result_q = bu.math.asarray([1, 2, 3], unit=bu.second)
assert_quantity(result_q, jnp.asarray([1, 2, 3]), bu.second)

q1 = [1, 2, 3] * bu.second
result_q = bu.math.asarray(q1, unit=bu.ms)
assert_quantity(result_q, jnp.asarray([1, 2, 3]) * 1000, bu.ms)

def test_arange(self):
result = bu.math.arange(5)
self.assertEqual(result.shape, (5,))
Expand All @@ -171,7 +175,7 @@ def test_arange(self):
assert_quantity(result_q, jnp.arange(5, step=1), bu.second)

result_q = bu.math.arange(3 * bu.second, 9 * bu.second, 1 * bu.second)
assert_quantity(result_q, jnp.arange(3, 9, 1), bu.second)
assert_quantity(result_q, jnp.arange(3, 9, 1), bu.ms)

def test_linspace(self):
result = bu.math.linspace(0, 10, 5)
Expand Down

0 comments on commit ea4e9d5

Please sign in to comment.