-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
79 lines (56 loc) · 3.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
Execute this script to train a model specified by "settings.py" script.
In case of an out of memory problem adjust batch_size in "settings.py".
Be sure to download dataset from https://www.kaggle.com/carlolepelaars/camvid
and unpack it to "data" subfolder.
"""
from model import *
import matplotlib.pyplot as plt
print("Tensorflow version: " + tf.__version__)
print("Keras version: " + tf.keras.__version__)
########################################################################################################################
# CREATE MODEL
########################################################################################################################
model = create_model()
model.compile(optimizer=tf.keras.optimizers.RMSprop(lr=init_lr),
loss=tf.keras.losses.categorical_crossentropy,
metrics=[tf.keras.metrics.categorical_accuracy])
model.summary()
########################################################################################################################
# CALLBACKS
########################################################################################################################
path = os.path.join(tmp_folder, 'trained_model.h5')
save_model = tf.keras.callbacks.ModelCheckpoint(path, monitor=cb_monitor, mode=cb_mode, verbose=1,
save_best_only=True)
csv_logger = tf.keras.callbacks.CSVLogger(os.path.join(tmp_folder, 'training.csv'))
early_stopping = tf.keras.callbacks.EarlyStopping(monitor=cb_monitor, mode=cb_mode, verbose=1,
patience=early_stopping_patience, restore_best_weights=True)
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(monitor=cb_monitor, mode=cb_mode, verbose=1,
factor=reduce_lr_factor, patience=reduce_lr_patience)
########################################################################################################################
# TRAINING MODEL
########################################################################################################################
data_gen_train = DataProvider(batch_size, is_validation=False, process_input=preprocessing)
data_gen_valid = DataProvider(batch_size, is_validation=True, process_input=preprocessing)
hist = model.fit(data_gen_train,
epochs=1000,
validation_data=data_gen_valid,
shuffle=True,
callbacks=[save_model, csv_logger, early_stopping, reduce_lr],
verbose=2)
model.save(path, include_optimizer=False)
plt.clf()
plt.plot(hist.history['loss'])
plt.plot(hist.history['val_loss'])
plt.savefig(os.path.join(tmp_folder, 'training_loss.png'))
plt.clf()
plt.plot(hist.history['categorical_accuracy'])
plt.plot(hist.history['val_categorical_accuracy'])
plt.savefig(os.path.join(tmp_folder, 'training_accuracy.png'))
########################################################################################################################
# EVALUATE MODEL
########################################################################################################################
res_train = model.evaluate(data_gen_train)
res_test = model.evaluate(data_gen_valid)
print('[train_loss, train_accuracy] =', res_train)
print('[val_loss, val_accuracy] =', res_test)