-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
271 lines (209 loc) · 9.86 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
import conf
import torch
import Preprocessing
import Trajectory
import Yolo
import cv2
import pickle as pkl
import numpy as np
import random
def progress_bar(current_value, total):
increments = 50
percentual = int((current_value/ total) * 100)
i = int(percentual // (100 / increments ))
text = "\r[{0: <{1}}] {2}%".format('=' * i, increments, percentual)
print(text, end="\n" if percentual == 100 else "")
# loading available classes (only person will be used)
fp = open(conf.CLASS_NAME_PATH, 'r')
classes = fp.read().split("\n")[:-1] # discard the last
# loading colors from palette
colors = pkl.load(open(conf.PALETTE_PATH, "rb"))
# checking if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Loading network...")
yolo = Yolo.Yolo(device)
print("Network successfully loaded")
cap = cv2.VideoCapture(conf.VIDEO_PATH)
assert cap.isOpened(), 'Cannot capture source, bad video path?'
video_length = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
loading = 0 # for the status bar
distortion = Preprocessing.Distortion()
prospective = Preprocessing.Prospective()
trajectory = Trajectory.Trajectory()
# GLOBAL TRACKING VARS
buckets_colors = [] # id -> (R, G, B)
buckets_cords = [] # id -> [(x,y), (x2, y2)]
buckets_cords_orig = []
# saving result as a new video?
if conf.SAVE_RESULT:
fourcc = cv2.VideoWriter_fourcc(*'XVID')
vout = cv2.VideoWriter(conf.VIDEO_OUT_PATH, fourcc, 20.0, (640, 480))
vout1 = cv2.VideoWriter(conf.VIDEO_UP_OUT_PATH, fourcc, 20.0, (640, 480))
print("analyzing video...")
while cap.isOpened():
loading = loading + 1 # status bar progressing
progress_bar(loading, video_length)
ret, frame = cap.read()
if ret:
# ----------------------
# START preprocessing |
# ----------------------
# camera distortion correction
frame = distortion.correct(frame)
# calculating upper visual
frame_up = prospective.frame_transform(frame)
# noise filtering
frame = cv2.bilateralFilter(
frame,
conf.BILATERAL_D,
conf.BILATERAL_SIGMA,
conf.BILATERAL_SIGMA
)
# -------------------------------------------
# END pre-processing | START yolo prediction|
# -------------------------------------------
output = yolo.predict(frame)
if type(output) != torch.Tensor:
continue
# --------------------------------------
# END yolo prediction | START tracking|
# --------------------------------------
# buckets scores is a ndarray which has people on rows and buckets on columns:
#
# ----------------------------------------------------
# - | bucket1 | bucket2 | ... | bucketN |
# --------|----------|----------|--------|-----------|
# person1 | | | | |
# --------|----------|----------|--------|-----------|
# person2 | | | | |
# --------|----------|----------|--------|-----------|
# ... | | | | |
# --------|----------|----------|--------|-----------|
# personM | | | | |
# ----------------------------------------------------
counter = len(['' for ot in output if classes[int(ot[-1])] == 'person'])
if conf.TESTING:
print("number of people detected", counter)
buckets_score = np.ndarray(shape=(counter, len(buckets_cords))) # distance matrix
new_coords = [] # using this to recover coords from the distance matrix
new_coords_orig = [] # old reference for new coords
old_bc = buckets_cords
for ip, person in enumerate(output):
# check if class is person
if classes[int(person[-1])] == 'person':
# find detection center
c1 = tuple(person[1:3].int().cpu())
c2 = tuple(person[3:5].int().cpu())
center = (np.asarray((c1[0] + c2[0]) // 2), np.asarray(c2[1])) # cx, cy
# get the coords of those points from the other perspective
pts = np.array([[center[0], center[1]]], dtype="float32")
pts = np.array([pts])
center_upp = prospective.point_transform(pts)
# SPECIAL CASE: avoid deadlock at the first iteration (no buckets)
# adding the first person met
if len(old_bc) == 0:
buckets_cords.append([center_upp]) # retrieve coords from the new person vector
buckets_cords_orig.append([center])
buckets_colors.append(random.choice(colors)) # link a color to this new index
if conf.TESTING:
print("new element in bucket list (INIT). buckets size: ", len(buckets_cords))
continue
new_coords.append(center_upp) # saving this for later
new_coords_orig.append(center)
# let's create a score person->bucket foreach bucket.
for k, bucket in enumerate(old_bc):
cx1, cy1 = bucket[-1] # compare last position
pos_dist = abs(center_upp[0] - cx1) + abs(center_upp[1] - cy1)
if conf.TESTING:
print("pos dist: ", pos_dist)
if len(bucket) > 1: # if we have enough points, compare also trajectory for the bucket
next_x, next_y = trajectory.next_point_prediction(bucket, cx1, cy1)
trj_score = trajectory.score(next_x, next_y, center_upp[0], center_upp[1])
pos_dist = (pos_dist + trj_score) / 2
if conf.TESTING:
print("trj_score: ", trj_score)
buckets_score[ip][k] = pos_dist
# find best people score for each buckets and return indexes
if buckets_score.shape[0] > 0 and buckets_score.shape[1] > 0:
# execute this cycle as many times as number of people
for personBucket in buckets_score:
if conf.TESTING:
print("buckets_score: ", buckets_score)
# find min val in matrix
ind = np.unravel_index(np.argmin(buckets_score, axis=None), buckets_score.shape)
if conf.TESTING:
print("best index: ", ind)
print("new chords: ", new_coords)
# check if min score is under a give threshold
if buckets_score[ind] <= conf.MIN_TRACKING_TH:
if buckets_score.shape[1] > 1:
bucket_score_temp = buckets_score[:, ind[1]].copy()
for cp, per in enumerate(buckets_score):
idx = np.argmin(per)
if per[idx] < conf.MIN_TRACKING_TH:
if idx == ind[1]:
A, B = np.partition(per, 1)[0:2]
bucket_score_temp[cp] = A / B
ind = list(ind)
ind[0] = np.argmin(bucket_score_temp)
ind = tuple(ind)
# append person coords to bucktes
buckets_cords[ind[1]].append(new_coords[ind[0]])
buckets_cords_orig[ind[1]].append(new_coords_orig[ind[0]])
# setting score out of image view
buckets_score[ind[0], :] = 9998
buckets_score[:, ind[1]] = 9998
buckets_score[ind[0], ind[1]] = 9999
if conf.TESTING:
print("find position in bucket: ", ind[1])
col_idx = ind[1]
# if it is not under a given threshold add buckets
else:
buckets_cords.append([new_coords[ind[0]]])
buckets_cords_orig.append([new_coords_orig[ind[0]]])
buckets_colors.append(random.choice(colors))
if conf.TESTING:
print("new element in bucket list. buckets: ", buckets_cords)
col_idx = buckets_score.shape[1]
cv2.circle(
frame_up, (
new_coords[ind[0]][0],
new_coords[ind[0]][1]
),
7,
buckets_colors[col_idx],
-1
)
cv2.circle(
frame, (
new_coords_orig[ind[0]][0],
new_coords_orig[ind[0]][1]
),
7,
buckets_colors[col_idx],
-1
)
for bkt_idx, bkt in enumerate(buckets_cords):
bkt = np.asarray(bkt).reshape((-1, 1, 2))
cv2.polylines(frame_up, np.int32([bkt]), isClosed=False, color=buckets_colors[bkt_idx], thickness=3,
lineType=10)
bkt_orig = np.asarray(buckets_cords_orig[bkt_idx]).reshape((-1, 1, 2))
cv2.polylines(frame, np.int32([bkt_orig]), isClosed=False, color=buckets_colors[bkt_idx], thickness=3,
lineType=10)
frame_up = cv2.resize(frame_up, (640, 480))
if conf.SAVE_RESULT:
vout1.write(frame_up)
vout.write(frame)
if conf.LIVE_RESULTS:
cv2.imshow("frame", frame)
cv2.imshow("frame_up", frame_up)
key = cv2.waitKey(1)
else: # exit if video if over
break
# saving results and exiting
if conf.SAVE_RESULT:
cap.release()
vout.release()
vout1.release()
print("Your video is ready!")
cv2.destroyAllWindows()