-
Notifications
You must be signed in to change notification settings - Fork 62
/
Copy path02_train_rgb_finetuning.py
194 lines (167 loc) · 6.75 KB
/
02_train_rgb_finetuning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Code for the PyCon.DE 2018 talk by Jens Leitloff and Felix M. Riese.
PyCon 2018 talk: Satellite data is for everyone: insights into modern remote
sensing research with open data and Python.
License: MIT
"""
import os
from tensorflow.keras.applications.densenet import DenseNet201 as DenseNet
from tensorflow.keras.applications.vgg16 import VGG16 as VGG
from tensorflow.keras.callbacks import (EarlyStopping, ModelCheckpoint,
TensorBoard)
from tensorflow.keras.layers import Dense, GlobalAveragePooling2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import SGD
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from image_functions import preprocessing_image_rgb
# variables
path_to_split_datasets = "~/Documents/Data/PyCon/RGB"
use_vgg = True
batch_size = 64
# contruct path
path_to_home = os.path.expanduser("~")
path_to_split_datasets = path_to_split_datasets.replace("~", path_to_home)
path_to_train = os.path.join(path_to_split_datasets, "train")
path_to_validation = os.path.join(path_to_split_datasets, "validation")
# get number of classes
sub_dirs = [sub_dir for sub_dir in os.listdir(path_to_train)
if os.path.isdir(os.path.join(path_to_train, sub_dir))]
num_classes = len(sub_dirs)
# parameters for CNN
if use_vgg:
base_model = VGG(include_top=False,
weights='imagenet',
input_shape=(64, 64, 3))
else:
base_model = DenseNet(include_top=False,
weights='imagenet',
input_shape=(64, 64, 3))
# add a global spatial average pooling layer
top_model = base_model.output
top_model = GlobalAveragePooling2D()(top_model)
# or just flatten the layers
# top_model = Flatten()(top_model)
# let's add a fully-connected layer
if use_vgg:
# only in VGG19 a fully connected nn is added for classfication
# DenseNet tends to overfitting if using additionally dense layers
top_model = Dense(2048, activation='relu')(top_model)
top_model = Dense(2048, activation='relu')(top_model)
# and a logistic layer
predictions = Dense(num_classes, activation='softmax')(top_model)
# this is the model we will train
model = Model(inputs=base_model.input, outputs=predictions)
# print network structure
model.summary()
# defining ImageDataGenerators
# ... initialization for training
train_datagen = ImageDataGenerator(
fill_mode="reflect",
rotation_range=45,
horizontal_flip=True,
vertical_flip=True,
preprocessing_function=preprocessing_image_rgb)
# ... initialization for validation
test_datagen = ImageDataGenerator(
preprocessing_function=preprocessing_image_rgb)
# ... definition for training
train_generator = train_datagen.flow_from_directory(path_to_train,
target_size=(64, 64),
batch_size=batch_size,
class_mode='categorical')
# just for information
class_indices = train_generator.class_indices
print(class_indices)
# ... definition for validation
validation_generator = test_datagen.flow_from_directory(
path_to_validation,
target_size=(64, 64),
batch_size=batch_size,
class_mode='categorical')
# first: train only the top layers (which were randomly initialized)
# i.e. freeze all convolutional layers
for layer in base_model.layers:
layer.trainable = False
# compile the model (should be done *after* setting layers to non-trainable)
model.compile(optimizer='adadelta', loss='categorical_crossentropy',
metrics=['categorical_accuracy'])
# generate callback to save best model w.r.t val_categorical_accuracy
if use_vgg:
file_name = "vgg"
else:
file_name = "dense"
checkpointer = ModelCheckpoint("../data/models/" + file_name +
"_rgb_transfer_init." +
"{epoch:02d}-{val_categorical_accuracy:.3f}." +
"hdf5",
monitor='val_categorical_accuracy',
verbose=1,
save_best_only=True,
mode='max')
earlystopper = EarlyStopping(monitor='val_categorical_accuracy',
patience=10,
mode='max',
restore_best_weights=True)
tensorboard = TensorBoard(log_dir='./logs', write_graph=True,
write_images=True, update_freq='epoch')
history = model.fit(
train_generator,
steps_per_epoch=1000,
epochs=10000,
callbacks=[checkpointer, earlystopper,
tensorboard],
validation_data=validation_generator,
validation_steps=500)
initial_epoch = len(history.history['loss'])+1
# at this point, the top layers are well trained and we can start fine-tuning
# convolutional layers. We will freeze the bottom N layers
# and train the remaining top layers.
# let's visualize layer names and layer indices to see how many layers
# we should freeze:
names = []
for i, layer in enumerate(model.layers):
names.append([i, layer.name, layer.trainable])
print(names)
if use_vgg:
# we will freaze the first convolutional block and train all
# remaining blocks, including top layers.
for layer in model.layers[:4]:
layer.trainable = False
for layer in model.layers[4:]:
layer.trainable = True
else:
for layer in model.layers[:7]:
layer.trainable = False
for layer in model.layers[7:]:
layer.trainable = True
# we need to recompile the model for these modifications to take effect
# we use SGD with a low learning rate
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9),
loss='categorical_crossentropy',
metrics=['categorical_accuracy'])
# generate callback to save best model w.r.t val_categorical_accuracy
if use_vgg:
file_name = "vgg"
else:
file_name = "dense"
checkpointer = ModelCheckpoint("../data/models/" + file_name +
"_rgb_transfer_final." +
"{epoch:02d}-{val_categorical_accuracy:.3f}" +
".hdf5",
monitor='val_categorical_accuracy',
verbose=1,
save_best_only=True,
mode='max')
earlystopper = EarlyStopping(monitor='val_categorical_accuracy',
patience=50,
mode='max')
model.fit(
train_generator,
steps_per_epoch=1000,
epochs=10000,
callbacks=[checkpointer, earlystopper, tensorboard],
validation_data=validation_generator,
validation_steps=500,
initial_epoch=initial_epoch)