From 48c5371842c68c75b0c43370c0bf7d12e1f15353 Mon Sep 17 00:00:00 2001 From: "George D. Torres" Date: Thu, 1 Aug 2024 16:26:30 -0500 Subject: [PATCH 1/2] dont squeeze in batch mode --- sam2/sam2_image_predictor.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 94111316..2ba27872 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -203,11 +203,11 @@ def predict_batch( return_logits=return_logits, img_idx=img_idx, ) - masks_np = masks.squeeze(0).float().detach().cpu().numpy() + masks_np = masks.float().detach().cpu().numpy() iou_predictions_np = ( - iou_predictions.squeeze(0).float().detach().cpu().numpy() + iou_predictions.float().detach().cpu().numpy() ) - low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.float().detach().cpu().numpy() all_masks.append(masks_np) all_ious.append(iou_predictions_np) all_low_res_masks.append(low_res_masks_np) From 482409d0cbd014654803906274b9d801c6d507c9 Mon Sep 17 00:00:00 2001 From: "George D. Torres" Date: Thu, 1 Aug 2024 17:08:55 -0500 Subject: [PATCH 2/2] use squeeze appropriately --- sam2/sam2_image_predictor.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/sam2/sam2_image_predictor.py b/sam2/sam2_image_predictor.py index 2ba27872..17cd1a14 100644 --- a/sam2/sam2_image_predictor.py +++ b/sam2/sam2_image_predictor.py @@ -203,11 +203,19 @@ def predict_batch( return_logits=return_logits, img_idx=img_idx, ) + + using_batch_points = point_coords is not None and len(point_coords.shape) == 3 + using_batch_box = box is not None and len(box.shape) == 3 + + if not using_batch_points and not using_batch_box: + masks = masks.squeeze(0) + iou_predictions = iou_predictions.squeeze(0) + low_res_masks = low_res_masks.squeeze(0) + masks_np = masks.float().detach().cpu().numpy() - iou_predictions_np = ( - iou_predictions.float().detach().cpu().numpy() - ) + iou_predictions_np = iou_predictions.float().detach().cpu().numpy() low_res_masks_np = low_res_masks.float().detach().cpu().numpy() + all_masks.append(masks_np) all_ious.append(iou_predictions_np) all_low_res_masks.append(low_res_masks_np) @@ -277,9 +285,17 @@ def predict( return_logits=return_logits, ) - masks_np = masks.squeeze(0).float().detach().cpu().numpy() - iou_predictions_np = iou_predictions.squeeze(0).float().detach().cpu().numpy() - low_res_masks_np = low_res_masks.squeeze(0).float().detach().cpu().numpy() + using_batch_points = point_coords is not None and len(point_coords.shape) == 3 + using_batch_box = box is not None and len(box.shape) == 3 + + if not using_batch_points and not using_batch_box: + masks = masks.squeeze(0) + iou_predictions = iou_predictions.squeeze(0) + low_res_masks = low_res_masks.squeeze(0) + + masks_np = masks.float().detach().cpu().numpy() + iou_predictions_np = iou_predictions.float().detach().cpu().numpy() + low_res_masks_np = low_res_masks.float().detach().cpu().numpy() return masks_np, iou_predictions_np, low_res_masks_np def _prep_prompts(