Skip to content

Commit

Permalink
Merge pull request #348 from chaoming0625/master
Browse files Browse the repository at this point in the history
Fix `Runner(jit=False)`` bug
  • Loading branch information
chaoming0625 authored Mar 23, 2023
2 parents c2d732e + f19cbc2 commit 51d700d
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 155 deletions.
264 changes: 134 additions & 130 deletions brainpy/_src/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,8 @@ class DynamicalSystem(BrainPyObject):
The model computation mode. It should be instance of :py:class:`~.Mode`.
"""

pass_shared: bool = True
_pass_shared_args: bool = True
# pass_shared : bool = False

global_delay_data: Dict[str, Tuple[Union[bm.LengthDelay, None], Variable]] = dict()
'''Global delay data, which stores the delay variables and corresponding delay targets.
Expand Down Expand Up @@ -160,7 +161,7 @@ def __call__(self, *args, **kwargs):
if share is None:
from brainpy._src.dyn.context import share

if self.pass_shared:
if self._pass_shared_args:
if hasattr(self.update, '_new_style') and getattr(self.update, '_new_style'):
if len(args) and isinstance(args[0], dict):
share.save(**args[0])
Expand Down Expand Up @@ -414,12 +415,6 @@ def __rrshift__(self, other):
return self.__call__(other)


class DynamicalSystemNS(DynamicalSystem):
"""Dynamical system without the need of shared parameters passing into ``update()`` function."""

pass_shared = False


class Container(DynamicalSystem):
"""Container object which is designed to add other instances of DynamicalSystem.
Expand Down Expand Up @@ -519,124 +514,6 @@ def clear_input(self):
node.clear_input()


class Sequential(DynamicalSystemNS):
"""A sequential `input-output` module.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``dict`` of modules can be
passed in. The ``update()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
What's the difference between a ``Sequential`` and a
:py:class:`Container`? A ``Container`` is exactly what it
sounds like--a container to store :py:class:`DynamicalSystem` s!
On the other hand, the layers in a ``Sequential`` are connected
in a cascading way.
Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> # composing ANN models
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
>>> bm.relu,
>>> bp.layers.Dense(10, 2))
>>> l({}, bm.random.random((256, 100)))
>>>
>>> # Using Sequential with Dict. This is functionally the
>>> # same as the above code
>>> l = bp.Sequential(l1=bp.layers.Dense(100, 10),
>>> l2=bm.relu,
>>> l3=bp.layers.Dense(10, 2))
>>> l({}, bm.random.random((256, 100)))
Parameters
----------
name: str
The object name.
mode: Mode
The object computing context/mode. Default is ``None``.
"""

def __init__(
self,
*modules_as_tuple,
name: str = None,
mode: bm.Mode = None,
**modules_as_dict
):
super().__init__(name=name, mode=mode)
self._dyn_modules = bm.NodeDict()
self._static_modules = dict()
i = 0
for m in modules_as_tuple + tuple(modules_as_dict.values()):
key = self.__format_key(i)
if isinstance(m, bm.BrainPyObject):
self._dyn_modules[key] = m
else:
self._static_modules[key] = m
i += 1
self._num = i

def __format_key(self, i):
return f'l-{i}'

def __all_nodes(self):
nodes = []
for i in range(self._num):
key = self.__format_key(i)
if key not in self._dyn_modules:
nodes.append(self._static_modules[key])
else:
nodes.append(self._dyn_modules[key])
return nodes

def __getitem__(self, key: Union[int, slice, str]):
if isinstance(key, str):
if key in self._dyn_modules:
return self._dyn_modules[key]
elif key in self._static_modules:
return self._static_modules[key]
else:
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
elif isinstance(key, slice):
return Sequential(*(self.__all_nodes()[key]))
elif isinstance(key, int):
key = self.__format_key(key)
return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key]
elif isinstance(key, (tuple, list)):
nodes = []
for i in key:
if isinstance(i, int):
i = self.__format_key(i)
assert isinstance(i, str)
nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i])
return Sequential(*nodes)
else:
raise KeyError(f'Unknown type of key: {type(key)}')

def __repr__(self):
nodes = self.__all_nodes()
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes))
return f'{self.__class__.__name__}(\n{entries}\n)'

def update(self, x):
"""Update function of a sequential model.
"""
for m in self.__all_nodes():
x = m(x)
return x


class Network(Container):
"""Base class to model network objects, an alias of Container.
Expand Down Expand Up @@ -807,10 +684,6 @@ def __getitem__(self, item):
return NeuGroupView(target=self, index=item)


class NeuGroupNS(NeuGroup):
"""Base class for neuron group without shared arguments passed."""
pass_shared = False


class SynConn(DynamicalSystem):
"""Base class to model two-end synaptic connections.
Expand Down Expand Up @@ -1497,3 +1370,134 @@ def __init__(

# initialization
NeuGroup.__init__(self, tuple(size), name=name, mode=mode)


class DynamicalSystemNS(DynamicalSystem):
"""Dynamical system without the need to pass shared parameters into ``update()`` function."""

_pass_shared_args = False


class Sequential(DynamicalSystemNS):
"""A sequential `input-output` module.
Modules will be added to it in the order they are passed in the
constructor. Alternatively, an ``dict`` of modules can be
passed in. The ``update()`` method of ``Sequential`` accepts any
input and forwards it to the first module it contains. It then
"chains" outputs to inputs sequentially for each subsequent module,
finally returning the output of the last module.
The value a ``Sequential`` provides over manually calling a sequence
of modules is that it allows treating the whole container as a
single module, such that performing a transformation on the
``Sequential`` applies to each of the modules it stores (which are
each a registered submodule of the ``Sequential``).
What's the difference between a ``Sequential`` and a
:py:class:`Container`? A ``Container`` is exactly what it
sounds like--a container to store :py:class:`DynamicalSystem` s!
On the other hand, the layers in a ``Sequential`` are connected
in a cascading way.
Examples
--------
>>> import brainpy as bp
>>> import brainpy.math as bm
>>>
>>> # composing ANN models
>>> l = bp.Sequential(bp.layers.Dense(100, 10),
>>> bm.relu,
>>> bp.layers.Dense(10, 2))
>>> l({}, bm.random.random((256, 100)))
>>>
>>> # Using Sequential with Dict. This is functionally the
>>> # same as the above code
>>> l = bp.Sequential(l1=bp.layers.Dense(100, 10),
>>> l2=bm.relu,
>>> l3=bp.layers.Dense(10, 2))
>>> l({}, bm.random.random((256, 100)))
Parameters
----------
name: str
The object name.
mode: Mode
The object computing context/mode. Default is ``None``.
"""

def __init__(
self,
*modules_as_tuple,
name: str = None,
mode: bm.Mode = None,
**modules_as_dict
):
super().__init__(name=name, mode=mode)
self._dyn_modules = bm.NodeDict()
self._static_modules = dict()
i = 0
for m in modules_as_tuple + tuple(modules_as_dict.values()):
key = self.__format_key(i)
if isinstance(m, bm.BrainPyObject):
self._dyn_modules[key] = m
else:
self._static_modules[key] = m
i += 1
self._num = i

def __format_key(self, i):
return f'l-{i}'

def __all_nodes(self):
nodes = []
for i in range(self._num):
key = self.__format_key(i)
if key not in self._dyn_modules:
nodes.append(self._static_modules[key])
else:
nodes.append(self._dyn_modules[key])
return nodes

def __getitem__(self, key: Union[int, slice, str]):
if isinstance(key, str):
if key in self._dyn_modules:
return self._dyn_modules[key]
elif key in self._static_modules:
return self._static_modules[key]
else:
raise KeyError(f'Does not find a component named {key} in\n {str(self)}')
elif isinstance(key, slice):
return Sequential(*(self.__all_nodes()[key]))
elif isinstance(key, int):
key = self.__format_key(key)
return self._static_modules[key] if (key not in self._dyn_modules) else self._dyn_modules[key]
elif isinstance(key, (tuple, list)):
nodes = []
for i in key:
if isinstance(i, int):
i = self.__format_key(i)
assert isinstance(i, str)
nodes.append(self._static_modules[i] if (i not in self._dyn_modules) else self._dyn_modules[i])
return Sequential(*nodes)
else:
raise KeyError(f'Unknown type of key: {type(key)}')

def __repr__(self):
nodes = self.__all_nodes()
entries = '\n'.join(f' [{i}] {tools.repr_object(x)}' for i, x in enumerate(nodes))
return f'{self.__class__.__name__}(\n{entries}\n)'

def update(self, x):
"""Update function of a sequential model.
"""
for m in self.__all_nodes():
x = m(x)
return x


class NeuGroupNS(NeuGroup):
"""Base class for neuron group without shared arguments passed."""
_pass_shared_args = False

8 changes: 4 additions & 4 deletions brainpy/_src/dyn/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -647,10 +647,10 @@ def _get_f_predict(self, shared_args: Dict = None):
dyn_vars = dyn_vars.unique()

def run_func(all_inputs):
with jax.disable_jit(not self.jit['predict']):
return bm.for_loop(partial(self._step_func_predict, shared_args),
all_inputs,
dyn_vars=dyn_vars)
return bm.for_loop(partial(self._step_func_predict, shared_args),
all_inputs,
dyn_vars=dyn_vars,
jit=self.jit['predict'])

if self.jit['predict']:
self._f_predict_compiled[shared_kwargs_str] = bm.jit(run_func, dyn_vars=dyn_vars)
Expand Down
11 changes: 5 additions & 6 deletions brainpy/_src/integrators/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,12 +230,11 @@ def __init__(
self.idx = bm.Variable(bm.zeros(1, dtype=bm.int_))

def _run_fun_integration(self, static_args, dyn_args, times, indices):
with jax.disable_jit(not self.jit['predict']):
dyn_vars = self.vars().unique()
dyn_vars = dyn_vars - dyn_vars.subset(bm.VariableView)
return bm.for_loop(partial(self._step_fun_integrator, static_args),
(dyn_args, times, indices),
dyn_vars=dyn_vars)
dyn_vars = self.vars().unique()
return bm.for_loop(partial(self._step_fun_integrator, static_args),
(dyn_args, times, indices),
dyn_vars=dyn_vars,
jit=self.jit['predict'])

def _step_fun_integrator(self, static_args, dyn_args, t, i):
# arguments
Expand Down
2 changes: 1 addition & 1 deletion brainpy/_src/testing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
pass


class UniTestCase(unittest.TestCase):
class UnitTestCase(unittest.TestCase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
bm.random.seed()
Expand Down
2 changes: 1 addition & 1 deletion brainpy/testing.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from brainpy._src.testing.base import UniTestCase
from brainpy._src.testing.base import UnitTestCase
Loading

0 comments on commit 51d700d

Please sign in to comment.