This repository has been archived by the owner on Nov 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 652
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #130 from bckenstler/master
added CLR callback, test
- Loading branch information
Showing
4 changed files
with
408 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
'''Train a simple deep CNN on the CIFAR10 small images dataset using | ||
a triangular cyclic learning rate (CLR) policy. | ||
It gets to 75% validation accuracy in 15 epochs, and 79% after 40 epochs; | ||
compare to 25 and 50 epochs respectively without CLR. | ||
''' | ||
|
||
from __future__ import print_function | ||
from __future__ import absolute_import | ||
import keras | ||
from keras.datasets import cifar10 | ||
from keras.preprocessing.image import ImageDataGenerator | ||
from keras.models import Sequential | ||
from keras.layers import Dense, Dropout, Activation, Flatten | ||
from keras.layers import Conv2D, MaxPooling2D | ||
from keras_contrib.callbacks import CyclicLR | ||
|
||
import os | ||
|
||
batch_size = 100 | ||
epochs = 50 | ||
num_classes = 10 | ||
data_augmentation = True | ||
num_predictions = 20 | ||
data_augmentation = True | ||
save_dir = os.path.join(os.getcwd(), 'saved_models') | ||
model_name = 'keras_cifar10_trained_model.h5' | ||
|
||
# The data, split between train and test sets: | ||
(x_train, y_train), (x_test, y_test) = cifar10.load_data() | ||
print('x_train shape:', x_train.shape) | ||
print(x_train.shape[0], 'train samples') | ||
print(x_test.shape[0], 'test samples') | ||
|
||
# Convert class vectors to binary class matrices. | ||
y_train = keras.utils.to_categorical(y_train, num_classes) | ||
y_test = keras.utils.to_categorical(y_test, num_classes) | ||
|
||
model = Sequential() | ||
model.add(Conv2D(32, (3, 3), padding='same', | ||
input_shape=x_train.shape[1:])) | ||
model.add(Activation('relu')) | ||
model.add(Conv2D(32, (3, 3))) | ||
model.add(Activation('relu')) | ||
model.add(MaxPooling2D(pool_size=(2, 2))) | ||
model.add(Dropout(0.25)) | ||
|
||
model.add(Conv2D(64, (3, 3), padding='same')) | ||
model.add(Activation('relu')) | ||
model.add(Conv2D(64, (3, 3))) | ||
model.add(Activation('relu')) | ||
model.add(MaxPooling2D(pool_size=(2, 2))) | ||
model.add(Dropout(0.25)) | ||
|
||
model.add(Flatten()) | ||
model.add(Dense(512)) | ||
model.add(Activation('relu')) | ||
model.add(Dropout(0.5)) | ||
model.add(Dense(num_classes)) | ||
model.add(Activation('softmax')) | ||
|
||
# initiate RMSprop optimizer | ||
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6) | ||
|
||
# initiate CyclicLR LR scheduler | ||
clr = CyclicLR( | ||
base_lr=0.0001, | ||
max_lr=0.0005, | ||
step_size=2000, | ||
mode='triangular') | ||
|
||
|
||
# Let's train the model using RMSprop | ||
model.compile(loss='categorical_crossentropy', | ||
optimizer=opt, | ||
metrics=['accuracy']) | ||
|
||
x_train = x_train.astype('float32') | ||
x_test = x_test.astype('float32') | ||
x_train /= 255 | ||
x_test /= 255 | ||
|
||
if not data_augmentation: | ||
print('Not using data augmentation.') | ||
model.fit(x_train, y_train, | ||
batch_size=batch_size, | ||
epochs=epochs, | ||
validation_data=(x_test, y_test), | ||
callbacks=[clr], | ||
shuffle=True) | ||
else: | ||
print('Using real-time data augmentation.') | ||
# This will do preprocessing and realtime data augmentation: | ||
datagen = ImageDataGenerator( | ||
featurewise_center=False, # set input mean to 0 over the dataset | ||
samplewise_center=False, # set each sample mean to 0 | ||
featurewise_std_normalization=False, # divide inputs by std of the dataset | ||
samplewise_std_normalization=False, # divide each input by its std | ||
zca_whitening=False, # apply ZCA whitening | ||
zca_epsilon=1e-06, # epsilon for ZCA whitening | ||
rotation_range=0, | ||
# randomly rotate images in the range (degrees, 0 to 180) | ||
# randomly shift images horizontally (fraction of total width) | ||
width_shift_range=0.1, | ||
# randomly shift images vertically (fraction of total height) | ||
height_shift_range=0.1, | ||
shear_range=0., # set range for random shear | ||
zoom_range=0., # set range for random zoom | ||
channel_shift_range=0., # set range for random channel shifts | ||
# set mode for filling points outside the input boundaries | ||
fill_mode='nearest', | ||
cval=0., # value used for fill_mode = "constant" | ||
horizontal_flip=True, # randomly flip images | ||
vertical_flip=False, # randomly flip images | ||
# set rescaling factor (applied before any other transformation) | ||
rescale=None, | ||
# set function that will be applied on each input | ||
preprocessing_function=None, | ||
# image data format, either "channels_first" or "channels_last" | ||
data_format=None, | ||
# fraction of images reserved for validation (strictly between 0 and 1) | ||
validation_split=0.0) | ||
|
||
# Compute quantities required for feature-wise normalization | ||
# (std, mean, and principal components if ZCA whitening is applied). | ||
datagen.fit(x_train) | ||
|
||
# Fit the model on the batches generated by datagen.flow(). | ||
|
||
model.fit_generator(datagen.flow(x_train, y_train, | ||
batch_size=batch_size), | ||
epochs=epochs, | ||
validation_data=(x_test, y_test), | ||
callbacks=[clr], | ||
workers=4) | ||
|
||
# Save model and weights | ||
if not os.path.isdir(save_dir): | ||
os.makedirs(save_dir) | ||
model_path = os.path.join(save_dir, model_name) | ||
model.save(model_path) | ||
print('Saved trained model at %s ' % model_path) | ||
|
||
# Score trained model. | ||
scores = model.evaluate(x_test, y_test, verbose=1) | ||
print('Test loss:', scores[0]) | ||
print('Test accuracy:', scores[1]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,3 @@ | ||
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint | ||
from .dead_relu_detector import DeadReluDetector | ||
from .cyclical_learning_rate import CyclicLR |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,146 @@ | ||
from keras.callbacks import * | ||
|
||
|
||
class CyclicLR(Callback): | ||
"""This callback implements a cyclical learning rate policy (CLR). | ||
The method cycles the learning rate between two boundaries with | ||
some constant frequency. | ||
# Arguments | ||
base_lr: initial learning rate which is the | ||
lower boundary in the cycle. | ||
max_lr: upper boundary in the cycle. Functionally, | ||
it defines the cycle amplitude (max_lr - base_lr). | ||
The lr at any cycle is the sum of base_lr | ||
and some scaling of the amplitude; therefore | ||
max_lr may not actually be reached depending on | ||
scaling function. | ||
step_size: number of training iterations per | ||
half cycle. Authors suggest setting step_size | ||
2-8 x training iterations in epoch. | ||
mode: one of {triangular, triangular2, exp_range}. | ||
Default 'triangular'. | ||
Values correspond to policies detailed above. | ||
If scale_fn is not None, this argument is ignored. | ||
gamma: constant in 'exp_range' scaling function: | ||
gamma**(cycle iterations) | ||
scale_fn: Custom scaling policy defined by a single | ||
argument lambda function, where | ||
0 <= scale_fn(x) <= 1 for all x >= 0. | ||
mode paramater is ignored | ||
scale_mode: {'cycle', 'iterations'}. | ||
Defines whether scale_fn is evaluated on | ||
cycle number or cycle iterations (training | ||
iterations since start of cycle). Default is 'cycle'. | ||
The amplitude of the cycle can be scaled on a per-iteration or | ||
per-cycle basis. | ||
This class has three built-in policies, as put forth in the paper. | ||
"triangular": | ||
A basic triangular cycle w/ no amplitude scaling. | ||
"triangular2": | ||
A basic triangular cycle that scales initial amplitude by half each cycle. | ||
"exp_range": | ||
A cycle that scales initial amplitude by gamma**(cycle iterations) at each | ||
cycle iteration. | ||
For more detail, please see paper. | ||
# Example for CIFAR-10 w/ batch size 100: | ||
```python | ||
clr = CyclicLR(base_lr=0.001, max_lr=0.006, | ||
step_size=2000., mode='triangular') | ||
model.fit(X_train, Y_train, callbacks=[clr]) | ||
``` | ||
Class also supports custom scaling functions: | ||
```python | ||
clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.)) | ||
clr = CyclicLR(base_lr=0.001, max_lr=0.006, | ||
step_size=2000., scale_fn=clr_fn, | ||
scale_mode='cycle') | ||
model.fit(X_train, Y_train, callbacks=[clr]) | ||
``` | ||
""" | ||
|
||
def __init__( | ||
self, | ||
base_lr=0.001, | ||
max_lr=0.006, | ||
step_size=2000., | ||
mode='triangular', | ||
gamma=1., | ||
scale_fn=None, | ||
scale_mode='cycle'): | ||
super(CyclicLR, self).__init__() | ||
|
||
assert mode in ['triangular', 'triangular2', | ||
'exp_range'], "mode must be one of 'triangular', 'triangular2', or 'exp_range'" | ||
self.base_lr = base_lr | ||
self.max_lr = max_lr | ||
self.step_size = step_size | ||
self.mode = mode | ||
self.gamma = gamma | ||
if scale_fn is None: | ||
if self.mode == 'triangular': | ||
self.scale_fn = lambda x: 1. | ||
self.scale_mode = 'cycle' | ||
elif self.mode == 'triangular2': | ||
self.scale_fn = lambda x: 1 / (2.**(x - 1)) | ||
self.scale_mode = 'cycle' | ||
elif self.mode == 'exp_range': | ||
self.scale_fn = lambda x: gamma**(x) | ||
self.scale_mode = 'iterations' | ||
else: | ||
self.scale_fn = scale_fn | ||
self.scale_mode = scale_mode | ||
self.clr_iterations = 0. | ||
self.trn_iterations = 0. | ||
self.history = {} | ||
|
||
self._reset() | ||
|
||
def _reset(self, new_base_lr=None, new_max_lr=None, | ||
new_step_size=None): | ||
"""Resets cycle iterations. | ||
Optional boundary/step size adjustment. | ||
""" | ||
if new_base_lr is not None: | ||
self.base_lr = new_base_lr | ||
if new_max_lr is not None: | ||
self.max_lr = new_max_lr | ||
if new_step_size is not None: | ||
self.step_size = new_step_size | ||
self.clr_iterations = 0. | ||
|
||
def clr(self): | ||
cycle = np.floor(1 + self.clr_iterations / (2 * self.step_size)) | ||
x = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1) | ||
if self.scale_mode == 'cycle': | ||
return self.base_lr + (self.max_lr - self.base_lr) * \ | ||
np.maximum(0, (1 - x)) * self.scale_fn(cycle) | ||
else: | ||
return self.base_lr + (self.max_lr - self.base_lr) * \ | ||
np.maximum(0, (1 - x)) * self.scale_fn(self.clr_iterations) | ||
|
||
def on_train_begin(self, logs={}): | ||
logs = logs or {} | ||
|
||
if self.clr_iterations == 0: | ||
K.set_value(self.model.optimizer.lr, self.base_lr) | ||
else: | ||
K.set_value(self.model.optimizer.lr, self.clr()) | ||
|
||
def on_batch_end(self, epoch, logs=None): | ||
|
||
logs = logs or {} | ||
self.trn_iterations += 1 | ||
self.clr_iterations += 1 | ||
K.set_value(self.model.optimizer.lr, self.clr()) | ||
|
||
self.history.setdefault( | ||
'lr', []).append( | ||
K.get_value( | ||
self.model.optimizer.lr)) | ||
self.history.setdefault('iterations', []).append(self.trn_iterations) | ||
|
||
for k, v in logs.items(): | ||
self.history.setdefault(k, []).append(v) |
Oops, something went wrong.