Skip to content

Commit

Permalink
Merge pull request #347 from chaoming0625/master
Browse files Browse the repository at this point in the history
Update docs and tests
  • Loading branch information
chaoming0625 authored Mar 23, 2023
2 parents 0a90f17 + 6b40112 commit c2d732e
Show file tree
Hide file tree
Showing 9 changed files with 129 additions and 68 deletions.
12 changes: 0 additions & 12 deletions .github/workflows/CI-models.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand All @@ -47,8 +45,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install jax==0.3.25
pip install jaxlib==0.3.25
Expand All @@ -74,8 +70,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip uninstall brainpy -y
python setup.py install
Expand All @@ -99,8 +93,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
pip install jax==0.3.25
pip install jaxlib==0.3.25
Expand All @@ -127,8 +119,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
Expand Down Expand Up @@ -156,8 +146,6 @@ jobs:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install flake8 pytest
python -m pip install numpy>=1.21.0
python -m pip install "jaxlib==0.3.25" -f https://whls.blob.core.windows.net/unstable/index.html --use-deprecated legacy-resolver
python -m pip install https://github.com/google/jax/archive/refs/tags/jax-v0.3.25.tar.gz
Expand Down
2 changes: 1 addition & 1 deletion brainpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
# Part 6: Others #
# ------------------ #

from . import running
from . import running, testing
from ._src.visualization import (visualize as visualize)
from ._src.running.runner import (Runner as Runner)

Expand Down
Empty file.
19 changes: 19 additions & 0 deletions brainpy/_src/testing/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import unittest
import brainpy.math as bm

try:
from absl.testing import parameterized
except ImportError:
pass


class UniTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
bm.random.seed()
self.rng = bm.random.default_rng()

def __del__(self):
bm.clear_buffer_memory()


1 change: 1 addition & 0 deletions brainpy/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from brainpy._src.testing.base import UniTestCase
147 changes: 98 additions & 49 deletions docs/core_concept/brainpy_dynamical_system.ipynb

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions tests/simulation/test_net_rate_SL.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(self, noise=0.14):
)


class TestSL(unittest.TestCase):
class TestSL(bp.testing.UniTestCase):
def test1(self):
net = Network()
runner = bp.DSRunner(net, monitors=['sl.x'])
Expand All @@ -41,4 +41,3 @@ def test1(self):
plt.tight_layout()
plt.show()


7 changes: 3 additions & 4 deletions tests/simulation/test_neu_HH.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,9 @@ def update(self, x=None):
return dV_grad


class TestHH(unittest.TestCase):
class TestHH(bp.testing.UniTestCase):
def test1(self):
bm.random.seed()
hh = HH(1)
I, length = bp.inputs.section_input(values=[0, 5, 0], durations=[10, 100, 10], return_length=True)
runner = bp.DSRunner(
Expand All @@ -102,9 +103,9 @@ def test1(self):

if show:
bp.visualize.line_plot(runner.mon.ts, runner.mon.V, show=True)
bp.math.clear_buffer_memory()

def test2(self):
bm.random.seed()
with bp.math.environment(dt=0.1):
hh = HH(1)
looper = bp.LoopOverTime(hh, out_vars=(hh.V, hh.m, hh.n, hh.h))
Expand All @@ -124,5 +125,3 @@ def test2(self):
fig.add_subplot(gs[3, 0])
bp.visualize.line_plot(ts, bm.exp(grads * bm.dt), show=True)

bm.clear_buffer_memory()

6 changes: 6 additions & 0 deletions tests/training/test_ESN.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def test_train_esn_with_ridge(self):
print(bp.losses.mean_absolute_error(outputs, Y))

def test_train_esn_with_force(self, num_in=100, num_out=30):
bm.random.seed()

with bm.batching_environment():
model = ESN(num_in, 2000, num_out)

Expand All @@ -97,6 +99,8 @@ def test_train_esn_with_force(self, num_in=100, num_out=30):
print(bp.losses.mean_absolute_error(outputs, Y))

def test_ngrc(self, num_in=10, num_out=30):
bm.random.seed()

with bm.batching_environment():
model = NGRC(num_in, num_out)

Expand All @@ -111,6 +115,8 @@ def test_ngrc(self, num_in=10, num_out=30):
print(bp.losses.mean_absolute_error(outputs, Y))

def test_ngrc_bacth(self, num_in=10, num_out=30):
bm.random.seed()

with bm.batching_environment():
model = NGRC(num_in, num_out)
batch_size = 10
Expand Down

0 comments on commit c2d732e

Please sign in to comment.