diff --git a/src/transformers/models/owlv2/image_processing_owlv2.py b/src/transformers/models/owlv2/image_processing_owlv2.py index 1e9a5163a1a6fd..d3ef04238a8f80 100644 --- a/src/transformers/models/owlv2/image_processing_owlv2.py +++ b/src/transformers/models/owlv2/image_processing_owlv2.py @@ -565,9 +565,9 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh """ logits, target_boxes = outputs.logits, outputs.target_pred_boxes - if len(logits) != len(target_sizes): + if target_sizes is not None and len(logits) != len(target_sizes): raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - if target_sizes.shape[1] != 2: + if target_sizes is not None and target_sizes.shape[1] != 2: raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") probs = torch.max(logits, dim=-1) @@ -588,9 +588,14 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh scores[idx][ious > nms_threshold] = 0.0 # Convert from relative [0, 1] to absolute [0, height] coordinates - img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device) - target_boxes = target_boxes * scale_fct[:, None, :] + if target_sizes is not None: + if isinstance(target_sizes, List): + img_h = torch.tensor([i[0] for i in target_sizes]) + img_w = torch.tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device) + target_boxes = target_boxes * scale_fct[:, None, :] # Compute box display alphas based on prediction scores results = [] diff --git a/src/transformers/models/owlvit/image_processing_owlvit.py b/src/transformers/models/owlvit/image_processing_owlvit.py index 25ea5f2720d527..5bc889ba85d501 100644 --- a/src/transformers/models/owlvit/image_processing_owlvit.py +++ b/src/transformers/models/owlvit/image_processing_owlvit.py @@ -556,9 +556,9 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh """ logits, target_boxes = outputs.logits, outputs.target_pred_boxes - if len(logits) != len(target_sizes): + if target_sizes is not None and len(logits) != len(target_sizes): raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - if target_sizes.shape[1] != 2: + if target_sizes is not None and target_sizes.shape[1] != 2: raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") probs = torch.max(logits, dim=-1) @@ -579,9 +579,14 @@ def post_process_image_guided_detection(self, outputs, threshold=0.0, nms_thresh scores[idx][ious > nms_threshold] = 0.0 # Convert from relative [0, 1] to absolute [0, height] coordinates - img_h, img_w = target_sizes.unbind(1) - scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device) - target_boxes = target_boxes * scale_fct[:, None, :] + if target_sizes is not None: + if isinstance(target_sizes, List): + img_h = torch.tensor([i[0] for i in target_sizes]) + img_w = torch.tensor([i[1] for i in target_sizes]) + else: + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1).to(target_boxes.device) + target_boxes = target_boxes * scale_fct[:, None, :] # Compute box display alphas based on prediction scores results = []