-
Notifications
You must be signed in to change notification settings - Fork 3
/
gen_data.py
73 lines (63 loc) · 2.02 KB
/
gen_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
import numpy as np
import cv2
from skimage.feature import hog
import glob
#We noticed that the car svm classifier doesnt detect the black car in images. So added few black car images the training set
#This files is used to generated additional black car data. I created some black car images, and flipped, translated those car images for additional data
#read black car file name
def readFile():
names = glob.glob("black-car/*.png")
return names
#read images
def readImages(filenames):
images = []
for name in filenames:
image = cv2.imread(name)
images.append(image)
return images
#randomly translates images
def random_trans(image, trans_range):
rows,cols,_ = image.shape;
tr_x = trans_range * np.random.uniform() - trans_range / 2
tr_y = 20 * np.random.uniform() - 20 / 2
Trans_M = np.float32([[1, 0, tr_x], [0, 1, tr_y]])
image_tr = cv2.warpAffine(image, Trans_M, (cols, rows))
return image_tr
#randomly flip images
def random_flip(image):
n = np.random.randint(0, 2)
if n == 0:
image = cv2.flip(image, 1)
return image
#generated images
def gen_random_data(images,count):
counter = 0
current = 0
new_images = []
while counter < count:
cur = images[current]
cur_trans = random_trans(cur,20)
cur_flip = random_flip(cur_trans)
new_images.append(cur)
new_images.append(cur_trans)
new_images.append(cur_flip)
current = current + 1
if current > len(images) -1:
current = 0
counter = counter + 3
return new_images
#write generated images to the training set images folder
def write_images(gen_images):
count = 0
for image in gen_images:
cv2.imwrite("extra/bc"+ str(count)+".png",image)
count = count + 1
filenames = readFile()
images = readImages(filenames)
gen_images = gen_random_data(images,1000)
write_images(gen_images)
#check the image generated
plt.imshow(gen_images[2])
plt.show()