Skip to content

Commit

Permalink
Fix bug with hidden_to_memory and no hidden_cell
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Nov 10, 2020
1 parent 4ca879f commit 424f89d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 13 deletions.
6 changes: 6 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ Release history
0.3.1 (unreleased)
==================

**Fixed**

- Fixed a bug when ``hidden_to_memory=True`` and ``hidden_cell=None``. (`#26`_)

.. _#26: https://github.com/nengo/keras-lmu/pull/26


0.3.0 (November 6, 2020)
========================
Expand Down
6 changes: 5 additions & 1 deletion keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,11 @@ def call(self, inputs, states, training=None):
m = states[-1]

# compute memory input
u_in = tf.concat((inputs, h[0]), axis=1) if self.hidden_to_memory else inputs
u_in = inputs
if self.hidden_to_memory:
# if the hidden cell is None then the hidden state is equivalent to
# the memory state
u_in = tf.concat((u_in, m if self.hidden_cell is None else h[0]), axis=1)
if self.dropout > 0:
u_in *= self.get_dropout_mask_for_cell(u_in, training)
u = tf.matmul(u_in, self.kernel)
Expand Down
28 changes: 16 additions & 12 deletions keras_lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,14 @@ def test_hidden_types(hidden_cell, fft, rng, seed):


@pytest.mark.parametrize("fft", (True, False))
def test_connection_params(fft):
@pytest.mark.parametrize("hidden_cell", (None, tf.keras.layers.Dense))
def test_connection_params(fft, hidden_cell):
input_shape = (32, 7 if fft else None, 6)
lmu_args = dict(
memory_d=1,
order=3,
theta=4,
hidden_cell=tf.keras.layers.Dense(units=5),
hidden_cell=hidden_cell if hidden_cell is None else hidden_cell(units=5),
input_to_hidden=False,
)
if not fft:
Expand All @@ -287,32 +288,35 @@ def test_connection_params(fft):
assert lmu.kernel.shape == (input_shape[-1], lmu.memory_d)
if not fft:
assert lmu.recurrent_kernel is None
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order,
lmu.hidden_cell.units,
)
if hidden_cell is not None:
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order,
lmu.hidden_cell.units,
)

lmu_args["input_to_hidden"] = True
if not fft:
lmu_args["hidden_to_memory"] = True
lmu_args["memory_to_memory"] = True

lmu = layers.LMUCell(**lmu_args) if not fft else layers.LMUFFT(**lmu_args)
lmu.hidden_cell.built = False # so that the kernel will be rebuilt
if hidden_cell is not None:
lmu.hidden_cell.built = False # so that the kernel will be rebuilt
lmu.build(input_shape)
assert lmu.kernel.shape == (
input_shape[-1] + (lmu.hidden_cell.units if not fft else 0),
input_shape[-1] + (lmu.hidden_output_size if not fft else 0),
lmu.memory_d,
)
if not fft:
assert lmu.recurrent_kernel.shape == (
lmu.order * lmu.memory_d,
lmu.memory_d,
)
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order + input_shape[-1],
lmu.hidden_cell.units,
)
if hidden_cell is not None:
assert lmu.hidden_cell.kernel.shape == (
lmu.memory_d * lmu.order + input_shape[-1],
lmu.hidden_cell.units,
)


@pytest.mark.parametrize("dropout, recurrent_dropout", [(0, 0), (0.5, 0), (0, 0.5)])
Expand Down

0 comments on commit 424f89d

Please sign in to comment.