-
Notifications
You must be signed in to change notification settings - Fork 55
/
Copy pathLoadBatches.py
78 lines (53 loc) · 2.21 KB
/
LoadBatches.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import numpy as np
import cv2
import glob
import itertools
import matplotlib.pyplot as plt
import random
def getImageArr(im):
img = im.astype(np.float32)
img[:, :, 0] -= 103.939
img[:, :, 1] -= 116.779
img[:, :, 2] -= 123.68
return img
def getSegmentationArr(seg, nClasses, input_height, input_width):
seg_labels = np.zeros((input_height, input_width, nClasses))
for c in range(nClasses):
seg_labels[:, :, c] = (seg == c).astype(int)
seg_labels = np.reshape(seg_labels, (-1, nClasses))
return seg_labels
def imageSegmentationGenerator(images_path, segs_path, batch_size,
n_classes, input_height, input_width):
assert images_path[-1] == '/'
assert segs_path[-1] == '/'
images = sorted(glob.glob(images_path + "*.jpg") +
glob.glob(images_path + "*.png") + glob.glob(images_path + "*.jpeg"))
segmentations = sorted(glob.glob(segs_path + "*.jpg") +
glob.glob(segs_path + "*.png") + glob.glob(segs_path + "*.jpeg"))
zipped = itertools.cycle(zip(images, segmentations))
while True:
X = []
Y = []
for _ in range(batch_size):
im, seg = zipped.__next__()
im = cv2.imread(im, 1)
seg = cv2.imread(seg, 0)
assert im.shape[:2] == seg.shape[:2]
assert im.shape[0] >= input_height and im.shape[1] >= input_width
xx = random.randint(0, im.shape[0] - input_height)
yy = random.randint(0, im.shape[1] - input_width)
im = im[xx:xx + input_height, yy:yy + input_width]
seg = seg[xx:xx + input_height, yy:yy + input_width]
X.append(getImageArr(im))
Y.append(
getSegmentationArr(
seg,
n_classes,
input_height,
input_width))
yield np.array(X), np.array(Y)
if __name__ == '__main__':
G = imageSegmentationGenerator("data/dataset1/images_prepped_train/",
"data/dataset1/annotations_prepped_train/", batch_size=16, n_classes=15, input_height=320, input_width=320)
x, y = G.__next__()
print(x.shape, y.shape)