diff --git a/CHANGES.rst b/CHANGES.rst index 0e19d09fc..7aac0100d 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -59,6 +59,7 @@ Release history - Reduced the amount of state that needs to be stored in the simulation. (`#129`_) - Added more information to the error message when loading saved parameters that don't match the current model. (`#129`_) +- More efficient implementation of convolutional biases in the Converter. (`#130`_) **Fixed** diff --git a/nengo_dl/converter.py b/nengo_dl/converter.py index 5b9b8f252..f1789eb0c 100644 --- a/nengo_dl/converter.py +++ b/nengo_dl/converter.py @@ -436,7 +436,7 @@ def add_nengo_obj(self, node_id, biases=None, activation=None): if biases is not None: # use a connection from a constant node (so that the bias # values will be trainable) - bias_node = nengo.Node([1], label="%s.bias_node" % name) + bias_node = nengo.Node([1], label="%s.bias" % name) nengo.Connection( bias_node, obj, transform=biases[:, None], synapse=None ) @@ -1137,7 +1137,9 @@ def convert(self, node_id): broadcast_bias = np.ravel(broadcast_bias) # connect up bias node to output - bias_node = nengo.Node(broadcast_bias) + bias_node = nengo.Node( + broadcast_bias, label="%s.%d.bias" % (self.layer.name, node_id) + ) conn = nengo.Connection(bias_node, output, synapse=None) self.converter.net.config[conn].trainable = False @@ -1150,34 +1152,6 @@ def convert(self, node_id): ) self.converter.net.config[conn].trainable = False - # this is an alternate approach, where rather than broadcasting scale/bias, - # we create individual connections for each element in the batch normalization - # axis. this will result in smaller weight matrices, but more Connections - # TODO: figure out where the tradeoffs lie between these two approaches - # bias_node = nengo.Node(np.ones(idxs[slices].size)) - # - # # for each element in the batch normalization axis - # for i in range(idxs.shape[axis]): - # # slice out one element of the output along the axis - # slices[axis] = i - # slice_idxs = np.ravel(idxs[slices]) - # sliced_output = output[slice_idxs] - # - # # connect up bias - # conn = nengo.Connection( - # bias_node, sliced_output, synapse=None, transform=bias[i], - # ) - # self.converter.net.config[conn].trainable = False - # - # # connect up input with scale applied - # conn = nengo.Connection( - # self.get_input_obj(node_id)[slice_idxs], - # sliced_output, - # synapse=None, - # transform=scale[i], - # ) - # self.converter.net.config[conn].trainable = False - return output @classmethod @@ -1248,31 +1222,41 @@ def convert(self, node_id, dimensions): # conv layer biases are per-output-channel, rather than per-output-element, # so we need to set up a nengo connection structure that will have one # bias parameter shared across all the spatial dimensions - if self.layer.data_format == "channels_first": - spatial_size = np.prod(self.output_shape(node_id)[1:]) - bias_node = nengo.Node(np.ones(spatial_size), label="conv_bias") - offset = 0 - for i in range(self.output_shape(node_id)[0]): - nengo.Connection( - bias_node, - output[offset : offset + spatial_size], - transform=biases[i], - synapse=None, - ) - offset += spatial_size - else: - spatial_size = np.prod(self.output_shape(node_id)[:-1]) - bias_node = nengo.Node(np.ones(spatial_size), label="conv_bias") - idxs = np.arange(np.prod(self.output_shape(node_id))).reshape( - (-1, self.output_shape(node_id)[-1]) + + # add trainable bias weights + bias_node = nengo.Node([1], label="%s.%d.bias" % (self.layer.name, node_id)) + bias_relay = nengo.Node(size_in=len(biases)) + nengo.Connection( + bias_node, bias_relay, transform=biases[:, None], synapse=None + ) + + # use a non-trainable sparse transform to broadcast biases along all + # non-channel dimensions + broadcast_indices = [] + idxs = np.arange(np.prod(self.output_shape(node_id))).reshape( + self.output_shape(node_id) + ) + slices = [slice(None) for _ in range(len(self.output_shape(node_id)))] + n_spatial = np.prod( + self.output_shape(node_id)[:-1] + if self.layer.data_format == "channels_last" + else self.output_shape(node_id)[1:] + ) + axis = -1 if self.layer.data_format == "channels_last" else 0 + for i in range(self.output_shape(node_id)[axis]): + slices[axis] = i + broadcast_indices.extend( + tuple(zip(np.ravel(idxs[tuple(slices)]), [i] * n_spatial)) ) - for i in range(self.output_shape(node_id)[-1]): - nengo.Connection( - bias_node, - output[idxs[:, i]], - transform=biases[i], - synapse=None, - ) + conn = nengo.Connection( + bias_relay, + output, + transform=nengo.Sparse( + (output.size_in, bias_relay.size_out), indices=broadcast_indices + ), + synapse=None, + ) + self.converter.net.config[conn].trainable = False # set up a convolutional transform that matches the layer parameters transform = nengo.Convolution( diff --git a/nengo_dl/graph_optimizer.py b/nengo_dl/graph_optimizer.py index 646e3923d..8600caa25 100644 --- a/nengo_dl/graph_optimizer.py +++ b/nengo_dl/graph_optimizer.py @@ -1383,6 +1383,13 @@ def is_identity(x, sig): if sig.ndim == 1: return x.shape == (d,) and np.allclose(x, 1) + if isinstance(x, SparseMatrix): + return ( + x.shape == (d, d) + and np.allclose(x.data, 1) + and np.allclose(x.indices, [[i, i] for i in range(d)]) + ) + return ( x.shape == (d, d) and np.allclose(np.diag(x), 1)