-
Notifications
You must be signed in to change notification settings - Fork 100
/
train.py
96 lines (77 loc) · 3.82 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from __future__ import absolute_import, division, print_function
import tensorflow as tf
from models.resnet import resnet_18, resnet_34, resnet_50, resnet_101, resnet_152
import config
from prepare_data import generate_datasets
import math
def get_model():
model = resnet_50()
if config.model == "resnet18":
model = resnet_18()
if config.model == "resnet34":
model = resnet_34()
if config.model == "resnet101":
model = resnet_101()
if config.model == "resnet152":
model = resnet_152()
model.build(input_shape=(None, config.image_height, config.image_width, config.channels))
model.summary()
return model
if __name__ == '__main__':
# GPU settings
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# get the original_dataset
train_dataset, valid_dataset, test_dataset, train_count, valid_count, test_count = generate_datasets()
# create model
model = get_model()
# define loss and optimizer
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adadelta()
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')
@tf.function
def train_step(images, labels):
with tf.GradientTape() as tape:
predictions = model(images, training=True)
loss = loss_object(y_true=labels, y_pred=predictions)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(grads_and_vars=zip(gradients, model.trainable_variables))
train_loss(loss)
train_accuracy(labels, predictions)
@tf.function
def valid_step(images, labels):
predictions = model(images, training=False)
v_loss = loss_object(labels, predictions)
valid_loss(v_loss)
valid_accuracy(labels, predictions)
# start training
for epoch in range(config.EPOCHS):
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()
step = 0
for images, labels in train_dataset:
step += 1
train_step(images, labels)
print("Epoch: {}/{}, step: {}/{}, loss: {:.5f}, accuracy: {:.5f}".format(epoch + 1,
config.EPOCHS,
step,
math.ceil(train_count / config.BATCH_SIZE),
train_loss.result(),
train_accuracy.result()))
for valid_images, valid_labels in valid_dataset:
valid_step(valid_images, valid_labels)
print("Epoch: {}/{}, train loss: {:.5f}, train accuracy: {:.5f}, "
"valid loss: {:.5f}, valid accuracy: {:.5f}".format(epoch + 1,
config.EPOCHS,
train_loss.result(),
train_accuracy.result(),
valid_loss.result(),
valid_accuracy.result()))
model.save_weights(filepath=config.save_model_dir, save_format='tf')