Skip to content

Commit

Permalink
Merge pull request #10 from marl/augment
Browse files Browse the repository at this point in the history
Add augmentation for audio and images from L3 paper and add use of validation set
  • Loading branch information
hohsiangwu authored Nov 8, 2017
2 parents 0c3e38b + a2aaf0b commit 631d55a
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 23 deletions.
48 changes: 48 additions & 0 deletions l3embedding/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import numpy as np
import skimage
import skimage.color

def adjust_saturation(rgb_img, factor):
"""
Adjust the saturation of an RGB image
Args:
rgb_img: RGB image data array
factor: Multiplicative scaling factor to be applied to saturation
Returns:
adjusted_img: RGB image with adjusted saturation
"""
hsv_img = skimage.color.rgb2hsv(rgb_img)
hsv_img[:,:,1] = np.clip(hsv_img[:,:,1] * factor, 0.0, 1.0)
return skimage.color.hsv2rgb(hsv_img)


def adjust_brightness(rgb_img, delta):
"""
Adjust the brightness of an RGB image
Args:
rgb_img: RGB image data array
delta: Additive (normalized) gain factor applied to each pixel
Returns:
adjusted_img: RGB image with adjusted saturation
"""
imin, imax = skimage.dtype_limits(rgb_img)
# Convert delta into the range of the image data
delta = rgb_img.dtype.type((imax - imin) * delta)

return np.clip(rgb_img + delta, imin, imax)

def horiz_flip(rgb_img):
"""
Horizontally flip the given image
Args:
rgb_img: RGB image data array
Returns:
flipped_img: Horizontally flipped image
"""
return rgb_img[:,::-1,:]
138 changes: 116 additions & 22 deletions l3embedding/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle
import random
import math

import keras
from keras.optimizers import Adam
Expand All @@ -14,6 +15,7 @@
from tqdm import tqdm

from .model import construct_cnn_L3_orig
from .image import *


#TODO: Consider putting the sampling functionality into another file
Expand All @@ -29,9 +31,14 @@ def get_file_list(data_dir):
video_files: list of video files
"""
data_dir_contents = set(os.listdir(data_dir))
if 'audio' in data_dir_contents and 'video' in data_dir_contents:
audio_files = glob.glob('{}/audio/*'.format(data_dir))
video_files = glob.glob('{}/video/*'.format(data_dir))
else:
audio_files = glob.glob('{}/**/audio/*'.format(data_dir))
video_files = glob.glob('{}/**/video/*'.format(data_dir))

audio_files = glob.glob('{}/audio/*'.format(data_dir))
video_files = glob.glob('{}/video/*'.format(data_dir))
return audio_files, video_files


Expand All @@ -51,7 +58,7 @@ def video_to_audio(video_file):
return '/'.join(path + ['audio', name])


def sample_one_second(audio_data, sampling_frequency, start, label):
def sample_one_second(audio_data, sampling_frequency, start, label, augment=False):
"""Return one second audio samples randomly if start is not specified,
otherwise, return one second audio samples including start (seconds).
Expand All @@ -63,30 +70,59 @@ def sample_one_second(audio_data, sampling_frequency, start, label):
One second samples
"""
sampling_frequency = int(sampling_frequency)
if label:
start = max(0, int(start * sampling_frequency) - random.randint(0, sampling_frequency))
else:
start = random.randrange(len(audio_data) - sampling_frequency)
return audio_data[start:start+sampling_frequency], start / sampling_frequency

audio_data = audio_data[start:start+sampling_frequency]
if augment:
# Make sure we don't clip
gain = 1 + random.random()*min(0.1, 1.0/np.abs(audio_data).max() - 1)
audio_data *= gain
audio_aug_params = {'gain': gain}
else:
audio_aug_params = {}

return audio_data, start / sampling_frequency, audio_aug_params


def l3_frame_scaling(frame_data):
"""
Scale and crop an video frame, using the method from Look, Listen and Learn
Args:
frame_data: video frame data array
Returns:
scaled_frame_data: scaled and cropped frame data
bbox: bounding box for the cropped image
"""
nx, ny, nc = frame_data.shape
scaling = 256.0 / min(nx, ny)

new_nx, new_ny = int(scaling * nx), int(scaling * ny)
assert 256 in (new_nx, new_ny)
new_nx, new_ny = math.ceil(scaling * nx), math.ceil(scaling * ny)
assert 256 in (new_nx, new_ny), str((new_nx, new_ny))


resized_frame_data = scipy.misc.imresize(frame_data, (new_nx, new_ny, nc))

start_x, start_y = random.randrange(new_nx - 224), random.randrange(new_ny - 224)
end_x, end_y = start_x + 224, start_y + 224

return resized_frame_data[start_x:end_x, start_y:end_y, :]
bbox = {
'start_x': start_x,
'start_y': start_y,
'end_x': end_x,
'end_y': end_y
}

return resized_frame_data[start_x:end_x, start_y:end_y, :], bbox


def sample_one_frame(video_data, fps=30, scaling_func=None):
def sample_one_frame(video_data, fps=30, scaling_func=None, augment=False):
"""Return one frame randomly and time (seconds).
Args:
Expand All @@ -102,11 +138,50 @@ def sample_one_frame(video_data, fps=30, scaling_func=None):
num_frames = video_data.shape[0]
frame = random.randrange(num_frames - fps)
frame_data = video_data[frame, :, :, :]
frame_data = scaling_func(frame_data)
return frame_data, frame / fps
frame_data, bbox = scaling_func(frame_data)

video_aug_params = {'bounding_box': bbox}

if augment:
# Randomly horizontally flip the image
horizontal_flip = False
if random.random() < 0.5:
frame_data = horiz_flip(frame_data)
horizontal_flip = True

