diff --git a/detect.py b/detect.py index 1b70dbb7ef89..fe01f1120135 100644 --- a/detect.py +++ b/detect.py @@ -59,7 +59,7 @@ def detect(save_img=False): t0 = time.time() img = torch.zeros((1, 3, imgsz, imgsz), device=device) # init img _ = model(img.half() if half else img) if device.type != 'cpu' else None # run once - for path, img, im0s, vid_cap in dataset: + for path, img, im0s, vid_cap, rotation in dataset: img = torch.from_numpy(img).to(device) img = img.half() if half else img.float() # uint8 to fp16/32 img /= 255.0 # 0 - 255 to 0.0 - 1.0 @@ -131,8 +131,8 @@ def detect(save_img=False): fourcc = 'mp4v' # output video codec fps = vid_cap.get(cv2.CAP_PROP_FPS) - w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) - h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + w = int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) if not rotation or rotation == '180' else int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + h = int(vid_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) if not rotation or rotation == '180' else int(vid_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) vid_writer = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*fourcc), fps, (w, h)) vid_writer.write(im0) diff --git a/requirements.txt b/requirements.txt index d04638f18e87..cd0e1355de4b 100755 --- a/requirements.txt +++ b/requirements.txt @@ -12,6 +12,8 @@ tensorboard>=2.2 torch>=1.6.0 torchvision>=0.7.0 tqdm>=4.41.0 +scikit-video +ffmpeg # logging ------------------------------------- # wandb diff --git a/utils/datasets.py b/utils/datasets.py index 841879a5cf8f..7e74e8de47e2 100755 --- a/utils/datasets.py +++ b/utils/datasets.py @@ -12,6 +12,8 @@ from threading import Thread import cv2 +import math +import skvideo.io import numpy as np import torch from PIL import Image, ExifTags @@ -128,9 +130,19 @@ def __init__(self, path, img_size=640): images = [x for x in files if os.path.splitext(x)[-1].lower() in img_formats] videos = [x for x in files if os.path.splitext(x)[-1].lower() in vid_formats] ni, nv = len(images), len(videos) + videos_rotation = [None for _ in videos] + for index in range(nv): + metadata = skvideo.io.ffprobe(videos[index]) + if 'video' in metadata and 'tag' in metadata['video']: + tags = metadata['video']['tag'] + + for tag in tags: + if tag['@key'] == 'rotate': + videos_rotation[index] = tag['@value'] self.img_size = img_size self.files = images + videos + self.rotation = [None for _ in images] + videos_rotation self.nf = ni + nv # number of files self.video_flag = [False] * ni + [True] * nv self.mode = 'images' @@ -149,6 +161,7 @@ def __next__(self): if self.count == self.nf: raise StopIteration path = self.files[self.count] + rotation = self.rotation[self.count] if self.video_flag[self.count]: # Read video @@ -174,6 +187,14 @@ def __next__(self): assert img0 is not None, 'Image Not Found ' + path print('image %g/%g %s: ' % (self.count, self.nf, path), end='') + # Rotation Valid + if rotation == '90': + img0 = cv2.rotate(img0, 0) + if rotation == '180': + img0 = cv2.rotate(img0, 1) + if rotation == '270': + img0 = cv2.rotate(img0, 2) + # Padded resize img = letterbox(img0, new_shape=self.img_size)[0] @@ -181,7 +202,7 @@ def __next__(self): img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = np.ascontiguousarray(img) - return path, img, img0, self.cap + return path, img, img0, self.cap, rotation def new_video(self, path): self.frame = 0 @@ -243,7 +264,7 @@ def __next__(self): img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 img = np.ascontiguousarray(img) - return img_path, img, img0, None + return img_path, img, img0, None, None def __len__(self): return 0 @@ -316,7 +337,7 @@ def __next__(self): img = img[:, :, :, ::-1].transpose(0, 3, 1, 2) # BGR to RGB, to bsx3x416x416 img = np.ascontiguousarray(img) - return self.sources, img, img0, None + return self.sources, img, img0, None, None def __len__(self): return 0 # 1E12 frames = 32 streams at 30 FPS for 30 years