diff --git a/brainpy/_src/dyn/base.py b/brainpy/_src/dyn/base.py index 50abef90a..25ca31499 100644 --- a/brainpy/_src/dyn/base.py +++ b/brainpy/_src/dyn/base.py @@ -114,7 +114,8 @@ class DynamicalSystem(BrainPyObject): The model computation mode. It should be instance of :py:class:`~.Mode`. """ - pass_shared: bool = True + _pass_shared_args: bool = True + # pass_shared : bool = False global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict() '''Global delay data, which stores the delay variables and corresponding delay targets. @@ -160,7 +161,7 @@ def __call__(self, *args, **kwargs): if share is None: from brainpy._src.dyn.context import share - if self.pass_shared: + if self._pass_shared_args: if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'): if len(args) and isinstance(args[0], dict): share.save(**args[0]) @@ -414,12 +415,6 @@ def __rrshift__(self, other): return self.__call__(other) -class DynamicalSystemNS(DynamicalSystem): - """Dynamical system without the need of shared parameters passing into ``update()`` function.""" - - pass_shared = False - - class Container(DynamicalSystem): """Container object which is designed to add other instances of DynamicalSystem. @@ -519,124 +514,6 @@ def clear_input(self): node.clear_input() -class Sequential(DynamicalSystemNS): - """A sequential `input-output` module. - - Modules will be added to it in the order they are passed in the - constructor. Alternatively, an ``dict`` of modules can be - passed in. The ``update()`` method of ``Sequential`` accepts any - input and forwards it to the first module it contains. It then - "chains" outputs to inputs sequentially for each subsequent module, - finally returning the output of the last module. - - The value a ``Sequential`` provides over manually calling a sequence - of modules is that it allows treating the whole container as a - single module, such that performing a transformation on the - ``Sequential`` applies to each of the modules it stores (which are - each a registered submodule of the ``Sequential``). - - What's the difference between a ``Sequential`` and a - :py:class:`Container`? A ``Container`` is exactly what it - sounds like--a container to store :py:class:`DynamicalSystem` s! - On the other hand, the layers in a ``Sequential`` are connected - in a cascading way. - - Examples - -------- - - >>> import brainpy as bp - >>> import brainpy.math as bm - >>> - >>> # composing ANN models - >>> l = bp.Sequential(bp.layers.Dense(100, 10), - >>> bm.relu, - >>> bp.layers.Dense(10, 2)) - >>> l({}, bm.random.random((256, 100))) - >>> - >>> # Using Sequential with Dict. This is functionally the - >>> # same as the above code - >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), - >>> l2=bm.relu, - >>> l3=bp.layers.Dense(10, 2)) - >>> l({}, bm.random.random((256, 100))) - - Parameters - ---------- - name: str - The object name. - mode: Mode - The object computing context/mode. Default is ``None``. - """ - - def __init__( - self, - *modules_as_tuple, - name: str = None, - mode: bm.Mode = None, - **modules_as_dict - ): - super().__init__(name=name, mode=mode) - self._dyn_modules = bm.NodeDict() - self._static_modules = dict() - i = 0 - for m in modules_as_tuple + tuple(modules_as_dict.values()): - key = self.__format_key(i) - if isinstance(m, bm.BrainPyObject): - self._dyn_modules[key] = m - else: - self._static_modules[key] = m - i += 1 - self._num = i - - def __format_key(self, i): - return f'l-{i}' - - def __all_nodes(self): - nodes = [] - for i in range(self._num): - key = self.__format_key(i) - if key not in self._dyn_modules: - nodes.append(self._static_modules[key]) - else: - nodes.append(self._dyn_modules[key]) - return nodes - - def __getitem__(self, key: Union[int, slice, str]): - if isinstance(key, str): - if key in self._dyn_modules: - return self._dyn_modules[key] - elif key in self._static_modules: - return self._static_modules[key] - else: - raise KeyError(f'Does not find a component named {key} in\n {str(self)}') - elif isinstance(key, slice): - return Sequential(*(self.__all_nodes()[key])) - elif isinstance(key, int): - key = self.__format_key(key) - return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key] - elif isinstance(key, (tuple, list)): - nodes = [] - for i in key: - if isinstance(i, int): - i = self.__format_key(i) - assert isinstance(i, str) - nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i]) - return Sequential(*nodes) - else: - raise KeyError(f'Unknown type of key: {type(key)}') - - def __repr__(self): - nodes = self.__all_nodes() - entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes)) - return f'{self.__class__.__name__}(\n{entries}\n)' - - def update(self, x): - """Update function of a sequential model. - """ - for m in self.__all_nodes(): - x = m(x) - return x - class Network(Container): """Base class to model network objects, an alias of Container. @@ -807,10 +684,6 @@ def __getitem__(self, item): return NeuGroupView(target=self, index=item) -class NeuGroupNS(NeuGroup): - """Base class for neuron group without shared arguments passed.""" - pass_shared = False - class SynConn(DynamicalSystem): """Base class to model two-end synaptic connections. @@ -1497,3 +1370,134 @@ def __init__( # initialization NeuGroup.__init__(self, tuple(size), name=name, mode=mode) + + +class DynamicalSystemNS(DynamicalSystem): + """Dynamical system without the need to pass shared parameters into ``update()`` function.""" + + _pass_shared_args = False + + +class Sequential(DynamicalSystemNS): + """A sequential `input-output` module. + + Modules will be added to it in the order they are passed in the + constructor. Alternatively, an ``dict`` of modules can be + passed in. The ``update()`` method of ``Sequential`` accepts any + input and forwards it to the first module it contains. It then + "chains" outputs to inputs sequentially for each subsequent module, + finally returning the output of the last module. + + The value a ``Sequential`` provides over manually calling a sequence + of modules is that it allows treating the whole container as a + single module, such that performing a transformation on the + ``Sequential`` applies to each of the modules it stores (which are + each a registered submodule of the ``Sequential``). + + What's the difference between a ``Sequential`` and a + :py:class:`Container`? A ``Container`` is exactly what it + sounds like--a container to store :py:class:`DynamicalSystem` s! + On the other hand, the layers in a ``Sequential`` are connected + in a cascading way. + + Examples + -------- + + >>> import brainpy as bp + >>> import brainpy.math as bm + >>> + >>> # composing ANN models + >>> l = bp.Sequential(bp.layers.Dense(100, 10), + >>> bm.relu, + >>> bp.layers.Dense(10, 2)) + >>> l({}, bm.random.random((256, 100))) + >>> + >>> # Using Sequential with Dict. This is functionally the + >>> # same as the above code + >>> l = bp.Sequential(l1=bp.layers.Dense(100, 10), + >>> l2=bm.relu, + >>> l3=bp.layers.Dense(10, 2)) + >>> l({}, bm.random.random((256, 100))) + + Parameters + ---------- + name: str + The object name. + mode: Mode + The object computing context/mode. Default is ``None``. + """ + + def __init__( + self, + *modules_as_tuple, + name: str = None, + mode: bm.Mode = None, + **modules_as_dict + ): + super().__init__(name=name, mode=mode) + self._dyn_modules = bm.NodeDict() + self._static_modules = dict() + i = 0 + for m in modules_as_tuple + tuple(modules_as_dict.values()): + key = self.__format_key(i) + if isinstance(m, bm.BrainPyObject): + self._dyn_modules[key] = m + else: + self._static_modules[key] = m + i += 1 + self._num = i + + def __format_key(self, i): + return f'l-{i}' + + def __all_nodes(self): + nodes = [] + for i in range(self._num): + key = self.__format_key(i) + if key not in self._dyn_modules: + nodes.append(self._static_modules[key]) + else: + nodes.append(self._dyn_modules[key]) + return nodes + + def __getitem__(self, key: Union[int, slice, str]): + if isinstance(key, str): + if key in self._dyn_modules: + return self._dyn_modules[key] + elif key in self._static_modules: + return self._static_modules[key] + else: + raise KeyError(f'Does not find a component named {key} in\n {str(self)}') + elif isinstance(key, slice): + return Sequential(*(self.__all_nodes()[key])) + elif isinstance(key, int): + key = self.__format_key(key) + return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key] + elif isinstance(key, (tuple, list)): + nodes = [] + for i in key: + if isinstance(i, int): + i = self.__format_key(i) + assert isinstance(i, str) + nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i]) + return Sequential(*nodes) + else: + raise KeyError(f'Unknown type of key: {type(key)}') + + def __repr__(self): + nodes = self.__all_nodes() + entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes)) + return f'{self.__class__.__name__}(\n{entries}\n)' + + def update(self, x): + """Update function of a sequential model. + """ + for m in self.__all_nodes(): + x = m(x) + return x + + +class NeuGroupNS(NeuGroup): + """Base class for neuron group without shared arguments passed.""" + _pass_shared_args = False + diff --git a/brainpy/_src/dyn/runners.py b/brainpy/_src/dyn/runners.py index c28d20d2f..cedc1ca76 100644 --- a/brainpy/_src/dyn/runners.py +++ b/brainpy/_src/dyn/runners.py @@ -647,10 +647,10 @@ def _get_f_predict(self, shared_args: Dict = None): dyn_vars = dyn_vars.unique() def run_func(all_inputs): - with jax.disable_jit(not self.jit['predict']): - return bm.for_loop(partial(self._step_func_predict, shared_args), - all_inputs, - dyn_vars=dyn_vars) + return bm.for_loop(partial(self._step_func_predict, shared_args), + all_inputs, + dyn_vars=dyn_vars, + jit=self.jit['predict']) if self.jit['predict']: self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars) diff --git a/brainpy/_src/integrators/runner.py b/brainpy/_src/integrators/runner.py index 351a8fe0d..22dca79c8 100644 --- a/brainpy/_src/integrators/runner.py +++ b/brainpy/_src/integrators/runner.py @@ -230,12 +230,11 @@ def __init__( self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_)) def _run_fun_integration(self, static_args, dyn_args, times, indices): - with jax.disable_jit(not self.jit['predict']): - dyn_vars = self.vars().unique() - dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView) - return bm.for_loop(partial(self._step_fun_integrator, static_args), - (dyn_args, times, indices), - dyn_vars=dyn_vars) + dyn_vars = self.vars().unique() + return bm.for_loop(partial(self._step_fun_integrator, static_args), + (dyn_args, times, indices), + dyn_vars=dyn_vars, + jit=self.jit['predict']) def _step_fun_integrator(self, static_args, dyn_args, t, i): # arguments diff --git a/brainpy/_src/testing/base.py b/brainpy/_src/testing/base.py index 0812a1e4c..9b2f4e7bd 100644 --- a/brainpy/_src/testing/base.py +++ b/brainpy/_src/testing/base.py @@ -7,7 +7,7 @@ pass -class UniTestCase(unittest.TestCase): +class UnitTestCase(unittest.TestCase): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) bm.random.seed() diff --git a/brainpy/testing.py b/brainpy/testing.py index b113e8499..f06131a3b 100644 --- a/brainpy/testing.py +++ b/brainpy/testing.py @@ -1 +1 @@ -from brainpy._src.testing.base import UniTestCase +from brainpy._src.testing.base import UnitTestCase diff --git a/docs/tutorial_simulation/parallel_computing.ipynb b/docs/tutorial_simulation/parallel_computing.ipynb index 500c245af..eed73e941 100644 --- a/docs/tutorial_simulation/parallel_computing.ipynb +++ b/docs/tutorial_simulation/parallel_computing.ipynb @@ -88,7 +88,7 @@ "# define your function\n", "def run_model(par):\n", " model = YourModel(par)\n", - " runner = bp.dyn.DSRunner(model)\n", + " runner = bp.DSRunner(model)\n", " runner.run(duration)\n", " return runner.mon\n", "\n", @@ -134,7 +134,7 @@ " import brainpy as bp # needed to reimport packages when\n", " # run the function in Jupyter\n", " model = bp.neurons.HH(1)\n", - " runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])\n", + " runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])\n", " runner.run(1000.)\n", " return runner.mon['spike'].sum() # \"output\" is the spike number" ], @@ -227,7 +227,7 @@ "\n", " bg_current = bp.math.as_jax(bg_current)\n", " model = bp.neurons.HH(1)\n", - " runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])\n", + " runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current])\n", " runner.run(1000.)\n", "\n", " bp.math.clear_buffer_memory()\n", @@ -316,7 +316,7 @@ "\n", "def run_model(par):\n", " model = YourModel(par)\n", - " runner = bp.dyn.DSRunner(model)\n", + " runner = bp.DSRunner(model)\n", " runner.run(duration)\n", " return runner.mon\n", "\n", @@ -347,7 +347,7 @@ "source": [ "def hh_spike_num3(bg_current): # \"input\" is the bg_current\n", " model = bp.neurons.HH(1)\n", - " runner = bp.dyn.DSRunner(model, monitors=['spike'], inputs=['input', bg_current],\n", + " runner = bp.DSRunner(model, monitors=['spike'], inputs=['input', bg_current],\n", " numpy_mon_after_run=False)\n", " runner.run(1000.)\n", " return runner.mon['spike'].sum() # \"output\" is the spike number" @@ -533,7 +533,7 @@ "\n", "def run_model(par):\n", " model = YourModel(par)\n", - " runner = bp.dyn.DSRunner(model)\n", + " runner = bp.DSRunner(model)\n", " runner.run()\n", " return runner.mon\n", "\n", diff --git a/examples/dynamics_training/reservoir-mnist.py b/examples/dynamics_training/reservoir-mnist.py index 63cc289f4..a868b8bf8 100644 --- a/examples/dynamics_training/reservoir-mnist.py +++ b/examples/dynamics_training/reservoir-mnist.py @@ -73,8 +73,7 @@ def force_online_train(num_hidden=2000, num_in=28, num_out=10, train_stage='fina rls = bp.algorithms.RLS() rls.register_target(num_hidden) - @bm.jit - @bm.to_object(child_objs=(reservoir, readout, rls)) + @bm.jit(child_objs=(reservoir, readout, rls)) def train_step(xs, y): reservoir.reset_state(xs.shape[0]) if train_stage == 'final_step': @@ -92,8 +91,7 @@ def train_step(xs, y): else: raise ValueError - @bm.jit - @bm.to_object(child_objs=(reservoir, readout)) + @bm.jit(child_objs=(reservoir, readout)) def predict(xs): reservoir.reset_state(xs.shape[0]) for x in xs.transpose(1, 0, 2): diff --git a/tests/simulation/test_net_rate_SL.py b/tests/simulation/test_net_rate_SL.py index fbce07d9f..fad1dd6ed 100644 --- a/tests/simulation/test_net_rate_SL.py +++ b/tests/simulation/test_net_rate_SL.py @@ -25,7 +25,7 @@ def __init__(self, noise=0.14): ) -class TestSL(bp.testing.UniTestCase): +class TestSL(bp.testing.UnitTestCase): def test1(self): net = Network() runner = bp.DSRunner(net, monitors=['sl.x']) diff --git a/tests/simulation/test_neu_HH.py b/tests/simulation/test_neu_HH.py index 907465116..1d5e46d68 100644 --- a/tests/simulation/test_neu_HH.py +++ b/tests/simulation/test_neu_HH.py @@ -90,7 +90,7 @@ def update(self, x=None): return dV_grad -class TestHH(bp.testing.UniTestCase): +class TestHH(bp.testing.UnitTestCase): def test1(self): bm.random.seed() hh = HH(1) @@ -107,7 +107,7 @@ def test1(self): def test2(self): bm.random.seed() with bp.math.environment(dt=0.1): - hh = HH(1) + hh = bp.neurons.HH(1) looper = bp.LoopOverTime(hh, out_vars=(hh.V, hh.m, hh.n, hh.h)) grads, (vs, ms, ns, hs) = looper(bm.ones(1000) * 5)