# Ranges taken from https://github.com/tensorflow/models/blob/master/research/slim/preprocessing/inception_preprocessing.py

# Randomize the order of saturation jitter and brightness jitter
if random.random() < 0.5:
# Add saturation jitter
saturation_factor = random.random() + 0.5
frame_data = adjust_saturation(frame_data, saturation_factor)

# Add brightness jitter
max_delta = 32. / 255.
brightness_delta = (2*random.random() - 1) * max_delta
frame_data = adjust_brightness(frame_data, brightness_delta)
else:
# Add brightness jitter
max_delta = 32. / 255.
brightness_delta = (2*random.random() - 1) * max_delta
frame_data = adjust_brightness(frame_data, brightness_delta)

def sampler(video_file, audio_files):
# Add saturation jitter
saturation_factor = random.random() + 0.5
frame_data = adjust_saturation(frame_data, saturation_factor)

video_aug_params.update({
'horizontal_flip': horizontal_flip,
'saturation_factor': saturation_factor,
'brightness_delta': brightness_delta
})


return frame_data, frame / fps, video_aug_params


def sampler(video_file, audio_files, augment=False):
"""Sample one frame from video_file, with 50% chance sample one second from corresponding audio_file,
50% chance sample one second from another audio_file in the list of audio_files.
Expand All @@ -132,8 +207,11 @@ def sampler(video_file, audio_files):
audio_data, sampling_frequency = sf.read(audio_file)

while True:
sample_video_data, video_start = sample_one_frame(video_data)
sample_audio_data, audio_start = sample_one_second(audio_data, sampling_frequency, video_start, label)
sample_video_data, video_start, video_aug_params \
= sample_one_frame(video_data, augment=augment)
sample_audio_data, audio_start, audio_aug_params \
= sample_one_second(audio_data, sampling_frequency, video_start,
label, augment=augment)
sample_audio_data = sample_audio_data[:, 0].reshape((1, len(sample_audio_data)))

sample = {
Expand All @@ -143,12 +221,14 @@ def sampler(video_file, audio_files):
'audio_file': audio_file,
'video_file': video_file,
'audio_start': audio_start,
'video_start': video_start
'video_start': video_start,
'audio_augment_params': audio_aug_params,
'video_augment_params': video_aug_params
}
yield sample


def data_generator(data_dir, k=32, batch_size=64, random_state=20171021):
def data_generator(data_dir, k=32, batch_size=64, random_state=20171021, augment=False):
"""Sample video and audio from data_dir, returns a streamer that yield samples infinitely.
Args:
Expand All @@ -166,7 +246,7 @@ def data_generator(data_dir, k=32, batch_size=64, random_state=20171021):
audio_files, video_files = get_file_list(data_dir)
seeds = []
for video_file in tqdm(random.sample(video_files, k)):
seeds.append(pescador.Streamer(sampler, video_file, audio_files))
seeds.append(pescador.Streamer(sampler, video_file, audio_files, augment=augment))

mux = pescador.Mux(seeds, k)
if batch_size == 1:
Expand Down Expand Up @@ -200,9 +280,9 @@ def on_epoch_end(self, epoch, logs=None):


#def train(train_csv_path, model_id, output_dir, num_epochs=150, epoch_size=512,
def train(train_data_dir, model_id, output_dir, num_epochs=150, epoch_size=512,
def train(train_data_dir, validation_data_dir, model_id, output_dir, num_epochs=150, epoch_size=512,
batch_size=64, validation_size=1024, num_streamers=16,
random_state=20171021, verbose=False, checkpoint_interval=100):
random_state=20171021, verbose=False, checkpoint_interval=100, augment=False):
m, inputs, outputs = construct_cnn_L3_orig()
loss = 'binary_crossentropy'
metrics = ['accuracy']
Expand Down Expand Up @@ -251,27 +331,41 @@ def train(train_data_dir, model_id, output_dir, num_epochs=150, epoch_size=512,
separator=','))


print('Setting up data generator...')
print('Setting up train data generator...')
train_gen = data_generator(
#train_csv_path,
train_data_dir,
batch_size=batch_size,
random_state=random_state,
k=num_streamers)
k=num_streamers,
augment=augment)

train_gen = pescador.maps.keras_tuples(train_gen,
['video', 'audio'],
'label')

print('Setting up validation data generator...')
val_gen = data_generator(
validation_data_dir,
batch_size=batch_size,
random_state=random_state,
k=num_streamers)

val_gen = pescador.maps.keras_tuples(val_gen,
['video', 'audio'],
'label')



# Fit the model
print('Fit model...')
if verbose:
verbosity = 1
else:
verbosity = 2
history = m.fit_generator(train_gen, epoch_size, num_epochs,
# validation_data=gen_val,
# validation_steps=validation_size,
validation_data=val_gen,
validation_steps=validation_size,
callbacks=cb,
verbose=verbosity)

Expand Down
14 changes: 13 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@ def parse_arguments():
default=20171021,
help='Random seed used to set the RNG state')

parser.add_argument('-a',
'--augment',
dest='augment',
action='store_true',
default=False,
help='If True, performs data augmentation on audio and images')

parser.add_argument('-v',
'--verbose',
dest='verbose',
Expand All @@ -77,7 +84,12 @@ def parse_arguments():
parser.add_argument('train_data_dir',
action='store',
type=str,
help='Path to directory where training subset files are stored')
help='Path to directory where training set files are stored')

parser.add_argument('validation_data_dir',
action='store',
type=str,
help='Path to directory where validation set files are stored')

parser.add_argument('model_id',
action='store',
Expand Down

0 comments on commit 631d55a

Please sign in to comment.