Skip to content

Commit

Permalink
try fine-tuning
Browse files Browse the repository at this point in the history
  • Loading branch information
kenmaz committed Aug 29, 2017
1 parent 8723fb3 commit 5a018ee
Showing 1 changed file with 112 additions and 0 deletions.
112 changes: 112 additions & 0 deletions keras/mcz_finetuning.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
from __future__ import print_function
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.utils import np_utils
from keras.layers.normalization import BatchNormalization
from keras.applications.vgg16 import VGG16
import mcz_input
import sys

batch_size = 32
nb_classes = 5
nb_epoch = 200
data_augmentation = True

img_rows, img_cols = 112, 112
img_channels = 3

(X_train, y_train)= mcz_input.read_data('../deeplearning/train.txt')
(X_test, y_test)= mcz_input.read_data('../deeplearning/test.txt')

print('X_train shape:', X_train.shape, X_train[0][0][0][0])
print('y_train shape:', y_train.shape, y_train[0][0])
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

Y_train = np_utils.to_categorical(y_train, nb_classes)
Y_test = np_utils.to_categorical(y_test, nb_classes)

vgg16 = VGG16(include_top=False, weights='imagenet', input_shape=X_train.shape[1:])
print(vgg16.summary())
sys.exit()

model = Sequential()

model.add(Conv2D(32, (3, 3), padding='same', input_shape=X_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(nb_classes))
model.add(Activation('softmax'))

model.compile(loss='categorical_crossentropy',
optimizer='adam',
metrics=['accuracy'])

X_train = X_train.astype('float32')
X_test = X_test.astype('float32')
X_train /= 255
X_test /= 255

from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping
csv_logger = CSVLogger('log.csv', append=True, separator=';')

fpath = 'weights.{epoch:02d}-{loss:.2f}-{acc:.2f}-{val_loss:.2f}-{val_acc:.2f}.h5'
cp_cb = ModelCheckpoint(fpath, monitor='val_loss', verbose=1, save_best_only=True, mode='auto')

stopping = EarlyStopping(monitor='val_loss', patience=10, verbose=1)

if not data_augmentation:
print('Not using data augmentation.')
model.fit(X_train, Y_train,
batch_size=batch_size,
nb_epoch=nb_epoch,
validation_data=(X_test, Y_test),
shuffle=True,
callbacks=[csv_logger, cp_cb, stopping])
else:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
rotation_range=0, # randomly rotate images in the range (degrees, 0 to 180)
width_shift_range=0.1, # randomly shift images horizontally (fraction of total width)
height_shift_range=0.1, # randomly shift images vertically (fraction of total height)
horizontal_flip=True, # randomly flip images
vertical_flip=False) # randomly flip images

# Compute quantities required for featurewise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(X_train)

# Fit the model on the batches generated by datagen.flow().
model.fit_generator(datagen.flow(X_train, Y_train, batch_size=batch_size),
steps_per_epoch=len(X_train),
epochs=nb_epoch,
validation_data=(X_test, Y_test),
callbacks=[csv_logger])

model.save('model.h5')

0 comments on commit 5a018ee

Please sign in to comment.