diff --git a/tests/test_zoo.py b/tests/test_zoo.py index 81a3914..1f8e543 100644 --- a/tests/test_zoo.py +++ b/tests/test_zoo.py @@ -66,7 +66,11 @@ def test_zoo_clf(model_type_name): assert arch.model == arch_dict["model"] -def test_basic_rec(): +@pytest.mark.parametrize("network_type", [ + "gru", + "lstm", +]) +def test_basic_rec(network_type): seq_len = 10 feat_dim = 2 output_dim = 1 @@ -75,7 +79,7 @@ def test_basic_rec(): hidden_dim=2, output_dim=output_dim, n_layers=1, - network_type="gru") + network_type=network_type) model = arch.build() assert model is not None diff --git a/tsgm/models/architectures/zoo.py b/tsgm/models/architectures/zoo.py index c02baad..821e709 100644 --- a/tsgm/models/architectures/zoo.py +++ b/tsgm/models/architectures/zoo.py @@ -477,7 +477,7 @@ def __init__( self.n_layers = n_layers self.network_type = network_type.lower() - assert self.network_type in ["gru", "lstm", "lstmLN"] + assert self.network_type in ["gru", "lstm"] self._name = name @@ -493,11 +493,6 @@ def _rnn_cell(self) -> keras.layers.Layer: # LSTM elif self.network_type == "lstm": cell = keras.layers.LSTMCell(self.hidden_dim, activation="tanh") - # LSTM Layer Normalization - elif self.network_type == "lstmLN": - cell = keras.layers.LayerNormLSTMCell( - num_units=self.hidden_dim, activation="tanh" - ) return cell def _make_network(self, model: keras.models.Model, activation: str, return_sequences: bool) -> keras.models.Model: