diff --git a/models/common.py b/models/common.py index 4f93887c55e0..f914c9d60fdb 100644 --- a/models/common.py +++ b/models/common.py @@ -617,7 +617,7 @@ def forward(self, imgs, size=640, augment=False, profile=False): files.append(Path(f).with_suffix('.jpg').name) if im.shape[0] < 5: # image in CHW im = im.transpose((1, 2, 0)) # reverse dataloader .transpose(2, 0, 1) - im = im[..., :3] if im.ndim == 3 else np.tile(im[..., None], 3) # enforce 3ch input + im = im[..., :3] if im.ndim == 3 else cv2.cvtColor(im, cv2.COLOR_GRAY2BGR) # enforce 3ch input s = im.shape[:2] # HWC shape0.append(s) # image shape g = (size / max(s)) # gain