diff --git a/utils/general.py b/utils/general.py index efe78b29a..af8550b3b 100644 --- a/utils/general.py +++ b/utils/general.py @@ -900,7 +900,7 @@ def non_max_suppression( """ if isinstance(prediction, (list, tuple)): # YOLO model in validation model, output = (inference_out, loss_out) - prediction = prediction[0] # select only inference output + prediction = prediction[0][1] # select only inference output device = prediction.device mps = 'mps' in device.type # Apple MPS