Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LMU swapping behaviour during training #28

Merged
merged 1 commit into from
Nov 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ keras_lmu.egg-info
__pycache__
/.idea
/docs/_build
/tmp
5 changes: 5 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,12 @@ Release history
- Raise a validation error if ``hidden_to_memory`` or ``input_to_hidden`` are True
when ``hidden_cell=None``. (`#26`_)

**Fixed**

- Fixed a bug with the autoswapping in ``keras_lmu.LMU`` during training. (`#28`_)

.. _#26: https://github.com/nengo/keras-lmu/pull/26
.. _#28: https://github.com/nengo/keras-lmu/pull/28


0.3.0 (November 6, 2020)
Expand Down
74 changes: 36 additions & 38 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,37 +368,7 @@ def __init__(
self.dropout = dropout
self.recurrent_dropout = recurrent_dropout
self.return_sequences = return_sequences

if not hidden_to_memory and not memory_to_memory and memory_d == 1:
self.fft_layer = LMUFFT(
memory_d=memory_d,
order=order,
theta=theta,
hidden_cell=hidden_cell,
input_to_hidden=input_to_hidden,
kernel_initializer=kernel_initializer,
dropout=dropout,
return_sequences=return_sequences,
)
else:
self.fft_layer = None

self.rnn_layer = tf.keras.layers.RNN(
LMUCell(
memory_d=memory_d,
order=order,
theta=theta,
hidden_cell=hidden_cell,
hidden_to_memory=hidden_to_memory,
memory_to_memory=memory_to_memory,
input_to_hidden=input_to_hidden,
kernel_initializer=kernel_initializer,
recurrent_initializer=recurrent_initializer,
dropout=dropout,
recurrent_dropout=recurrent_dropout,
),
return_sequences=return_sequences,
)
self.layer = None

def build(self, input_shapes):
"""
Expand All @@ -413,10 +383,41 @@ def build(self, input_shapes):

super().build(input_shapes)

if self.fft_layer is None or input_shapes[1] is None:
self.rnn_layer.build(input_shapes)
if (
not self.hidden_to_memory
and not self.memory_to_memory
and self.memory_d == 1
and input_shapes[1] is not None
):
self.layer = LMUFFT(
memory_d=self.memory_d,
order=self.order,
theta=self.theta,
hidden_cell=self.hidden_cell,
input_to_hidden=self.input_to_hidden,
kernel_initializer=self.kernel_initializer,
dropout=self.dropout,
return_sequences=self.return_sequences,
)
else:
self.fft_layer.build(input_shapes)
self.layer = tf.keras.layers.RNN(
LMUCell(
memory_d=self.memory_d,
order=self.order,
theta=self.theta,
hidden_cell=self.hidden_cell,
hidden_to_memory=self.hidden_to_memory,
memory_to_memory=self.memory_to_memory,
input_to_hidden=self.input_to_hidden,
kernel_initializer=self.kernel_initializer,
recurrent_initializer=self.recurrent_initializer,
dropout=self.dropout,
recurrent_dropout=self.recurrent_dropout,
),
return_sequences=self.return_sequences,
)

self.layer.build(input_shapes)

def call(self, inputs, training=None):
"""
Expand All @@ -429,10 +430,7 @@ def call(self, inputs, training=None):
with some additional bookkeeping.
"""

if self.fft_layer is None or inputs.shape[1] is None:
return self.rnn_layer.call(inputs, training=training)
else:
return self.fft_layer.call(inputs, training=training)
return self.layer.call(inputs, training=training)

def get_config(self):
"""Return config of layer (for serialization during model saving/loading)."""
Expand Down
59 changes: 53 additions & 6 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test_layer_vs_cell(rng):
return_sequences=True,
)
lmu_layer.build(inp.shape)
lmu_layer.rnn_layer.set_weights(lmu_cell.get_weights())
lmu_layer.layer.set_weights(lmu_cell.get_weights())
layer_out = lmu_layer(inp)

for w0, w1 in zip(
Expand Down Expand Up @@ -218,10 +218,16 @@ def test_validation_errors():


@pytest.mark.parametrize(
"hidden_to_memory, memory_to_memory, memory_d",
[(False, False, 1), (True, False, 1), (False, True, 1), (False, False, 2)],
"hidden_to_memory, memory_to_memory, memory_d, steps",
[
(False, False, 1, 5),
(True, False, 1, 5),
(False, True, 1, 5),
(False, False, 2, 5),
(False, False, 1, None),
],
)
def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d):
def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d, steps):
lmu = layers.LMU(
memory_d,
2,
Expand All @@ -230,9 +236,10 @@ def test_fft_auto_swap(hidden_to_memory, memory_to_memory, memory_d):
hidden_to_memory=hidden_to_memory,
memory_to_memory=memory_to_memory,
)
lmu.build((32, steps, 8))

assert (lmu.fft_layer is None) == (
hidden_to_memory or memory_to_memory or memory_d != 1
assert isinstance(lmu.layer, tf.keras.layers.RNN) == (
hidden_to_memory or memory_to_memory or memory_d != 1 or steps is None
)


Expand Down Expand Up @@ -364,3 +371,43 @@ def test_dropout(dropout, recurrent_dropout, fft):
y0 = lmu(np.ones((32, 10, 64)), training=False).numpy()
y1 = lmu(np.ones((32, 10, 64)), training=False).numpy()
assert np.allclose(y0, y1)


@pytest.mark.parametrize("fft", (True, False))
def test_fit(fft):
lmu_layer = layers.LMU(
memory_d=1,
order=256,
theta=784,
hidden_cell=tf.keras.layers.SimpleRNNCell(units=10),
hidden_to_memory=not fft,
memory_to_memory=not fft,
input_to_hidden=not fft,
)

inputs = tf.keras.layers.Input((5 if fft else None, 10))
lmu = lmu_layer(inputs)
outputs = tf.keras.layers.Dense(2)(lmu)

model = tf.keras.Model(inputs=inputs, outputs=outputs)

x_train = tf.ones((5, 5, 10))
x_test = tf.ones((5, 5, 10))
y_train = tf.ones((5, 1))
y_test = tf.ones((5, 1))
model.compile(
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.RMSprop(),
metrics=["accuracy"],
)

model.fit(x_train, y_train, epochs=10, validation_split=0.2)

_, acc = model.evaluate(x_test, y_test, verbose=0)

if fft:
assert isinstance(lmu_layer.layer, layers.LMUFFT)
else:
assert isinstance(lmu_layer.layer, tf.keras.layers.RNN)

assert acc == 1.0