Skip to content
This repository has been archived by the owner on Nov 3, 2022. It is now read-only.

Commit

Permalink
Merge pull request #130 from bckenstler/master
Browse files Browse the repository at this point in the history
added CLR callback, test
  • Loading branch information
ahundt authored Jul 28, 2018
2 parents 3e77ba2 + f25c266 commit 4d48c68
Show file tree
Hide file tree
Showing 4 changed files with 408 additions and 0 deletions.
146 changes: 146 additions & 0 deletions examples/cifar10_clr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
'''Train a simple deep CNN on the CIFAR10 small images dataset using
a triangular cyclic learning rate (CLR) policy.
It gets to 75% validation accuracy in 15 epochs, and 79% after 40 epochs;
compare to 25 and 50 epochs respectively without CLR.
'''

from __future__ import print_function
from __future__ import absolute_import
import keras
from keras.datasets import cifar10
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras_contrib.callbacks import CyclicLR

import os

batch_size = 100
epochs = 50
num_classes = 10
data_augmentation = True
num_predictions = 20
data_augmentation = True
save_dir = os.path.join(os.getcwd(), 'saved_models')
model_name = 'keras_cifar10_trained_model.h5'

# The data, split between train and test sets:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# Convert class vectors to binary class matrices.
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

model = Sequential()
model.add(Conv2D(32, (3, 3), padding='same',
input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))

model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(num_classes))
model.add(Activation('softmax'))

# initiate RMSprop optimizer
opt = keras.optimizers.rmsprop(lr=0.0001, decay=1e-6)

# initiate CyclicLR LR scheduler
clr = CyclicLR(
base_lr=0.0001,
max_lr=0.0005,
step_size=2000,
mode='triangular')


# Let's train the model using RMSprop
model.compile(loss='categorical_crossentropy',
optimizer=opt,
metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

if not data_augmentation:
print('Not using data augmentation.')
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[clr],
shuffle=True)
else:
print('Using real-time data augmentation.')
# This will do preprocessing and realtime data augmentation:
datagen = ImageDataGenerator(
featurewise_center=False, # set input mean to 0 over the dataset
samplewise_center=False, # set each sample mean to 0
featurewise_std_normalization=False, # divide inputs by std of the dataset
samplewise_std_normalization=False, # divide each input by its std
zca_whitening=False, # apply ZCA whitening
zca_epsilon=1e-06, # epsilon for ZCA whitening
rotation_range=0,
# randomly rotate images in the range (degrees, 0 to 180)
# randomly shift images horizontally (fraction of total width)
width_shift_range=0.1,
# randomly shift images vertically (fraction of total height)
height_shift_range=0.1,
shear_range=0., # set range for random shear
zoom_range=0., # set range for random zoom
channel_shift_range=0., # set range for random channel shifts
# set mode for filling points outside the input boundaries
fill_mode='nearest',
cval=0., # value used for fill_mode = "constant"
horizontal_flip=True, # randomly flip images
vertical_flip=False, # randomly flip images
# set rescaling factor (applied before any other transformation)
rescale=None,
# set function that will be applied on each input
preprocessing_function=None,
# image data format, either "channels_first" or "channels_last"
data_format=None,
# fraction of images reserved for validation (strictly between 0 and 1)
validation_split=0.0)

# Compute quantities required for feature-wise normalization
# (std, mean, and principal components if ZCA whitening is applied).
datagen.fit(x_train)

# Fit the model on the batches generated by datagen.flow().

model.fit_generator(datagen.flow(x_train, y_train,
batch_size=batch_size),
epochs=epochs,
validation_data=(x_test, y_test),
callbacks=[clr],
workers=4)

# Save model and weights
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)
model.save(model_path)
print('Saved trained model at %s ' % model_path)

# Score trained model.
scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])
1 change: 1 addition & 0 deletions keras_contrib/callbacks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .snapshot import SnapshotCallbackBuilder, SnapshotModelCheckpoint
from .dead_relu_detector import DeadReluDetector
from .cyclical_learning_rate import CyclicLR
146 changes: 146 additions & 0 deletions keras_contrib/callbacks/cyclical_learning_rate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from keras.callbacks import *


