Skip to content

Make R-CNN models support Automatic Mixed Precision (AMP) #2222

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Okery opened this issue May 16, 2020 · 1 comment
Closed

Make R-CNN models support Automatic Mixed Precision (AMP) #2222

Okery opened this issue May 16, 2020 · 1 comment

Comments

@Okery
Copy link

Okery commented May 16, 2020

🚀 Feature

Now PyTorch 1.6.0 has torch.cuda.amp.autocast, I think we can make R-CNN models support Automatic Mixed Precision (AMP).

Motivation

When AMP is enabled, the training speed may increase ~20% on GPUs that support FP16.

Alternatives

There are 2 modifications:

  • In torchvision/ops/roi_align.py, function roi_align
    rois' datatype should be the same as input's datatype. So I replace
    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
                                           output_size[0], output_size[1],
                                           sampling_ratio, aligned)

with

    check_roi_boxes_shape(boxes)
    rois = boxes
    output_size = _pair(output_size)
    if not isinstance(rois, torch.Tensor):
        rois = convert_boxes_to_roi_format(rois)
    rois = rois.to(input.dtype)
    return torch.ops.torchvision.roi_align(input, rois, spatial_scale,
                                           output_size[0], output_size[1],
                                           sampling_ratio, aligned)
  • In torchvision/models/detection/_utils.py, function encode_boxes
    I'm confused about the decorator torch.jit.script. I printed the proposals' datatype, but the output was 6, instead of torch.float32 or torch.float16. So I removed the decorator.

Additional context

I tested the speed of maskrcnn_resnet50_fpn with and without autocast().
Dataset: VOC 2012 Segmentation, train 1463 images, val 1444 images.

GPU train/test FPS without AMP train/test FPS with AMP increase
2080 Ti 7.4/13.3 9.3/15.2 25.2%/14.6%
1080 Ti 5.1/8.8 4.2/7.3 --
@fmassa
Copy link
Member

fmassa commented Oct 21, 2020

This has been added in the 0.7.0 release of torchvision in #2384

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants