Skip to content

Commit

Permalink
fixup! Rework LMU API
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Oct 29, 2020
1 parent a95757d commit 5144d96
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 24 deletions.
11 changes: 7 additions & 4 deletions lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ def __init__(
self.B = None

if self.hidden_cell is None:
self.hidden_output_size = self.memory_d * self.order
# if input_to_hidden=True then we can't determine the output size
# until build time
self.hidden_output_size = (
None if input_to_hidden else self.memory_d * self.order
)
self.hidden_state_size = []
elif hasattr(self.hidden_cell, "state_size"):
self.hidden_output_size = self.hidden_cell.output_size
Expand Down Expand Up @@ -138,12 +142,11 @@ def build(self, input_shape):

super().build(input_shape)

enc_d = input_shape[-1]

if self.input_to_hidden and self.hidden_cell is None:
self.hidden_output_size += input_shape[-1]
self.hidden_output_size = self.memory_d * self.order + input_shape[-1]
self.output_size = self.hidden_output_size

enc_d = input_shape[-1]
if self.hidden_to_memory:
enc_d += self.hidden_output_size

Expand Down
28 changes: 8 additions & 20 deletions lmu/tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,33 +354,21 @@ def test_skip_connection(rng, hidden_cell):
input_d = 32

inp = tf.keras.Input(shape=(n_steps, input_d))
input_enc = rng.uniform(0, 1, size=(input_d, memory_d))

lmu = layers.LMUCell(
memory_d=memory_d,
order=order,
theta=n_steps,
kernel_initializer=tf.initializers.constant(input_enc),
hidden_cell=hidden_cell,
input_to_hidden=True,
)
out = tf.keras.layers.RNN(
lmu,
return_sequences=True,
)(inp)
assert lmu.output_size == (None if hidden_cell is None else 10)

assert (
out.shape[-1] == (memory_d * order + input_d)
if hidden_cell is None
else (hidden_cell.units)
)
assert (
lmu.hidden_output_size == (memory_d * order + input_d)
if hidden_cell is None
else (hidden_cell.units)
)
assert (
lmu.output_size == (memory_d * order + input_d)
if hidden_cell is None
else (hidden_cell.units)
out = tf.keras.layers.RNN(lmu)(inp)

output_size = (
(memory_d * order + input_d) if hidden_cell is None else hidden_cell.units
)
assert out.shape[-1] == output_size
assert lmu.hidden_output_size == output_size
assert lmu.output_size == output_size

0 comments on commit 5144d96

Please sign in to comment.