class CyclicLR(Callback):
"""This callback implements a cyclical learning rate policy (CLR).
The method cycles the learning rate between two boundaries with
some constant frequency.
# Arguments
base_lr: initial learning rate which is the
lower boundary in the cycle.
max_lr: upper boundary in the cycle. Functionally,
it defines the cycle amplitude (max_lr - base_lr).
The lr at any cycle is the sum of base_lr
and some scaling of the amplitude; therefore
max_lr may not actually be reached depending on
scaling function.
step_size: number of training iterations per
half cycle. Authors suggest setting step_size
2-8 x training iterations in epoch.
mode: one of {triangular, triangular2, exp_range}.
Default 'triangular'.
Values correspond to policies detailed above.
If scale_fn is not None, this argument is ignored.
gamma: constant in 'exp_range' scaling function:
gamma**(cycle iterations)
scale_fn: Custom scaling policy defined by a single
argument lambda function, where
0 <= scale_fn(x) <= 1 for all x >= 0.
mode paramater is ignored
scale_mode: {'cycle', 'iterations'}.
Defines whether scale_fn is evaluated on
cycle number or cycle iterations (training
iterations since start of cycle). Default is 'cycle'.
The amplitude of the cycle can be scaled on a per-iteration or
per-cycle basis.
This class has three built-in policies, as put forth in the paper.
"triangular":
A basic triangular cycle w/ no amplitude scaling.
"triangular2":
A basic triangular cycle that scales initial amplitude by half each cycle.
"exp_range":
A cycle that scales initial amplitude by gamma**(cycle iterations) at each
cycle iteration.
For more detail, please see paper.
# Example for CIFAR-10 w/ batch size 100:
```python
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., mode='triangular')
model.fit(X_train, Y_train, callbacks=[clr])
```
Class also supports custom scaling functions:
```python
clr_fn = lambda x: 0.5*(1+np.sin(x*np.pi/2.))
clr = CyclicLR(base_lr=0.001, max_lr=0.006,
step_size=2000., scale_fn=clr_fn,
scale_mode='cycle')
model.fit(X_train, Y_train, callbacks=[clr])
```
"""

def __init__(
self,
base_lr=0.001,
max_lr=0.006,
step_size=2000.,
mode='triangular',
gamma=1.,
scale_fn=None,
scale_mode='cycle'):
super(CyclicLR, self).__init__()

assert mode in ['triangular', 'triangular2',
'exp_range'], "mode must be one of 'triangular', 'triangular2', or 'exp_range'"
self.base_lr = base_lr
self.max_lr = max_lr
self.step_size = step_size
self.mode = mode
self.gamma = gamma
if scale_fn is None:
if self.mode == 'triangular':
self.scale_fn = lambda x: 1.
self.scale_mode = 'cycle'
elif self.mode == 'triangular2':
self.scale_fn = lambda x: 1 / (2.**(x - 1))
self.scale_mode = 'cycle'
elif self.mode == 'exp_range':
self.scale_fn = lambda x: gamma**(x)
self.scale_mode = 'iterations'
else:
self.scale_fn = scale_fn
self.scale_mode = scale_mode
self.clr_iterations = 0.
self.trn_iterations = 0.
self.history = {}

self._reset()

def _reset(self, new_base_lr=None, new_max_lr=None,
new_step_size=None):
"""Resets cycle iterations.
Optional boundary/step size adjustment.
"""
if new_base_lr is not None:
self.base_lr = new_base_lr
if new_max_lr is not None:
self.max_lr = new_max_lr
if new_step_size is not None:
self.step_size = new_step_size
self.clr_iterations = 0.

def clr(self):
cycle = np.floor(1 + self.clr_iterations / (2 * self.step_size))
x = np.abs(self.clr_iterations / self.step_size - 2 * cycle + 1)
if self.scale_mode == 'cycle':
return self.base_lr + (self.max_lr - self.base_lr) * \
np.maximum(0, (1 - x)) * self.scale_fn(cycle)
else:
return self.base_lr + (self.max_lr - self.base_lr) * \
np.maximum(0, (1 - x)) * self.scale_fn(self.clr_iterations)

def on_train_begin(self, logs={}):
logs = logs or {}

if self.clr_iterations == 0:
K.set_value(self.model.optimizer.lr, self.base_lr)
else:
K.set_value(self.model.optimizer.lr, self.clr())

def on_batch_end(self, epoch, logs=None):

logs = logs or {}
self.trn_iterations += 1
self.clr_iterations += 1
K.set_value(self.model.optimizer.lr, self.clr())

self.history.setdefault(
'lr', []).append(
K.get_value(
self.model.optimizer.lr))
self.history.setdefault('iterations', []).append(self.trn_iterations)

for k, v in logs.items():
self.history.setdefault(k, []).append(v)
Loading

0 comments on commit 4d48c68

Please sign in to comment.