From 29ebf2b6589be653aacf00cff9570b0f32afe966 Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Sat, 17 Sep 2022 17:08:45 +0300 Subject: [PATCH 01/10] add post_process_semantic_segmentation method --- .../models/beit/feature_extraction_beit.py | 48 ++++++++++++++++++- 1 file changed, 46 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 62b790621baf3e..b726b7c23a370d 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -14,7 +14,7 @@ # limitations under the License. """Feature extractor class for BEiT.""" -from typing import Optional, Union +from typing import List, Optional, Tuple, Union import numpy as np from PIL import Image @@ -27,9 +27,12 @@ ImageInput, is_torch_tensor, ) -from ...utils import TensorType, logging +from ...utils import TensorType, is_torch_available, logging +if is_torch_available(): + import torch + logger = logging.get_logger(__name__) @@ -222,3 +225,44 @@ def __call__( encoded_inputs = BatchFeature(data=data, tensor_type=return_tensors) return encoded_inputs + + def post_process_semantic_segmentation(self, outputs, target_sizes: Union[torch.Tensor, List[Tuple]] = None): + """ + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports + PyTorch. + Parameters: + outputs ([`BeitForSemanticSegmentation`]): + Raw outputs of the model. + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to + None, predictions will not be resized. + Returns: + `List[torch.Tensor]`: A list of `torch.Tensor` of length `batch_size`, where each item is a semantic + segmentation map of of the corresponding target_sizes entry (if `target_sizes` is specified). Each + entry of each `torch.Tensor` correspond to a semantic class id. + """ + logits = outputs.logits + + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + semantic_segmentation = logits.argmax(dim=1) + + # Resize semantic segmentation maps + if target_sizes is not None: + if is_torch_tensor(target_sizes): + target_sizes = target_sizes.numpy() + + resized_maps = [] + semantic_segmentation = semantic_segmentation.numpy() + + for idx in range(len(semantic_segmentation)): + resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) + resized_maps.append(resized) + + semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] + + return semantic_segmentation From 68e3d05dd2690a733927c422286d566f00884196 Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Sat, 17 Sep 2022 17:11:30 +0300 Subject: [PATCH 02/10] update docs --- docs/source/en/model_doc/beit.mdx | 1 + src/transformers/models/beit/feature_extraction_beit.py | 7 +++---- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/model_doc/beit.mdx b/docs/source/en/model_doc/beit.mdx index 625357810ded9f..040888547adb7f 100644 --- a/docs/source/en/model_doc/beit.mdx +++ b/docs/source/en/model_doc/beit.mdx @@ -77,6 +77,7 @@ contributed by [kamalkraj](https://huggingface.co/kamalkraj). The original code [[autodoc]] BeitFeatureExtractor - __call__ + - post_process_semantic_segmentation ## BeitModel diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index b726b7c23a370d..9227550696e4f0 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -228,9 +228,8 @@ def __call__( def post_process_semantic_segmentation(self, outputs, target_sizes: Union[torch.Tensor, List[Tuple]] = None): """ - Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports - PyTorch. Parameters: + Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): @@ -238,8 +237,8 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[torch. None, predictions will not be resized. Returns: `List[torch.Tensor]`: A list of `torch.Tensor` of length `batch_size`, where each item is a semantic - segmentation map of of the corresponding target_sizes entry (if `target_sizes` is specified). Each - entry of each `torch.Tensor` correspond to a semantic class id. + segmentation map of of the corresponding target_sizes entry (if `target_sizes` is specified). Each entry of + each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits From 6dc836b323e468b0cf76520566a638f0e29ef9ce Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Mon, 19 Sep 2022 14:00:41 +0300 Subject: [PATCH 03/10] fix test errors --- src/transformers/models/beit/feature_extraction_beit.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 9227550696e4f0..901c45a1956124 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -226,13 +226,14 @@ def __call__( return encoded_inputs - def post_process_semantic_segmentation(self, outputs, target_sizes: Union[torch.Tensor, List[Tuple]] = None): + def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): """ Parameters: Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, + *optional*): Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to None, predictions will not be resized. Returns: From c1d79b4444f86647a6de00feb67b5b6efbd81103 Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Mon, 19 Sep 2022 16:42:16 +0300 Subject: [PATCH 04/10] fix formatting --- src/transformers/models/beit/feature_extraction_beit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 901c45a1956124..ec85dec001b322 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -228,12 +228,12 @@ def __call__( def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): """ - Parameters: Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + + Args: outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, - *optional*): + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to None, predictions will not be resized. Returns: From c840b4fa265de0a0063dba18fbfecd2c5f7401df Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Mon, 19 Sep 2022 16:51:33 +0300 Subject: [PATCH 05/10] fix formatting --- src/transformers/models/beit/feature_extraction_beit.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index ec85dec001b322..eac1ba8e32a241 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -237,9 +237,9 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[Tensor Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to None, predictions will not be resized. Returns: - `List[torch.Tensor]`: A list of `torch.Tensor` of length `batch_size`, where each item is a semantic - segmentation map of of the corresponding target_sizes entry (if `target_sizes` is specified). Each entry of - each `torch.Tensor` correspond to a semantic class id. + semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length + `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if + `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits From bbf3cee49979b8469a1f70a0a7443b0e8d40b42d Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Tue, 20 Sep 2022 11:56:57 +0300 Subject: [PATCH 06/10] return post-processed segmentations as list, add test --- .../models/beit/feature_extraction_beit.py | 22 +++++++-------- tests/models/beit/test_modeling_beit.py | 28 +++++++++++++++++++ 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index eac1ba8e32a241..5ca2da49899a8a 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -237,22 +237,20 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[Tensor Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to None, predictions will not be resized. Returns: - semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length - `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if - `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (w, h) corresponding to the target_sizes entry (if `target_sizes` is specified). + Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits - - if len(logits) != len(target_sizes): - raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - - if target_sizes is not None and target_sizes.shape[1] != 2: - raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") - semantic_segmentation = logits.argmax(dim=1) # Resize semantic segmentation maps if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() @@ -263,6 +261,8 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[Tensor resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) resized_maps.append(resized) - semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] + semantic_segmentation = [torch.Tensor(np.array(image)).to(torch.int64) for image in resized_maps] + else: + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] return semantic_segmentation diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index 7d2d75d2881b75..b51dca88ecc54b 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -455,3 +455,31 @@ def test_inference_semantic_segmentation(self): ) self.assertTrue(torch.allclose(logits[0, :3, :3, :3], expected_slice, atol=1e-4)) + + @slow + def test_post_processing_semantic_segmentation(self): + model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640") + model = model.to(torch_device) + + feature_extractor = BeitFeatureExtractor(do_resize=True, size=640, do_center_crop=False) + + ds = load_dataset("hf-internal-testing/fixtures_ade20k", split="test") + image = Image.open(ds[0]["file"]) + inputs = feature_extractor(images=image, return_tensors="pt").to(torch_device) + + # forward pass + with torch.no_grad(): + outputs = model(**inputs) + + outputs.logits = outputs.logits.detach().cpu() + + target_size = torch.Tensor([[500, 300]]) + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=target_size)[ + 0 + ] + expected_shape = torch.Size((300, 500)) + self.assertEqual(segmentation.shape, expected_shape) + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)[0] + expected_shape = torch.Size((160, 160)) + self.assertEqual(segmentation.shape, expected_shape) From 85e4e913f0385b2fda5f6c5cd555d1aa50f70cfe Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Tue, 20 Sep 2022 12:58:27 +0300 Subject: [PATCH 07/10] minor changes --- src/transformers/models/beit/feature_extraction_beit.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 5ca2da49899a8a..07b224feb2821b 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -261,7 +261,7 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[Tensor resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) resized_maps.append(resized) - semantic_segmentation = [torch.Tensor(np.array(image)).to(torch.int64) for image in resized_maps] + semantic_segmentation = [torch.Tensor(np.array(image)).long() for image in resized_maps] else: semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] From 9d0059d375e7c64b3fa0ba5f3a2f1a525887a14b Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Tue, 20 Sep 2022 16:55:48 +0300 Subject: [PATCH 08/10] use torch to resize logits --- .../models/beit/feature_extraction_beit.py | 28 +++++++++---------- tests/models/beit/test_modeling_beit.py | 15 ++++------ 2 files changed, 20 insertions(+), 23 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 07b224feb2821b..3e12e2d90d1a78 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -226,25 +226,24 @@ def __call__( return encoded_inputs - def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): """ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. Args: outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): - Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. Returns: semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic - segmentation map of shape (w, h) corresponding to the target_sizes entry (if `target_sizes` is specified). - Each entry of each `torch.Tensor` correspond to a semantic class id. + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits - semantic_segmentation = logits.argmax(dim=1) - # Resize semantic segmentation maps + # Resize logits and compute semantic segmentation maps if target_sizes is not None: if len(logits) != len(target_sizes): raise ValueError( @@ -254,15 +253,16 @@ def post_process_semantic_segmentation(self, outputs, target_sizes: Union[Tensor if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() - resized_maps = [] - semantic_segmentation = semantic_segmentation.numpy() + semantic_segmentation = [] - for idx in range(len(semantic_segmentation)): - resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) - resized_maps.append(resized) - - semantic_segmentation = [torch.Tensor(np.array(image)).long() for image in resized_maps] + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) else: + semantic_segmentation = logits.argmax(dim=1) semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] return semantic_segmentation diff --git a/tests/models/beit/test_modeling_beit.py b/tests/models/beit/test_modeling_beit.py index b51dca88ecc54b..377ed8e8e94989 100644 --- a/tests/models/beit/test_modeling_beit.py +++ b/tests/models/beit/test_modeling_beit.py @@ -473,13 +473,10 @@ def test_post_processing_semantic_segmentation(self): outputs.logits = outputs.logits.detach().cpu() - target_size = torch.Tensor([[500, 300]]) - segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=target_size)[ - 0 - ] - expected_shape = torch.Size((300, 500)) - self.assertEqual(segmentation.shape, expected_shape) - - segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs)[0] + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs, target_sizes=[(500, 300)]) + expected_shape = torch.Size((500, 300)) + self.assertEqual(segmentation[0].shape, expected_shape) + + segmentation = feature_extractor.post_process_semantic_segmentation(outputs=outputs) expected_shape = torch.Size((160, 160)) - self.assertEqual(segmentation.shape, expected_shape) + self.assertEqual(segmentation[0].shape, expected_shape) From 71e0f23de95c9354e2f22b8b4805eecec1eb8e4f Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Tue, 20 Sep 2022 17:38:29 +0300 Subject: [PATCH 09/10] fix conflict --- .../models/beit/feature_extraction_beit.py | 56 ++++++++----------- 1 file changed, 22 insertions(+), 34 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index 3e12e2d90d1a78..b168f232cc3e69 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -39,10 +39,8 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): r""" Constructs a BEiT feature extractor. - This feature extractor inherits from [`~feature_extraction_utils.FeatureExtractionMixin`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. - Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the input to a certain `size`. @@ -106,34 +104,25 @@ def __call__( ) -> BatchFeature: """ Main method to prepare for the model one or several image(s). - - NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images. - - Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. - segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations. - return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): If set, will return tensors of a particular framework. Acceptable values are: - - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. - Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: - - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, width). - **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided) @@ -226,43 +215,42 @@ def __call__( return encoded_inputs - def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): + def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): """ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. - Args: outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`List[Tuple]` of length `batch_size`, *optional*): - List of tuples corresponding to the requested final size (height, width) of each prediction. If left to + target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): + Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to None, predictions will not be resized. Returns: - semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic - segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is - specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length + `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if + `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits - # Resize logits and compute semantic segmentation maps - if target_sizes is not None: - if len(logits) != len(target_sizes): - raise ValueError( - "Make sure that you pass in as many target sizes as the batch dimension of the logits" - ) + if len(logits) != len(target_sizes): + raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") + if target_sizes is not None and target_sizes.shape[1] != 2: + raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") + + semantic_segmentation = logits.argmax(dim=1) + + # Resize semantic segmentation maps + if target_sizes is not None: if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() - semantic_segmentation = [] + resized_maps = [] + semantic_segmentation = semantic_segmentation.numpy() - for idx in range(len(logits)): - resized_logits = torch.nn.functional.interpolate( - logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False - ) - semantic_map = resized_logits[0].argmax(dim=0) - semantic_segmentation.append(semantic_map) - else: - semantic_segmentation = logits.argmax(dim=1) - semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] + for idx in range(len(semantic_segmentation)): + resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) + resized_maps.append(resized) + + semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] return semantic_segmentation From 711cc7c118517ec1e1f3c454c66974bbc4875c6d Mon Sep 17 00:00:00 2001 From: Alara Dirik Date: Tue, 20 Sep 2022 17:39:55 +0300 Subject: [PATCH 10/10] push updates --- .../models/beit/feature_extraction_beit.py | 56 +++++++++++-------- 1 file changed, 34 insertions(+), 22 deletions(-) diff --git a/src/transformers/models/beit/feature_extraction_beit.py b/src/transformers/models/beit/feature_extraction_beit.py index b168f232cc3e69..3e12e2d90d1a78 100644 --- a/src/transformers/models/beit/feature_extraction_beit.py +++ b/src/transformers/models/beit/feature_extraction_beit.py @@ -39,8 +39,10 @@ class BeitFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): r""" Constructs a BEiT feature extractor. + This feature extractor inherits from [`~feature_extraction_utils.FeatureExtractionMixin`] which contains most of the main methods. Users should refer to this superclass for more information regarding those methods. + Args: do_resize (`bool`, *optional*, defaults to `True`): Whether to resize the input to a certain `size`. @@ -104,25 +106,34 @@ def __call__( ) -> BatchFeature: """ Main method to prepare for the model one or several image(s). + + NumPy arrays and PyTorch tensors are converted to PIL images when resizing, so the most efficient is to pass PIL images. + + Args: images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`): The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch tensor. In case of a NumPy array/PyTorch tensor, each image should be of shape (C, H, W), where C is a number of channels, H and W are image height and width. + segmentation_maps (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`, *optional*): Optionally, the corresponding semantic segmentation maps with the pixel-wise annotations. + return_tensors (`str` or [`~utils.TensorType`], *optional*, defaults to `'np'`): If set, will return tensors of a particular framework. Acceptable values are: + - `'tf'`: Return TensorFlow `tf.constant` objects. - `'pt'`: Return PyTorch `torch.Tensor` objects. - `'np'`: Return NumPy `np.ndarray` objects. - `'jax'`: Return JAX `jnp.ndarray` objects. + Returns: [`BatchFeature`]: A [`BatchFeature`] with the following fields: + - **pixel_values** -- Pixel values to be fed to a model, of shape (batch_size, num_channels, height, width). - **labels** -- Optional labels to be fed to a model (when `segmentation_maps` are provided) @@ -215,42 +226,43 @@ def __call__( return encoded_inputs - def post_process_semantic_segmentation(self, outputs, target_sizes: Union[TensorType, List[Tuple]] = None): + def post_process_semantic_segmentation(self, outputs, target_sizes: List[Tuple] = None): """ Converts the output of [`BeitForSemanticSegmentation`] into semantic segmentation maps. Only supports PyTorch. + Args: outputs ([`BeitForSemanticSegmentation`]): Raw outputs of the model. - target_sizes (`torch.Tensor` of shape `(batch_size, 2)` or `List[Tuple]` of length `batch_size`, *optional*): - Torch Tensor (or list) corresponding to the requested final size (h, w) of each prediction. If left to + target_sizes (`List[Tuple]` of length `batch_size`, *optional*): + List of tuples corresponding to the requested final size (height, width) of each prediction. If left to None, predictions will not be resized. Returns: - semantic_segmentation: `torch.Tensor` of shape `(batch_size, 2)` or `List[torch.Tensor]` of length - `batch_size`, where each item is a semantic segmentation map of of the corresponding target_sizes entry (if - `target_sizes` is specified). Each entry of each `torch.Tensor` correspond to a semantic class id. + semantic_segmentation: `List[torch.Tensor]` of length `batch_size`, where each item is a semantic + segmentation map of shape (height, width) corresponding to the target_sizes entry (if `target_sizes` is + specified). Each entry of each `torch.Tensor` correspond to a semantic class id. """ logits = outputs.logits - if len(logits) != len(target_sizes): - raise ValueError("Make sure that you pass in as many target sizes as the batch dimension of the logits") - - if target_sizes is not None and target_sizes.shape[1] != 2: - raise ValueError("Each element of target_sizes must contain the size (h, w) of each image of the batch") - - semantic_segmentation = logits.argmax(dim=1) - - # Resize semantic segmentation maps + # Resize logits and compute semantic segmentation maps if target_sizes is not None: + if len(logits) != len(target_sizes): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the logits" + ) + if is_torch_tensor(target_sizes): target_sizes = target_sizes.numpy() - resized_maps = [] - semantic_segmentation = semantic_segmentation.numpy() + semantic_segmentation = [] - for idx in range(len(semantic_segmentation)): - resized = self.resize(image=semantic_segmentation[idx], size=target_sizes[idx]) - resized_maps.append(resized) - - semantic_segmentation = [torch.Tensor(np.array(image)) for image in resized_maps] + for idx in range(len(logits)): + resized_logits = torch.nn.functional.interpolate( + logits[idx].unsqueeze(dim=0), size=target_sizes[idx], mode="bilinear", align_corners=False + ) + semantic_map = resized_logits[0].argmax(dim=0) + semantic_segmentation.append(semantic_map) + else: + semantic_segmentation = logits.argmax(dim=1) + semantic_segmentation = [semantic_segmentation[i] for i in range(semantic_segmentation.shape[0])] return semantic_segmentation