Skip to content

Commit

Permalink
Correctly set output_size when hidden_cell=None
Browse files Browse the repository at this point in the history
  • Loading branch information
drasmuss committed Apr 19, 2024
1 parent 4d7fe22 commit 330856d
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions keras_lmu/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ def __init__(
order,
theta,
hidden_cell,
input_d=None,
trainable_theta=False,
hidden_to_memory=False,
memory_to_memory=False,
Expand All @@ -141,6 +142,7 @@ def __init__(
self.order = order
self._init_theta = theta
self.hidden_cell = hidden_cell
self.input_d = input_d
self.trainable_theta = trainable_theta
self.hidden_to_memory = hidden_to_memory
self.memory_to_memory = memory_to_memory
Expand Down Expand Up @@ -178,6 +180,15 @@ def __init__(
)

self.hidden_output_size = self.memory_d * self.order

if self.input_to_hidden:
if self.input_d is None:
raise ValueError(
"input_d must be specified when setting input_to_hidden=True "
"with hidden_cell=None"
)
self.hidden_output_size += self.input_d

self.hidden_state_size = []
elif hasattr(self.hidden_cell, "state_size"):
self.hidden_output_size = self.hidden_cell.output_size
Expand Down Expand Up @@ -272,6 +283,12 @@ def build(self, input_shape):

super().build(input_shape)

if self.input_d is not None and input_shape[-1] != self.input_d:
raise ValueError(
f"Input dimensionality ({input_shape[-1]}) does not match expected "
f"dimensionality ({self.input_d})"
)

enc_d = input_shape[-1]
if self.hidden_to_memory:
enc_d += self.hidden_output_size
Expand Down Expand Up @@ -470,6 +487,7 @@ def get_config(self):
"order": self.order,
"theta": self._init_theta,
"hidden_cell": keras.layers.serialize(self.hidden_cell),
"input_d": self.input_d,
"trainable_theta": self.trainable_theta,
"hidden_to_memory": self.hidden_to_memory,
"memory_to_memory": self.memory_to_memory,
Expand Down

0 comments on commit 330856d

Please sign in to comment.