Skip to content

Commit

Permalink
Merge pull request #262 from philipperemy/fix_256
Browse files Browse the repository at this point in the history
Fix bug introduced in Keras 3
  • Loading branch information
philipperemy authored Aug 13, 2024
2 parents 55403a2 + f430095 commit fde0141
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [ 3.9 ]
python-version: [ "3.10" ]

steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion tasks/multi_length_sequences.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def get_x_y(max_time_steps):
m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

gen = get_x_y(max_time_steps=MAX_TIME_STEP)
m.fit(gen, epochs=1, steps_per_epoch=1000, max_queue_size=1, verbose=2)
m.fit(gen, epochs=1, steps_per_epoch=1000, verbose=2)
4 changes: 2 additions & 2 deletions tasks/plot_tcn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
num_features = 4

inputs = tf.keras.layers.Input(shape=input_shape, name='input')
tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='LeakyReLU')(inputs)
tcn_out = TCN(nb_filters=64, kernel_size=3, nb_stacks=1, activation='relu')(inputs)
outputs = tf.keras.layers.Dense(forecast_horizon * num_features, activation='linear')(tcn_out)
outputs = tf.reshape(outputs, shape=(-1, forecast_horizon, num_features), name='ouput')
outputs = tf.keras.layers.Reshape((forecast_horizon, num_features), name='ouput')(outputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)

tf.keras.utils.plot_model(
Expand Down
4 changes: 2 additions & 2 deletions tasks/save_reload_sequential_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
with open('model.json', "w") as json_file:
json_file.write(model_as_json)
# save weights to file (for this format, need h5py installed)
model.save_weights('weights.h5')
model.save_weights('model.weights.h5')

# Make inference.
inputs = np.ones(shape=(1, 100))
Expand All @@ -36,7 +36,7 @@
tcn_full_summary(model, expand_residual_blocks=False)

# restore weights
reloaded_model.load_weights('weights.h5')
reloaded_model.load_weights('model.weights.h5')

# Make inference.
out2 = reloaded_model.predict(inputs)[0, 0]
Expand Down
22 changes: 0 additions & 22 deletions tasks/tcn_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
from tensorflow.keras import Input
from tensorflow.keras import Model
from tensorflow.keras.models import Sequential

from tcn import TCN

Expand Down Expand Up @@ -100,27 +99,6 @@ def test_non_causal_time_dim_unknown_return_no_sequences(self):
r = predict_with_tcn(time_steps=None, padding='same', return_sequences=False)
self.assertListEqual([list(b.shape) for b in r], [[1, NB_FILTERS], [1, NB_FILTERS], [1, NB_FILTERS]])

def test_norms(self):
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_weight_norm=False)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_layer_norm=False)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True)]).compile(optimizer='adam', loss='mse')
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=False)]).compile(optimizer='adam', loss='mse')
try:
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True, use_weight_norm=True)]).compile(
optimizer='adam', loss='mse')
raise AssertionError('test failed.')
except ValueError:
pass
try:
Sequential(layers=[TCN(input_shape=(20, 2), use_batch_norm=True,
use_weight_norm=True, use_layer_norm=True)]).compile(
optimizer='adam', loss='mse')
raise AssertionError('test failed.')
except ValueError:
pass

def test_receptive_field(self):
self.assertEqual(37, TCN(kernel_size=3, dilations=(1, 3, 5), nb_stacks=1).receptive_field)
self.assertEqual(379, TCN(kernel_size=4, dilations=(1, 2, 4, 8, 16, 32), nb_stacks=1).receptive_field)
Expand Down
14 changes: 10 additions & 4 deletions tcn/tcn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import inspect
from typing import List # noqa
from typing import List # noqa

import tensorflow as tf
# pylint: disable=E0611,E0401
Expand Down Expand Up @@ -270,6 +270,12 @@ def __init__(self,
def receptive_field(self):
return 1 + 2 * (self.kernel_size - 1) * self.nb_stacks * sum(self.dilations)

def tolist(self, shape):
try:
return shape.as_list()
except AttributeError:
return shape

def build(self, input_shape):

# member to hold current output shape of the layer for building purposes
Expand Down Expand Up @@ -305,17 +311,17 @@ def build(self, input_shape):

self.output_slice_index = None
if self.padding == 'same':
time = self.build_output_shape.as_list()[1]
time = self.tolist(self.build_output_shape)[1]
if time is not None: # if time dimension is defined. e.g. shape = (bs, 500, input_dim).
self.output_slice_index = int(self.build_output_shape.as_list()[1] / 2)
self.output_slice_index = int(self.tolist(self.build_output_shape)[1] / 2)
else:
# It will known at call time. c.f. self.call.
self.padding_same_and_time_dim_unknown = True

else:
self.output_slice_index = -1 # causal case.
self.slicer_layer = Lambda(lambda tt: tt[:, self.output_slice_index, :], name='Slice_Output')
self.slicer_layer.build(self.build_output_shape.as_list())
self.slicer_layer.build(self.tolist(self.build_output_shape))

def compute_output_shape(self, input_shape):
"""
Expand Down
10 changes: 3 additions & 7 deletions tox.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[tox]
envlist = {py3}-tensorflow-{2.9,2.10,2.11,2.12,2.13,2.14,2.15}
envlist = {py3}-tensorflow-{2.13,2.16,2.17}

[testenv]
setenv =
Expand All @@ -8,13 +8,9 @@ deps = pytest
pylint
flake8
-rrequirements.txt
tensorflow-2.9: tensorflow==2.9
tensorflow-2.10: tensorflow==2.10
tensorflow-2.11: tensorflow==2.11
tensorflow-2.12: tensorflow==2.12
tensorflow-2.13: tensorflow==2.13
tensorflow-2.14: tensorflow==2.14
tensorflow-2.15: tensorflow==2.15
tensorflow-2.16: tensorflow==2.16.2
tensorflow-2.17: tensorflow==2.17
changedir = tasks/
commands = pylint --disable=R,C,W,E1136 ../tcn
flake8 ../tcn --count --select=E9,F63,F7,F82 --show-source --statistics
Expand Down

0 comments on commit fde0141

Please sign in to comment.