forked from philipperemy/keras-tcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsave_reload_sequential_model.py
46 lines (36 loc) · 1.36 KB
/
save_reload_sequential_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import numpy as np
from tensorflow.keras.layers import Dense, Embedding
from tensorflow.keras.models import Sequential, model_from_json
from tcn import TCN, tcn_full_summary
# define input shape
max_len = 100
max_features = 50
# make model
model = Sequential(layers=[Embedding(max_features, 16, input_shape=(max_len,)),
TCN(nb_filters=12,
dropout_rate=0.5,
kernel_size=6,
dilations=[1, 2, 4]),
Dense(units=1, activation='sigmoid')])
# get model as json string and save to file
model_as_json = model.to_json()
with open('model.json', "w") as json_file:
json_file.write(model_as_json)
# save weights to file (for this format, need h5py installed)
model.save_weights('weights.h5')
# Make inference.
inputs = np.ones(shape=(1, 100))
out1 = model.predict(inputs)[0, 0]
print('*' * 80)
print('Inference after creation:', out1)
# load model from file
loaded_json = open('model.json', 'r').read()
reloaded_model = model_from_json(loaded_json, custom_objects={'TCN': TCN})
tcn_full_summary(model, expand_residual_blocks=False)
# restore weights
reloaded_model.load_weights('weights.h5')
# Make inference.
out2 = reloaded_model.predict(inputs)[0, 0]
print('*' * 80)
print('Inference after loading:', out2)
assert abs(out1 - out2) < 1e-6