Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

data pass of the Node is default SingleData #148

Merged
merged 4 commits into from
Apr 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""Neural Networks (nn)"""

from .base import *
from .constants import *
from .datatypes import *
from .graph_flow import *
from .nodes import *
from .graph_flow import *
Expand Down
50 changes: 24 additions & 26 deletions brainpy/nn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
MathError)
from brainpy.nn.algorithms.offline import OfflineAlgorithm
from brainpy.nn.algorithms.online import OnlineAlgorithm
from brainpy.nn.constants import (PASS_SEQUENCE,
DATA_PASS_FUNC,
DATA_PASS_TYPES)
from brainpy.nn.datatypes import (DataType, SingleData, MultipleData)
from brainpy.nn.graph_flow import (find_senders_and_receivers,
find_entries_and_exits,
detect_cycle,
Expand Down Expand Up @@ -83,13 +81,13 @@ def feedback(self):
class Node(Base):
"""Basic Node class for neural network building in BrainPy."""

'''Support multiple types of data pass, including "PASS_SEQUENCE" (by default),
"PASS_NAME_DICT", "PASS_NODE_DICT" and user-customized type which registered
by ``brainpy.nn.register_data_pass_type()`` function.
'''Support multiple types of data pass, including "PassOnlyOne" (by default),
"PassSequence", "PassNameDict", etc. and user-customized type which inherits
from basic "SingleData" or "MultipleData".

This setting will change the feedforward/feedback input data which pass into
the "call()" function and the sizes of the feedforward/feedback input data.'''
data_pass_type = PASS_SEQUENCE
data_pass = SingleData()

'''Offline fitting method.'''
offline_fit_by: Union[Callable, OfflineAlgorithm]
Expand All @@ -115,11 +113,10 @@ def __init__(
self._trainable = trainable
self._state = None # the state of the current node
self._fb_output = None # the feedback output of the current node
# data pass function
if self.data_pass_type not in DATA_PASS_FUNC:
raise ValueError(f'Unsupported data pass type {self.data_pass_type}. '
f'Only support {DATA_PASS_TYPES}')
self.data_pass_func = DATA_PASS_FUNC[self.data_pass_type]
# data pass
if not isinstance(self.data_pass, DataType):
raise ValueError(f'Unsupported data pass type {type(self.data_pass)}. '
f'Only support {DataType.__class__}')

# super initialization
super(Node, self).__init__(name=name)
Expand All @@ -129,11 +126,10 @@ def __init__(
self._feedforward_shapes = {self.name: (None,) + tools.to_size(input_shape)}

def __repr__(self):
name = type(self).__name__
prefix = ' ' * (len(name) + 1)
line1 = f"{name}(name={self.name}, forwards={self.feedforward_shapes}, \n"
line2 = f"{prefix}feedbacks={self.feedback_shapes}, output={self.output_shape})"
return line1 + line2
return (f"{type(self).__name__}(name={self.name}, "
f"forwards={self.feedforward_shapes}, "
f"feedbacks={self.feedback_shapes}, "
f"output={self.output_shape})")

def __call__(self, *args, **kwargs) -> Tensor:
"""The main computation function of a Node.
Expand Down Expand Up @@ -298,7 +294,7 @@ def trainable(self, value: bool):
@property
def feedforward_shapes(self):
"""Input data size."""
return self.data_pass_func(self._feedforward_shapes)
return self.data_pass.filter(self._feedforward_shapes)

@feedforward_shapes.setter
def feedforward_shapes(self, size):
Expand All @@ -324,7 +320,7 @@ def set_feedforward_shapes(self, feedforward_shapes: Dict):
@property
def feedback_shapes(self):
"""Output data size."""
return self.data_pass_func(self._feedback_shapes)
return self.data_pass.filter(self._feedback_shapes)

@feedback_shapes.setter
def feedback_shapes(self, size):
Expand Down Expand Up @@ -530,8 +526,8 @@ def _check_inputs(self, ff, fb=None):
f'batch size by ".initialize(num_batch)", or change the data '
f'consistent with the data batch size {self.state.shape[0]}.')
# data
ff = self.data_pass_func(ff)
fb = self.data_pass_func(fb)
ff = self.data_pass.filter(ff)
fb = self.data_pass.filter(fb)
return ff, fb

