Skip to content

Commit

Permalink
Remove problematic source tensor sorting
Browse files Browse the repository at this point in the history
This is no longer required, and is problematic for models that
have an output that is used other places in the model (since
the sorting puts all outputs at the end).

Also ensure a better error for unconverted input tensor
  • Loading branch information
hunse authored and drasmuss committed Mar 26, 2020
1 parent eb03d44 commit b8394f6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 21 deletions.
3 changes: 3 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ Release history
- Added ``tensorflow-cpu`` distributions to installation checks (so Nengo DL will
not attempt to reinstall TensorFlow if ``tensorflow-cpu`` is already installed).
(`#142`_)
- Fixed bug when applying the Converter to Keras models that re-use intermediate
layers as output layers. (`#137`_)

**Deprecated**

Expand All @@ -88,6 +90,7 @@ Release history
.. _#129: https://github.com/nengo/nengo-dl/pull/129
.. _#130: https://github.com/nengo/nengo-dl/pull/130
.. _#134: https://github.com/nengo/nengo-dl/pull/134
.. _#137: https://github.com/nengo/nengo-dl/pull/137
.. _#139: https://github.com/nengo/nengo-dl/pull/139
.. _#140: https://github.com/nengo/nengo-dl/pull/140
.. _#142: https://github.com/nengo/nengo-dl/pull/142
Expand Down
27 changes: 6 additions & 21 deletions nengo_dl/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -627,10 +627,7 @@ def get_input_obj(self, node_id, tensor_idx=0):

input_layer, input_node_id, input_tensor_id = self.get_history(tensor)

if input_node_id in self.converter.layer_map[input_layer]:
return self.converter.layer_map[input_layer][input_node_id][input_tensor_id]
else:
return None
return self.converter.layer_map[input_layer][input_node_id][input_tensor_id]

def _get_shape(self, input_output, node_id, include_batch=False):
"""
Expand Down Expand Up @@ -820,18 +817,6 @@ def convert(self, node_id):
# that need to be built into the Nengo network
source_tensors = self.trace_tensors(self.layer.outputs)

def sort_key(x):
# sort tensors so that order of model inputs/outputs is preserved
for i, y in enumerate(self.layer.inputs):
if x is y:
return -(len(self.layer.inputs) - i)
for i, y in enumerate(self.layer.outputs):
if x is y:
return i + 1
return 0

source_tensors = sorted(source_tensors, key=sort_key)

for tensor in source_tensors:
# look up the layer/node to be converted
model_layer, model_node_id, _ = self.get_history(tensor)
Expand Down Expand Up @@ -1467,11 +1452,11 @@ class ConvertInput(LayerConverter):
"""Convert ``tf.keras.layers.InputLayer`` to Nengo objects."""

def convert(self, node_id):
# if this input layer has an input obj, that means it is a passthrough
# (so we just return the input)
output = self.get_input_obj(node_id)

if output is None:
try:
# if this input layer has an input obj, that means it is a passthrough
# (so we just return the input)
output = self.get_input_obj(node_id)
except KeyError:
# not a passthrough input, so create input node
shape = self.output_shape(node_id)
if any(x is None for x in shape):
Expand Down
17 changes: 17 additions & 0 deletions nengo_dl/tests/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,3 +701,20 @@ def test_layer_dicts():

assert conv.outputs[x0].target is conv.layers[x0]
assert conv.outputs[x1].target is conv.layers[x1]


def test_mid_model_output(Simulator):
"""Check that converter supports output tensors from the middle of the model.
Previous converter put output tensors last in build order, so having an output
tensor that needed to be built before non-output tensors was problematic.
https://github.com/nengo/nengo-dl/pull/137
"""

# model must have at least three layers, with one layer in between outputs
inp = tf.keras.Input(shape=(1,))
x0 = tf.keras.layers.ReLU()(inp)
x1 = tf.keras.layers.ReLU()(x0)
x2 = tf.keras.layers.ReLU()(x1)

_test_convert(inp, [x0, x2], inp_vals=[np.ones((4, 1))])

0 comments on commit b8394f6

Please sign in to comment.