Skip to content

Commit

Permalink
Merge branch 'fix-asarray' into fix
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoming0625 committed Jun 12, 2024
2 parents d3f38e9 + 1647baa commit f5bbe80
Show file tree
Hide file tree
Showing 22 changed files with 5,928 additions and 1,594 deletions.
73 changes: 40 additions & 33 deletions brainunit/_unit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def assert_allclose(actual, desired, rtol=4.5e8, atol=0, **kwds):
def assert_quantity(q, values, unit):
values = jnp.asarray(values)
if isinstance(q, Quantity):
assert have_same_unit(q.unit, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
assert have_same_unit(q.dim, unit), f"Dimension mismatch: ({get_unit(q)}) ({get_unit(unit)})"
if not jnp.allclose(q.value, values):
raise AssertionError(f"Values do not match: {q.value} != {values}")
elif isinstance(q, jnp.ndarray):
Expand Down Expand Up @@ -144,10 +144,10 @@ def test_get_dimensions():
Test various ways of getting/comparing the dimensions of a Array.
"""
q = 500 * ms
assert get_unit(q) is get_or_create_dimension(q.unit._dims)
assert get_unit(q) is q.unit
assert get_unit(q) is get_or_create_dimension(q.dim._dims)
assert get_unit(q) is q.dim
assert q.has_same_unit(3 * second)
dims = q.unit
dims = q.dim
assert_equal(dims.get_dimension("time"), 1.0)
assert_equal(dims.get_dimension("length"), 0)

Expand Down Expand Up @@ -201,47 +201,54 @@ def test_unary_operations():


def test_operations():
q1 = Quantity(5, dim=mV)
q2 = Quantity(10, dim=mV)
assert_quantity(q1 + q2, 15, mV)
assert_quantity(q1 - q2, -5, mV)
assert_quantity(q1 * q2, 50, mV * mV)
q1 = 5 * second
q2 = 10 * second
assert_quantity(q1 + q2, 15, second)
assert_quantity(q1 - q2, -5, second)
assert_quantity(q1 * q2, 50, second * second)
assert_quantity(q2 / q1, 2, DIMENSIONLESS)
assert_quantity(q2 // q1, 2, DIMENSIONLESS)
assert_quantity(q2 % q1, 0, mV)
assert_quantity(q2 % q1, 0, second)
assert_quantity(divmod(q2, q1)[0], 2, DIMENSIONLESS)
assert_quantity(divmod(q2, q1)[1], 0, mV)
assert_quantity(q1 ** 2, 25, mV ** 2)
assert_quantity(q1 << 1, 10, mV)
assert_quantity(q1 >> 1, 2, mV)
assert_quantity(round(q1, 0), 5, mV)
assert_quantity(divmod(q2, q1)[1], 0, second)
assert_quantity(q1 ** 2, 25, second ** 2)
assert_quantity(round(q1, 0), 5, second)

# matmul
q1 = Quantity([1, 2], dim=mV)
q2 = Quantity([3, 4], dim=mV)
assert_quantity(q1 @ q2, 11, mV ** 2)
q1 = [1, 2] * second
q2 = [3, 4] * second
assert_quantity(q1 @ q2, 11, second ** 2)
q1 = Quantity([1, 2], unit=second)
q2 = Quantity([3, 4], unit=second)
assert_quantity(q1 @ q2, 11, second ** 2)

# shift
q1 = Quantity(0b1100, dtype=jnp.int32, dim=DIMENSIONLESS)
assert_quantity(q1 << 1, 0b11000, second)
assert_quantity(q1 >> 1, 0b110, second)


def test_numpy_methods():
q = Quantity([[1, 2], [3, 4]], dim=mV)
q = [[1, 2], [3, 4]] * second
assert q.all()
assert q.any()
assert q.nonzero()[0].tolist() == [0, 0, 1, 1]
assert q.argmax() == 3
assert q.argmin() == 0
assert q.argsort(axis=None).tolist() == [0, 1, 2, 3]
assert_quantity(q.var(), 1.25, mV ** 2)
assert_quantity(q.round(), [[1, 2], [3, 4]], mV)
assert_quantity(q.std(), 1.11803398875, mV)
assert_quantity(q.sum(), 10, mV)
assert_quantity(q.trace(), 5, mV)
assert_quantity(q.cumsum(), [1, 3, 6, 10], mV)
assert_quantity(q.cumprod(), [1, 2, 6, 24], mV ** 4)
assert_quantity(q.diagonal(), [1, 4], mV)
assert_quantity(q.max(), 4, mV)
assert_quantity(q.mean(), 2.5, mV)
assert_quantity(q.min(), 1, mV)
assert_quantity(q.ptp(), 3, mV)
assert_quantity(q.ravel(), [1, 2, 3, 4], mV)
assert_quantity(q.var(), 1.25, second ** 2)
assert_quantity(q.round(), [[1, 2], [3, 4]], second)
assert_quantity(q.std(), 1.11803398875, second)
assert_quantity(q.sum(), 10, second)
assert_quantity(q.trace(), 5, second)
assert_quantity(q.cumsum(), [1, 3, 6, 10], second)
assert_quantity(q.cumprod(), [1, 2, 6, 24], second ** 4)
assert_quantity(q.diagonal(), [1, 4], second)
assert_quantity(q.max(), 4, second)
assert_quantity(q.mean(), 2.5, second)
assert_quantity(q.min(), 1, second)
assert_quantity(q.ptp(), 3, second)
assert_quantity(q.ravel(), [1, 2, 3, 4], second)


def test_shape_manipulation():
Expand Down Expand Up @@ -1596,7 +1603,7 @@ def test_constants():
import brainunit._unit_constants as constants

# Check that the expected names exist and have the correct dimensions
assert constants.avogadro_constant.dim == (1 / mole).unit
assert constants.avogadro_constant.dim == (1 / mole).dim
assert constants.boltzmann_constant.dim == (joule / kelvin).dim
assert constants.electric_constant.dim == (farad / meter).dim
assert constants.electron_mass.dim == kilogram.dim
Expand Down
65 changes: 60 additions & 5 deletions brainunit/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,67 @@
# limitations under the License.
# ==============================================================================

from ._compat_numpy import *
from ._compat_numpy import __all__ as _compat_numpy_all
# from ._compat_numpy import *
# from ._compat_numpy import __all__ as _compat_numpy_all
from ._others import *
from ._others import __all__ as _other_all
from ._compat_numpy_array_creation import *
from ._compat_numpy_array_creation import __all__ as _compat_array_creation_all
from ._compat_numpy_array_manipulation import *
from ._compat_numpy_array_manipulation import __all__ as _compat_array_manipulation_all
from ._compat_numpy_funcs_accept_unitless import *
from ._compat_numpy_funcs_accept_unitless import __all__ as _compat_funcs_accept_unitless_all
from ._compat_numpy_funcs_bit_operation import *
from ._compat_numpy_funcs_bit_operation import __all__ as _compat_funcs_bit_operation_all
from ._compat_numpy_funcs_change_unit import *
from ._compat_numpy_funcs_change_unit import __all__ as _compat_funcs_change_unit_all
from ._compat_numpy_funcs_indexing import *
from ._compat_numpy_funcs_indexing import __all__ as _compat_funcs_indexing_all
from ._compat_numpy_funcs_keep_unit import *
from ._compat_numpy_funcs_keep_unit import __all__ as _compat_funcs_keep_unit_all
from ._compat_numpy_funcs_logic import *
from ._compat_numpy_funcs_logic import __all__ as _compat_funcs_logic_all
from ._compat_numpy_funcs_match_unit import *
from ._compat_numpy_funcs_match_unit import __all__ as _compat_funcs_match_unit_all
from ._compat_numpy_funcs_remove_unit import *
from ._compat_numpy_funcs_remove_unit import __all__ as _compat_funcs_remove_unit_all
from ._compat_numpy_funcs_window import *
from ._compat_numpy_funcs_window import __all__ as _compat_funcs_window_all
from ._compat_numpy_get_attribute import *
from ._compat_numpy_get_attribute import __all__ as _compat_get_attribute_all
from ._compat_numpy_linear_algebra import *
from ._compat_numpy_linear_algebra import __all__ as _compat_linear_algebra_all
from ._compat_numpy_misc import *
from ._compat_numpy_misc import __all__ as _compat_misc_all

__all__ = _compat_numpy_all + _other_all

del _compat_numpy_all, _other_all
__all__ = _compat_array_creation_all + \
_compat_array_manipulation_all + \
_compat_funcs_change_unit_all + \
_compat_funcs_keep_unit_all + \
_compat_funcs_accept_unitless_all + \
_compat_funcs_match_unit_all + \
_compat_funcs_remove_unit_all + \
_compat_get_attribute_all + \
_compat_funcs_bit_operation_all + \
_compat_funcs_logic_all + \
_compat_funcs_indexing_all + \
_compat_funcs_window_all + \
_compat_linear_algebra_all + \
_compat_misc_all + _other_all + \
_other_all

del _compat_array_creation_all, \
_compat_array_manipulation_all, \
_compat_funcs_change_unit_all, \
_compat_funcs_keep_unit_all, \
_compat_funcs_accept_unitless_all, \
_compat_funcs_match_unit_all, \
_compat_funcs_remove_unit_all, \
_compat_get_attribute_all, \
_compat_funcs_bit_operation_all, \
_compat_funcs_logic_all, \
_compat_funcs_indexing_all, \
_compat_funcs_window_all, \
_compat_linear_algebra_all, \
_compat_misc_all, \
_other_all
Loading

0 comments on commit f5bbe80

Please sign in to comment.