Skip to content

Commit

Permalink
fix mot infer video (#3823)
Browse files Browse the repository at this point in the history
  • Loading branch information
nemonameless authored Jul 29, 2021
1 parent a3e2b2e commit e4259d4
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 16 deletions.
21 changes: 15 additions & 6 deletions deploy/python/mot_jde_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,8 @@ def predict_video(detector, camera_id):
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = []
Expand All @@ -236,7 +237,7 @@ def predict_video(detector, camera_id):

results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking(
im = mot_vis.plot_tracking(
frame,
online_tlwhs,
online_ids,
Expand All @@ -249,11 +250,11 @@ def predict_video(detector, camera_id):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
im)
else:
writer.write(im)
frame_id += 1
print('detect frame:%d' % (frame_id))
im = np.array(online_im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
Expand All @@ -262,7 +263,15 @@ def predict_video(detector, camera_id):
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()

if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()


def main():
Expand Down
17 changes: 13 additions & 4 deletions deploy/python/mot_keypoint_unite_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,8 @@ def mot_keypoint_unite_predict_video(mot_model,
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer_mot = FPSTimer()
timer_kp = FPSTimer()
Expand Down Expand Up @@ -202,8 +203,8 @@ def mot_keypoint_unite_predict_video(mot_model,
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)), im)

writer.write(im)
else:
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking and keypoint results', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
Expand All @@ -212,7 +213,15 @@ def mot_keypoint_unite_predict_video(mot_model,
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, mot_results)
writer.release()

if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()


def main():
Expand Down
21 changes: 15 additions & 6 deletions deploy/python/mot_sde_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,8 @@ def predict_video(detector, reid_model, camera_id):
if not os.path.exists(FLAGS.output_dir):
os.makedirs(FLAGS.output_dir)
out_path = os.path.join(FLAGS.output_dir, video_name)
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
if not FLAGS.save_images:
writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
frame_id = 0
timer = MOTTimer()
results = []
Expand All @@ -379,7 +380,7 @@ def predict_video(detector, reid_model, camera_id):

results.append((frame_id + 1, online_tlwhs, online_scores, online_ids))
fps = 1. / timer.average_time
online_im = mot_vis.plot_tracking(
im = mot_vis.plot_tracking(
frame,
online_tlwhs,
online_ids,
Expand All @@ -392,11 +393,11 @@ def predict_video(detector, reid_model, camera_id):
os.makedirs(save_dir)
cv2.imwrite(
os.path.join(save_dir, '{:05d}.jpg'.format(frame_id)),
online_im)
im)
else:
writer.write(im)
frame_id += 1
print('detect frame:%d' % (frame_id))
im = np.array(online_im)
writer.write(im)
if camera_id != -1:
cv2.imshow('Tracking Detection', im)
if cv2.waitKey(1) & 0xFF == ord('q'):
Expand All @@ -405,7 +406,15 @@ def predict_video(detector, reid_model, camera_id):
result_filename = os.path.join(FLAGS.output_dir,
video_name.split('.')[-2] + '.txt')
write_mot_results(result_filename, results)
writer.release()

if FLAGS.save_images:
save_dir = os.path.join(FLAGS.output_dir, video_name.split('.')[-2])
cmd_str = 'ffmpeg -f image2 -i {}/%05d.jpg -vf "scale=trunc(iw/2)*2:trunc(ih/2)*2" {}'.format(
save_dir, out_path)
os.system(cmd_str)
print('Save video in {}.'.format(out_path))
else:
writer.release()


def main():
Expand Down

0 comments on commit e4259d4

Please sign in to comment.