Skip to content

Commit

Permalink
torch.cuda.amp bug fix (ultralytics#2750)
Browse files Browse the repository at this point in the history
PR ultralytics#2725 introduced a very specific bug that only affects multi-GPU trainings. Apparently the cause was using the torch.cuda.amp decorator in the autoShape forward method. I've implemented amp more traditionally in this PR, and the bug is resolved.

(cherry picked from commit b5de52c)
  • Loading branch information
glenn-jocher authored and Lechtr committed Apr 12, 2021
1 parent bfe2fb8 commit e28ebae
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import torch
import torch.nn as nn
from PIL import Image
from torch.cuda import amp

from utils.datasets import letterbox
from utils.general import non_max_suppression, make_divisible, scale_coords, increment_path, xyxy2xywh
Expand Down Expand Up @@ -237,7 +238,6 @@ def autoshape(self):
return self

@torch.no_grad()
@torch.cuda.amp.autocast(torch.cuda.is_available())
def forward(self, imgs, size=640, augment=False, profile=False):
# Inference from various sources. For height=640, width=1280, RGB images example inputs are:
# filename: imgs = 'data/samples/zidane.jpg'
Expand All @@ -251,7 +251,8 @@ def forward(self, imgs, size=640, augment=False, profile=False):
t = [time_synchronized()]
p = next(self.model.parameters()) # for device and type
if isinstance(imgs, torch.Tensor): # torch
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference
with amp.autocast(enabled=p.device.type != 'cpu'):
return self.model(imgs.to(p.device).type_as(p), augment, profile) # inference

# Pre-process
n, imgs = (len(imgs), imgs) if isinstance(imgs, list) else (1, [imgs]) # number of images, list of images
Expand All @@ -278,17 +279,18 @@ def forward(self, imgs, size=640, augment=False, profile=False):
x = torch.from_numpy(x).to(p.device).type_as(p) / 255. # uint8 to fp16/32
t.append(time_synchronized())

# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())
with amp.autocast(enabled=p.device.type != 'cpu'):
# Inference
y = self.model(x, augment, profile)[0] # forward
t.append(time_synchronized())

# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])
# Post-process
y = non_max_suppression(y, conf_thres=self.conf, iou_thres=self.iou, classes=self.classes) # NMS
for i in range(n):
scale_coords(shape1, y[i][:, :4], shape0[i])

t.append(time_synchronized())
return Detections(imgs, y, files, t, self.names, x.shape)
t.append(time_synchronized())
return Detections(imgs, y, files, t, self.names, x.shape)


class Detections:
Expand Down

0 comments on commit e28ebae

Please sign in to comment.