Skip to content

Commit

Permalink
add architecture tests
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderVNikitin committed Nov 19, 2023
1 parent 6e35c7a commit 91bb8a6
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 8 deletions.
8 changes: 6 additions & 2 deletions tests/test_zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@ def test_zoo_clf(model_type_name):
assert arch.model == arch_dict["model"]


def test_basic_rec():
@pytest.mark.parametrize("network_type", [
"gru",
"lstm",
])
def test_basic_rec(network_type):
seq_len = 10
feat_dim = 2
output_dim = 1
Expand All @@ -75,7 +79,7 @@ def test_basic_rec():
hidden_dim=2,
output_dim=output_dim,
n_layers=1,
network_type="gru")
network_type=network_type)
model = arch.build()
assert model is not None

Expand Down
7 changes: 1 addition & 6 deletions tsgm/models/architectures/zoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def __init__(
self.n_layers = n_layers

self.network_type = network_type.lower()
assert self.network_type in ["gru", "lstm", "lstmLN"]
assert self.network_type in ["gru", "lstm"]

self._name = name

Expand All @@ -493,11 +493,6 @@ def _rnn_cell(self) -> keras.layers.Layer:
# LSTM
elif self.network_type == "lstm":
cell = keras.layers.LSTMCell(self.hidden_dim, activation="tanh")
# LSTM Layer Normalization
elif self.network_type == "lstmLN":
cell = keras.layers.LayerNormLSTMCell(
num_units=self.hidden_dim, activation="tanh"
)
return cell

def _make_network(self, model: keras.models.Model, activation: str, return_sequences: bool) -> keras.models.Model:
Expand Down

0 comments on commit 91bb8a6

Please sign in to comment.