From d70e9f0d33084e70976a637df8ef6b64c5a5efe4 Mon Sep 17 00:00:00 2001 From: Tanmay patil Date: Fri, 16 Feb 2024 17:08:59 +0530 Subject: [PATCH] Support : Leverage Accelerate for object detection/segmentation models (#28312) * made changes for object detection models * added support for segmentation models. * Made changes for segmentaion models * Changed import statements * solving conflicts * removed conflicts * Resolving commits * Removed conflicts * Fix : Pixel_mask_value set to False --- .../modeling_conditional_detr.py | 16 +++++++++++----- .../deformable_detr/modeling_deformable_detr.py | 16 ++++++++++------ src/transformers/models/detr/modeling_detr.py | 15 ++++++++++----- .../models/mask2former/modeling_mask2former.py | 12 +++++++++++- .../models/maskformer/modeling_maskformer.py | 11 +++++++++++ .../models/oneformer/modeling_oneformer.py | 11 +++++++++++ .../modeling_table_transformer.py | 15 ++++++++++----- .../models/yolos/image_processing_yolos.py | 2 +- src/transformers/models/yolos/modeling_yolos.py | 14 +++++++++----- 9 files changed, 84 insertions(+), 28 deletions(-) diff --git a/src/transformers/models/conditional_detr/modeling_conditional_detr.py b/src/transformers/models/conditional_detr/modeling_conditional_detr.py index 2a9bbdeff6bd7c..2a5e06ea2b4abc 100644 --- a/src/transformers/models/conditional_detr/modeling_conditional_detr.py +++ b/src/transformers/models/conditional_detr/modeling_conditional_detr.py @@ -30,6 +30,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, is_timm_available, is_vision_available, @@ -41,6 +42,10 @@ from .configuration_conditional_detr import ConditionalDetrConfig +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + if is_scipy_available(): from scipy.optimize import linear_sum_assignment @@ -2507,11 +2512,12 @@ def forward(self, outputs, targets): # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {} diff --git a/src/transformers/models/deformable_detr/modeling_deformable_detr.py b/src/transformers/models/deformable_detr/modeling_deformable_detr.py index 3c6e48a6226221..89682729c651bd 100755 --- a/src/transformers/models/deformable_detr/modeling_deformable_detr.py +++ b/src/transformers/models/deformable_detr/modeling_deformable_detr.py @@ -43,7 +43,7 @@ from ...modeling_outputs import BaseModelOutput from ...modeling_utils import PreTrainedModel from ...pytorch_utils import meshgrid -from ...utils import is_ninja_available, logging +from ...utils import is_accelerate_available, is_ninja_available, logging from ...utils.backbone_utils import load_backbone from .configuration_deformable_detr import DeformableDetrConfig from .load_custom import load_cuda_kernels @@ -65,6 +65,10 @@ if is_vision_available(): from transformers.image_transforms import center_to_corners_format +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + class MultiScaleDeformableAttentionFunction(Function): @staticmethod @@ -2246,11 +2250,11 @@ def forward(self, outputs, targets): # Compute the average number of target boxes accross all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {} diff --git a/src/transformers/models/detr/modeling_detr.py b/src/transformers/models/detr/modeling_detr.py index 218d63a412b170..0fa912eb1d5192 100644 --- a/src/transformers/models/detr/modeling_detr.py +++ b/src/transformers/models/detr/modeling_detr.py @@ -30,6 +30,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, is_timm_available, is_vision_available, @@ -41,6 +42,10 @@ from .configuration_detr import DetrConfig +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + if is_scipy_available(): from scipy.optimize import linear_sum_assignment @@ -2204,11 +2209,11 @@ def forward(self, outputs, targets): # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {} diff --git a/src/transformers/models/mask2former/modeling_mask2former.py b/src/transformers/models/mask2former/modeling_mask2former.py index 15f1759045f6a7..bf86b5ba6039e6 100644 --- a/src/transformers/models/mask2former/modeling_mask2former.py +++ b/src/transformers/models/mask2former/modeling_mask2former.py @@ -34,7 +34,7 @@ ) from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithCrossAttentions from ...modeling_utils import PreTrainedModel -from ...utils import logging +from ...utils import is_accelerate_available, logging from ...utils.backbone_utils import load_backbone from .configuration_mask2former import Mask2FormerConfig @@ -42,6 +42,10 @@ if is_scipy_available(): from scipy.optimize import linear_sum_assignment +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + logger = logging.get_logger(__name__) @@ -788,6 +792,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor """ num_masks = sum([len(classes) for classes in class_labels]) num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) + world_size = 1 + if PartialState._shared_state != {}: + num_masks_pt = reduce(num_masks_pt) + world_size = PartialState().num_processes + + num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) return num_masks_pt diff --git a/src/transformers/models/maskformer/modeling_maskformer.py b/src/transformers/models/maskformer/modeling_maskformer.py index dd8f7ccfdf9eb1..eef31ba2799a45 100644 --- a/src/transformers/models/maskformer/modeling_maskformer.py +++ b/src/transformers/models/maskformer/modeling_maskformer.py @@ -31,6 +31,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, logging, replace_return_docstrings, @@ -42,6 +43,10 @@ from .configuration_maskformer_swin import MaskFormerSwinConfig +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + if is_scipy_available(): from scipy.optimize import linear_sum_assignment @@ -1194,6 +1199,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor """ num_masks = sum([len(classes) for classes in class_labels]) num_masks_pt = torch.as_tensor(num_masks, dtype=torch.float, device=device) + world_size = 1 + if PartialState._shared_state != {}: + num_masks_pt = reduce(num_masks_pt) + world_size = PartialState().num_processes + + num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) return num_masks_pt diff --git a/src/transformers/models/oneformer/modeling_oneformer.py b/src/transformers/models/oneformer/modeling_oneformer.py index 87014d8afbf6fa..586fd7345c5645 100644 --- a/src/transformers/models/oneformer/modeling_oneformer.py +++ b/src/transformers/models/oneformer/modeling_oneformer.py @@ -31,6 +31,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, logging, replace_return_docstrings, @@ -40,6 +41,10 @@ from .configuration_oneformer import OneFormerConfig +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + logger = logging.get_logger(__name__) @@ -723,6 +728,12 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor """ num_masks = sum([len(classes) for classes in class_labels]) num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device) + world_size = 1 + if PartialState._shared_state != {}: + num_masks_pt = reduce(num_masks_pt) + world_size = PartialState().num_processes + + num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1) return num_masks_pt diff --git a/src/transformers/models/table_transformer/modeling_table_transformer.py b/src/transformers/models/table_transformer/modeling_table_transformer.py index a113c99109ba64..8a16917c3c76b8 100644 --- a/src/transformers/models/table_transformer/modeling_table_transformer.py +++ b/src/transformers/models/table_transformer/modeling_table_transformer.py @@ -30,6 +30,7 @@ ModelOutput, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, is_timm_available, is_vision_available, @@ -50,6 +51,10 @@ if is_vision_available(): from transformers.image_transforms import center_to_corners_format +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce + logger = logging.get_logger(__name__) _CONFIG_FOR_DOC = "TableTransformerConfig" @@ -1751,11 +1756,11 @@ def forward(self, outputs, targets): # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {} diff --git a/src/transformers/models/yolos/image_processing_yolos.py b/src/transformers/models/yolos/image_processing_yolos.py index 22d43026a27c9b..d964f6f02f4187 100644 --- a/src/transformers/models/yolos/image_processing_yolos.py +++ b/src/transformers/models/yolos/image_processing_yolos.py @@ -1297,7 +1297,7 @@ def preprocess( encoded_inputs = self.pad( images, annotations=annotations, - return_pixel_mask=True, + return_pixel_mask=False, data_format=data_format, input_data_format=input_data_format, update_bboxes=do_convert_annotations, diff --git a/src/transformers/models/yolos/modeling_yolos.py b/src/transformers/models/yolos/modeling_yolos.py index 65ffbfced4e85c..237429ae707d4c 100755 --- a/src/transformers/models/yolos/modeling_yolos.py +++ b/src/transformers/models/yolos/modeling_yolos.py @@ -33,6 +33,7 @@ add_code_sample_docstrings, add_start_docstrings, add_start_docstrings_to_model_forward, + is_accelerate_available, is_scipy_available, is_vision_available, logging, @@ -48,6 +49,9 @@ if is_vision_available(): from transformers.image_transforms import center_to_corners_format +if is_accelerate_available(): + from accelerate import PartialState + from accelerate.utils import reduce logger = logging.get_logger(__name__) @@ -1074,11 +1078,11 @@ def forward(self, outputs, targets): # Compute the average number of target boxes across all nodes, for normalization purposes num_boxes = sum(len(t["class_labels"]) for t in targets) num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device) - # (Niels): comment out function below, distributed training to be added - # if is_dist_avail_and_initialized(): - # torch.distributed.all_reduce(num_boxes) - # (Niels) in original implementation, num_boxes is divided by get_world_size() - num_boxes = torch.clamp(num_boxes, min=1).item() + world_size = 1 + if PartialState._shared_state != {}: + num_boxes = reduce(num_boxes) + world_size = PartialState().num_processes + num_boxes = torch.clamp(num_boxes / world_size, min=1).item() # Compute all the requested losses losses = {}