Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate support added to Object Detection & Segmentation Models #28315

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -2507,11 +2509,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn
from torch.autograd import Function
from torch.autograd.function import once_differentiable
Expand Down Expand Up @@ -2226,11 +2228,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/deta/modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import torch
import torch.nn.functional as F
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -2203,11 +2205,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes accross all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/detr/modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -2204,11 +2206,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/mask2former/modeling_mask2former.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import numpy as np
import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ... import AutoBackbone
Expand Down Expand Up @@ -788,6 +790,15 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)

return num_masks_pt


Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/maskformer/modeling_maskformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import numpy as np
import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ... import AutoBackbone
Expand Down Expand Up @@ -1194,6 +1196,15 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)

return num_masks_pt


Expand Down
11 changes: 11 additions & 0 deletions src/transformers/models/oneformer/modeling_oneformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

import numpy as np
import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn
from torch.cuda.amp import autocast

Expand Down Expand Up @@ -723,6 +725,15 @@ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> tor
"""
num_masks = sum([len(classes) for classes in class_labels])
num_masks_pt = torch.as_tensor([num_masks], dtype=torch.float, device=device)

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_masks_pt = reduce(num_masks_pt)
world_size = PartialState().num_processes

num_masks_pt = torch.clamp(num_masks_pt / world_size, min=1)

return num_masks_pt


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from typing import Dict, List, Optional, Tuple, Union

import torch
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -1751,11 +1753,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down
15 changes: 10 additions & 5 deletions src/transformers/models/yolos/modeling_yolos.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

import torch
import torch.utils.checkpoint
from accelerate import PartialState
from accelerate.utils import reduce
from torch import Tensor, nn

from ...activations import ACT2FN
Expand Down Expand Up @@ -1074,11 +1076,14 @@ def forward(self, outputs, targets):
# Compute the average number of target boxes across all nodes, for normalization purposes
num_boxes = sum(len(t["class_labels"]) for t in targets)
num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
# (Niels): comment out function below, distributed training to be added
# if is_dist_avail_and_initialized():
# torch.distributed.all_reduce(num_boxes)
# (Niels) in original implementation, num_boxes is divided by get_world_size()
num_boxes = torch.clamp(num_boxes, min=1).item()

# Check that we have initialized the distributed state
world_size = 1
if PartialState._shared_state != {}:
num_boxes = reduce(num_boxes)
world_size = PartialState().num_processes

num_boxes = torch.clamp(num_boxes / world_size, min=1).item()

# Compute all the requested losses
losses = {}
Expand Down