diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index eac1ba8e32a..3e12e2d90d1 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -226,43 +226,43 @@ def __call__( return encoded_inputs - def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): """ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. Args: outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): - Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. Returns: - semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length - `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if - `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits - if len(logits) != len(target_sizes): - raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - - if target_sizes is not None and target_sizes.shape[1] != 2: - raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") - - semantic_segmentation = logits.argmax(dim=1) - - # Resize semantic segmentation maps + # Resize logits and compute semantic segmentation maps if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() - resized_maps = [] - semantic_segmentation = semantic_segmentation.numpy() + semantic_segmentation = [] - for idx in range(len(semantic_segmentation)): - resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) - resized_maps.append(resized) - - semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] return semantic_segmentation diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 7d2d75d2881..377ed8e8e94 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -455,3 +455,28 @@ def test_inference_semantic_segmentation(self): ) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_post_processing_semantic_segmentation(self): + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + model = model.to(torch_device) + + feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) + + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + outputs.logits = outputs.logits.detach().cpu() + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)]) + expected_shape = torch.Size((500, 300)) + self.assertEqual(segmentation[0].shape, expected_shape) + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs) + expected_shape = torch.Size((160, 160)) + self.assertEqual(segmentation[0].shape, expected_shape)