From a047dafa91428a5163fb3bc1f3f7d88b96014fdf Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Fri, 6 Dec 2024 20:31:50 +0800 Subject: [PATCH] add 'fft' and 'linalg' shortcut in brainunit.math (#76) * add 'fft' and 'linalg' shortcut in `brainunit.math` * update README --- .gitignore | 1 + README.md | 28 ++++++++++++++---- brainunit/_base.py | 9 ++++-- brainunit/fft/_fft_change_unit.py | 3 +- brainunit/fft/_fft_keep_unit_test.py | 16 ++-------- brainunit/math/__init__.py | 44 ++++++++++++++++------------ brainunit/math/fft.py | 18 ++++++++++++ brainunit/math/linalg.py | 18 ++++++++++++ 8 files changed, 96 insertions(+), 41 deletions(-) create mode 100644 brainunit/math/fft.py create mode 100644 brainunit/math/linalg.py diff --git a/.gitignore b/.gitignore index 2942400..2d032ae 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ io_test_tmp* brainpy/math/brainpy_object/tests/io_test_tmp* lib/ +.VSCodeCounter development diff --git a/README.md b/README.md index 68977d4..4ae9fec 100644 --- a/README.md +++ b/README.md @@ -3,23 +3,23 @@ # Physical units and unit-aware mathematical system in JAX

- Header image of brainunit. + Header image of brainunit.

Supported Python Version - LICENSE + LICENSE Documentation Status PyPI version - Continuous Integration + Continuous Integration

-[``brainunit``](https://github.com/chaoming0625/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science +[``brainunit``](https://github.com/chaor/brainunit) provides physical units and unit-aware mathematical system in JAX for brain dynamics and AI4Science ## Installation @@ -35,7 +35,25 @@ pip install brainunit --upgrade The official documentation is hosted on Read the Docs: [https://brainunit.readthedocs.io](https://brainunit.readthedocs.io) +## Unit-Aware Computation Ecosystem + + +`brainunit` has been deeply integrated into following diverse projects, such as: + +- [``brainstate``](https://github.com/chaobrain/brainstate): A State-based Transformation System for Program Compilation and Augmentation +- [``braintaichi``](https://github.com/chaobrain/braintaichi): Leveraging Taichi Lang to customize brain dynamics operators +- [``braintools``](https://github.com/chaobrain/braintools): The Common Toolbox for Brain Dynamics Programming. +- [``dendritex``](https://github.com/chaobrain/dendritex): Dendritic Modeling in JAX +- [``pinnx``](https://github.com/chaobrain/pinnx): Physics-Informed Neural Networks for Scientific Machine Learning in JAX. + +- [``diffrax``](https://github.com/chaobrain/diffrax): Numerical differential equation solvers in JAX. +- [``jax-md``](â—‹https://github.com/Routhleck/jax-md): Differentiable Molecular Dynamics in JAX +- [``Catalax``](https://github.com/Routhleck/Catalax): JAX-based framework to model biological systems +- ... + + ## See also the BDP ecosystem -We are building the [brain dynamics programming ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/). +We are building the [brain dynamics programming ecosystem](https://ecosystem-for-brain-dynamics.readthedocs.io/). +[``brainunit``](https://github.com/chaobrain/brainunit) has been deeply integrated into our BDP ecosystem. diff --git a/brainunit/_base.py b/brainunit/_base.py index 66da3b9..5b6f2a6 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -4476,7 +4476,8 @@ def new_f(*args, **kwds): if isinstance(v, bool): newkeyset[n] = v else: - raise TypeError(f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'") + raise TypeError( + f"Function '{f.__name__}' expected a boolean value for argument '{n}' but got '{v}'") elif specific_unit == 1: if isinstance(v, Quantity): @@ -4521,7 +4522,7 @@ def new_f(*args, **kwds): ) result = jax.tree.map( - partial(_assign_unit, f), result, expected_result + partial(_assign_unit, f), result, expected_result, ) return result @@ -4529,6 +4530,7 @@ def new_f(*args, **kwds): return do_assign_units + def _check_unit(f, val, unit): unit = UNITLESS if unit is None else unit if not has_same_unit(val, unit): @@ -4540,8 +4542,9 @@ def _check_unit(f, val, unit): ) raise UnitMismatchError(error_message, get_unit(val)) + def _assign_unit(f, val, unit): - if unit is None or unit == bool or unit == 1: + if unit is None or unit is bool: return val return Quantity(val, unit=unit) diff --git a/brainunit/fft/_fft_change_unit.py b/brainunit/fft/_fft_change_unit.py index a3de2b9..467a010 100644 --- a/brainunit/fft/_fft_change_unit.py +++ b/brainunit/fft/_fft_change_unit.py @@ -14,7 +14,6 @@ # ============================================================================== from __future__ import annotations -import sys from typing import Callable, Union, Sequence import jax @@ -1046,6 +1045,7 @@ def irfftn( 24: (u.Ysecond, u.yhertz), } + def _find_closet_scale(scale): values = list(_time_freq_map.keys()) @@ -1113,7 +1113,6 @@ def fftfreq( return jnpfft.fftfreq(n, d, dtype=dtype) - @set_module_as('brainunit.fft') def rfftfreq( n: int, diff --git a/brainunit/fft/_fft_keep_unit_test.py b/brainunit/fft/_fft_keep_unit_test.py index 31562b5..7490baf 100644 --- a/brainunit/fft/_fft_keep_unit_test.py +++ b/brainunit/fft/_fft_keep_unit_test.py @@ -13,29 +13,19 @@ # limitations under the License. # ============================================================================== -import jax.numpy as jnp -import pytest -from absl.testing import parameterized - -import brainunit as bu -import brainunit.math as bm -from brainunit import second, meter, ms -from brainunit._base import assert_quantity - import jax.numpy as jnp import jax.numpy.fft as jnpfft -import pytest from absl.testing import parameterized -import brainunit as u import brainunit.fft as ufft from brainunit import meter, second -from brainunit._base import assert_quantity, Unit, get_or_create_dimension +from brainunit._base import assert_quantity fft_keep_unit = [ 'fftshift', 'ifftshift', ] + class TestFftKeepUnit(parameterized.TestCase): def __init__(self, *args, **kwargs): super(TestFftKeepUnit, self).__init__(*args, **kwargs) @@ -67,4 +57,4 @@ def test_fft_keep_unit(self, value_axes, unit): q = value * unit result = ufft_fun(q, axes=axes) expected = ufft_fun(jnp.array(value), axes=axes) - assert_quantity(result, expected, unit=unit) \ No newline at end of file + assert_quantity(result, expected, unit=unit) diff --git a/brainunit/math/__init__.py b/brainunit/math/__init__.py index 022a923..f426994 100644 --- a/brainunit/math/__init__.py +++ b/brainunit/math/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================== +from . import linalg, fft from ._activation import * from ._activation import __all__ as _activation_all from ._alias import * @@ -32,22 +33,29 @@ from ._misc import * from ._misc import __all__ as _compat_misc_all -__all__ = (_compat_array_creation_all + - _alias_all + - _compat_funcs_change_unit_all + - _compat_funcs_keep_unit_all + - _compat_funcs_accept_unitless_all + - _compat_funcs_remove_unit_all + - _compat_misc_all + - _einops_all + - _activation_all) +__all__ = ( + _compat_array_creation_all + + _alias_all + + _compat_funcs_change_unit_all + + _compat_funcs_keep_unit_all + + _compat_funcs_accept_unitless_all + + _compat_funcs_remove_unit_all + + _compat_misc_all + + _einops_all + + _activation_all + + [ + 'linalg', 'fft' + ] +) -del (_compat_array_creation_all, - _alias_all, - _compat_funcs_change_unit_all, - _compat_funcs_keep_unit_all, - _compat_funcs_accept_unitless_all, - _compat_funcs_remove_unit_all, - _compat_misc_all, - _einops_all, - _activation_all) +del ( + _compat_array_creation_all, + _alias_all, + _compat_funcs_change_unit_all, + _compat_funcs_keep_unit_all, + _compat_funcs_accept_unitless_all, + _compat_funcs_remove_unit_all, + _compat_misc_all, + _einops_all, + _activation_all +) diff --git a/brainunit/math/fft.py b/brainunit/math/fft.py new file mode 100644 index 0000000..a8661f6 --- /dev/null +++ b/brainunit/math/fft.py @@ -0,0 +1,18 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from brainunit.fft import * + + diff --git a/brainunit/math/linalg.py b/brainunit/math/linalg.py new file mode 100644 index 0000000..c08fdb5 --- /dev/null +++ b/brainunit/math/linalg.py @@ -0,0 +1,18 @@ +# Copyright 2024 BDP Ecosystem Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + + +from brainunit.linalg import * +