-
Notifications
You must be signed in to change notification settings - Fork 1.4k
/
image_augmentation.py
102 lines (74 loc) · 3.13 KB
/
image_augmentation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""
Fairly basic set of tools for data augmentation on images.
"""
import numpy as np
import random
import cv2
from os.path import join, expanduser
import matplotlib.pyplot as plt
def perform_augmentation(batch_x, batch_y):
"""
Perform basic data augmentation on image batches.
Parameters
----------
batch_x: ndarray of shape (b, h, w, c)
Batch of images in RGB format, values in [0, 255]
batch_y: ndarray of shape (b, h, w, c)
Batch of ground truth with road segmentation
Returns
-------
batch_x_aug, batch_y_aug: two ndarray of shape (b, h, w, c)
Augmented batches
"""
def mirror(x):
return x[:, ::-1, :]
def augment_in_hsv_space(x_hsv):
x_hsv = np.float32(cv2.cvtColor(x_hsv, cv2.COLOR_RGB2HSV))
x_hsv[:, :, 0] = x_hsv[:, :, 0] * random.uniform(0.9, 1.1) # change hue
x_hsv[:, :, 1] = x_hsv[:, :, 1] * random.uniform(0.5, 2.0) # change saturation
x_hsv[:, :, 2] = x_hsv[:, :, 2] * random.uniform(0.5, 2.0) # change brightness
x_hsv = np.uint8(np.clip(x_hsv, 0, 255))
return cv2.cvtColor(x_hsv, cv2.COLOR_HSV2RGB)
batch_x_aug = np.copy(batch_x)
batch_y_aug = np.copy(batch_y)
for b in range(batch_x_aug.shape[0]):
# Random mirroring
should_mirror = random.choice([True, False])
if should_mirror:
batch_x_aug[b] = mirror(batch_x[b])
batch_y_aug[b] = mirror(batch_y[b])
# Random change in image values (hue, saturation, brightness)
batch_x_aug[b] = augment_in_hsv_space(batch_x_aug[b])
return batch_x_aug, batch_y_aug
def debug_visualize_data_augmentation():
from main_27 import gen_batch_function # keep here to avoid circular dependencies
"""
Dirty and running code to debug image augmentation functions.
"""
# Parameters
data_dir = join(expanduser("~"), 'code', 'self-driving-car', 'project_12_road_segmentation', 'data')
image_h, image_w = (160, 576)
batch_size = 20
# Create function to get batches
batch_generator = gen_batch_function(join(data_dir, 'data_road/training'), (image_h, image_w))
# Load a batch and augment it
batch_x, batch_y = next(batch_generator(batch_size))
batch_x_aug, batch_y_aug = perform_augmentation(batch_x, batch_y)
# Show both original and augmented batch images
for i in range(batch_size):
plt.figure(1)
x = batch_x[i]
y = np.uint8(batch_y[i][:, :, 1]) * 255 # cast from boolean to uint8 for visualization
y = np.stack([y, y, y], axis=2) # turn to 3-channels for visualization
xy = np.concatenate([x, y], axis=0)
plt.imshow(xy)
plt.figure(2)
x_aug = batch_x_aug[i]
y_aug = np.uint8(batch_y_aug[i][:, :, 1]) * 255 # cast from boolean to uint8 for visualization
y_aug = np.stack([y_aug, y_aug, y_aug], axis=2) # turn to 3-channels for visualization
xy_aug = np.concatenate([x_aug, y_aug], axis=0)
plt.imshow(xy_aug)
plt.show(block=False)
plt.waitforbuttonpress()
if __name__ == '__main__':
debug_visualize_data_augmentation()