diff --git a/ultralytics/models/yolov10/predict.py b/ultralytics/models/yolov10/predict.py index 98994f71b..e3670f209 100644 --- a/ultralytics/models/yolov10/predict.py +++ b/ultralytics/models/yolov10/predict.py @@ -11,7 +11,7 @@ def postprocess(self, preds, img, orig_imgs): if isinstance(preds, (list, tuple)): preds = preds[0] - + if preds.shape[-1] == 6: pass else: @@ -22,9 +22,7 @@ def postprocess(self, preds, img, orig_imgs): mask = preds[..., 4] > self.args.conf - b, _, c = preds.shape - preds = preds.view(-1, preds.shape[-1])[mask.view(-1)] - preds = preds.view(b, -1, c) + preds = [p[mask[idx]] for idx, p in enumerate(preds)] if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)