def _call(self,
Expand Down Expand Up @@ -747,6 +743,8 @@ def set_state(self, state):
class Network(Node):
"""Basic Network class for neural network building in BrainPy."""

data_pass = MultipleData('sequence')

def __init__(self,
nodes: Optional[Sequence[Node]] = None,
ff_edges: Optional[Sequence[Tuple[Node]]] = None,
Expand Down Expand Up @@ -1145,8 +1143,8 @@ def _check_inputs(self, ff, fb=None):
check_shape_except_batch(size, fb[k].shape)

# data transformation
ff = self.data_pass_func(ff)
fb = self.data_pass_func(fb)
ff = self.data_pass.filter(ff)
fb = self.data_pass.filter(fb)
return ff, fb

def _call(self,
Expand Down Expand Up @@ -1208,12 +1206,12 @@ def _call(self,
def _call_a_node(self, node, ff, fb, monitors, forced_states,
parent_outputs, children_queue, ff_senders,
**shared_kwargs):
ff = node.data_pass_func(ff)
ff = node.data_pass.filter(ff)
if f'{node.name}.inputs' in monitors:
monitors[f'{node.name}.inputs'] = ff
# get the output results
if len(fb):
fb = node.data_pass_func(fb)
fb = node.data_pass.filter(fb)
if f'{node.name}.feedbacks' in monitors:
monitors[f'{node.name}.feedbacks'] = fb
parent_outputs[node] = node.forward(ff, fb, **shared_kwargs)
Expand Down Expand Up @@ -1440,7 +1438,7 @@ def plot_node_graph(self,
if len(nodes_untrainable):
proxie.append(Line2D([], [], color='white', marker='o',
markerfacecolor=untrainable_color))
labels.append('Untrainable')
labels.append('Nontrainable')
if len(ff_edges):
proxie.append(Line2D([], [], color=ff_color, linewidth=2))
labels.append('Feedforward')
Expand Down
114 changes: 0 additions & 114 deletions brainpy/nn/constants.py

This file was deleted.

97 changes: 97 additions & 0 deletions brainpy/nn/datatypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# -*- coding: utf-8 -*-


__all__ = [
# data types
'DataType',

# pass rules
'SingleData',
'MultipleData',
]


class DataType(object):
"""Base class for data type."""

def filter(self, data):
raise NotImplementedError

def __repr__(self):
return self.__class__.__name__


class SingleData(DataType):
"""Pass the only one data into the node.
If there are multiple data, an error will be raised. """

def filter(self, data):
if data is None:
return None
if len(data) > 1:
raise ValueError(f'{self.__class__.__name__} only support one '
f'feedforward/feedback input. But we got {len(data)}.')
return tuple(data.values())[0]

def __repr__(self):
return self.__class__.__name__


class MultipleData(DataType):
"""Pass a list/tuple of data into the node."""

def __init__(self, return_type: str = 'sequence'):
if return_type not in ['sequence', 'name_dict', 'type_dict', 'node_dict']:
raise ValueError(f"Only support return type of 'sequence', 'name_dict', "
f"'type_dict' and 'node_dict'. But we got {return_type}")
self.return_type = return_type

from brainpy.nn.base import Node

if return_type == 'sequence':
f = lambda data: tuple(data.values())

elif return_type == 'name_dict':
# Pass a dict with <node name, data> into the node.

def f(data):
_res = dict()
for node, val in data.items():
if isinstance(node, str):
_res[node] = val
elif isinstance(node, Node):
_res[node.name] = val
else:
raise ValueError(f'Unknown type {type(node)}: node')
return _res

elif return_type == 'type_dict':
# Pass a dict with <node type, data> into the node.

def f(data):
_res = dict()
for node, val in data.items():
if isinstance(node, str):
_res[str] = val
elif isinstance(node, Node):
_res[type(node.name)] = val
else:
raise ValueError(f'Unknown type {type(node)}: node')
return _res

elif return_type == 'node_dict':
# Pass a dict with <node, data> into the node.
f = lambda data: data

else:
raise ValueError
self.return_func = f

def __repr__(self):
return f'{self.__class__.__name__}(return_type={self.return_type})'

def filter(self, data):
if data is None:
return None
else:
return self.return_func(data)
8 changes: 2 additions & 6 deletions brainpy/nn/nodes/ANN/batch_norm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
# -*- coding: utf-8 -*-

from typing import Sequence, Optional, Dict, Callable, Union
from typing import Union

import jax.nn
import jax.numpy as jnp

import brainpy.math as bm
import brainpy
import brainpy.math as bm
from brainpy.initialize import ZeroInit, OneInit, Initializer
from brainpy.nn.base import Node
from brainpy.nn.constants import PASS_ONLY_ONE


__all__ = [
'BatchNorm',
Expand Down Expand Up @@ -43,8 +41,6 @@ class BatchNorm(Node):
gamma_init: brainpy.init.Initializer
an initializer generating the original scaling matrix
"""
data_pass_type = PASS_ONLY_ONE

def __init__(self,
axis: Union[int, tuple, list],
epsilon: float = 1e-5,
Expand Down
Loading