Skip to content

Commit

Permalink
Add Converter.layers attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Feb 28, 2020
1 parent a13fb9f commit 1131cf3
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 12 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ Release history
- Added support for UpSampling layers to ``nengo_dl.Converter``. (`#130`_)
- Added tolerance parameters to ``nengo_dl.Converter.verify``. (`#130`_)
- Added ``scale_firing_rates`` option to ``nengo_dl.Converter``. (`#134`_)
- Added ``Converter.layers`` attribute which will map Keras layers/tensors to
the converted Nengo objects, to make it easier to access converted components.
(`#134`_)

**Changed**

Expand Down
62 changes: 50 additions & 12 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,15 @@ class Converter:
equivalent Functional model).
net : `nengo.Network`
The converted Nengo network.
inputs : `.Converter.TensorDict`
inputs : `.Converter.KerasTensorDict`
Maps from Keras model inputs to input Nodes in the converted Nengo network.
For example, ``my_node = Converter(my_model).inputs[my_model.input]``.
outputs : `.Converter.TensorDict`
outputs : `.Converter.KerasTensorDict`
Maps from Keras model outputs to output Probes in the converted Nengo network.
For example, ``my_probe = Converter(my_model).outputs[my_model.output]``.
layers : `.Converter.KerasTensorDict`
Maps from Keras model layers to the converted Nengo object.
For example, ``my_neurons = Converter(my_model).layers[my_layer]``.
"""

converters = {}
Expand Down Expand Up @@ -129,8 +132,8 @@ def __init__(
else:
self.model = model

# track inputs/outputs of model on network object
self.inputs = Converter.TensorDict()
# data structures to track converted objects
self.inputs = Converter.KerasTensorDict()
for input in self.model.inputs:
(
input_layer,
Expand All @@ -141,7 +144,7 @@ def __init__(
input_tensor_id
]

self.outputs = Converter.TensorDict()
self.outputs = Converter.KerasTensorDict()
for output in self.model.outputs:
(
output_layer,
Expand All @@ -156,6 +159,19 @@ def __init__(
logger.info("Probing %s (%s)", output_obj, output)
self.outputs[output] = nengo.Probe(output_obj)

self.layers = Converter.KerasTensorDict()
for layer in self.model.layers:
for node_id, node_outputs in self.layer_map[layer].items():
for nengo_obj in node_outputs:
output_tensor = layer.inbound_nodes[node_id].output_tensors

# assuming layers with only one output (for now)
if isinstance(output_tensor, list):
assert len(output_tensor) == 1
output_tensor = output_tensor[0]

self.layers[output_tensor] = nengo_obj

def verify(self, training=False, inputs=None, atol=1e-8, rtol=1e-5):
"""
Verify that output of converted Nengo network matches the original Keras model.
Expand Down Expand Up @@ -336,23 +352,45 @@ def register_converter(convert_cls):

return register_converter

class TensorDict:
"""A dictionary-like object that works with TensorFlow Tensors."""
class KerasTensorDict(collections.abc.Mapping):
"""
A dictionary-like object that has extra logic to handle Layer/Tensor keys.
"""

def __init__(self):
self.dict = collections.OrderedDict()

def __setitem__(self, key, val):
def _get_key(self, key):
if isinstance(key, tf.keras.layers.Layer):
if len(key.inbound_nodes) > 1 or (
isinstance(key.inbound_nodes[0].output_tensors, list)
and len(key.inbound_nodes[0].output_tensors) > 1
):
raise KeyError(
"Layer %s is ambiguous because it has multiple output tensors; "
"use a specific tensor as key instead" % key
)

# get output tensor
key = key.output

if isinstance(key, tf.Tensor):
# get hashable key
key = key.experimental_ref()

self.dict[key] = val
return key

def __setitem__(self, key, val):
self.dict[self._get_key(key)] = val

def __getitem__(self, key):
if isinstance(key, tf.Tensor):
key = key.experimental_ref()
return self.dict[self._get_key(key)]

def __iter__(self):
return iter(ref._wrapped for ref in self.dict)

return self.dict[key]
def __len__(self):
return len(self.dict)


class LayerConverter:
Expand Down
48 changes: 48 additions & 0 deletions nengo_dl/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -653,3 +653,51 @@ def test_scale_firing_rates_cases(Simulator, scale_firing_rates, expected_rates)
* sim.dt,
atol=1,
)


def test_layer_dicts():
inp0 = tf.keras.Input(shape=(1,))
inp1 = tf.keras.Input(shape=(1,))
add = tf.keras.layers.Add()([inp0, inp1])
dense_node = tf.keras.layers.Dense(units=1)(add)
dense_ens = tf.keras.layers.Dense(units=1, activation=tf.nn.relu)(dense_node)

model = tf.keras.Model([inp0, inp1], [dense_node, dense_ens])

conv = converter.Converter(model)
assert len(conv.inputs) == 2
assert len(conv.outputs) == 2
assert len(conv.layers) == 5

# inputs/outputs/layers referencing the same stuff
assert isinstance(conv.outputs[dense_node], nengo.Probe)
assert conv.outputs[dense_node].target is conv.layers[dense_node]
assert conv.inputs[inp0] is conv.layers[inp0]

# look up by tensor
assert isinstance(conv.layers[dense_node], nengo.Node)
assert isinstance(conv.layers[dense_ens], nengo.ensemble.Neurons)

# look up by layer
assert isinstance(conv.layers[model.layers[-2]], nengo.Node)
assert isinstance(conv.layers[model.layers[-1]], nengo.ensemble.Neurons)

# iterating over dict works as expected
for i, tensor in enumerate(conv.layers):
assert model.layers.index(tensor._keras_history.layer) == i

# applying the same layer multiple times
inp = tf.keras.Input(shape=(1,))
layer = tf.keras.layers.ReLU()
x0 = layer(inp)
x1 = layer(inp)

model = tf.keras.Model(inp, [x0, x1])

conv = converter.Converter(model, split_shared_weights=True)

with pytest.raises(KeyError, match="multiple output tensors"):
assert conv.layers[layer]

assert conv.outputs[x0].target is conv.layers[x0]
assert conv.outputs[x1].target is conv.layers[x1]

0 comments on commit 1131cf3

Please sign in to comment.