From 62feea59e8a6d45c0de13bc9b592646e1e5cb4ca Mon Sep 17 00:00:00 2001 From: jbwang1997 Date: Wed, 9 Mar 2022 19:50:36 +0800 Subject: [PATCH] [Fix] Force the inputs of `get_bboxes` in yolox_head to float32. (#7324) * Fix softnms bug * Add force_fp32 in corner_head and centripetal_head --- mmdet/models/dense_heads/centripetal_head.py | 3 +++ mmdet/models/dense_heads/corner_head.py | 5 ++++- mmdet/models/dense_heads/yolox_head.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/mmdet/models/dense_heads/centripetal_head.py b/mmdet/models/dense_heads/centripetal_head.py index fe85794e87a..ebc721b7623 100644 --- a/mmdet/models/dense_heads/centripetal_head.py +++ b/mmdet/models/dense_heads/centripetal_head.py @@ -2,6 +2,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule, normal_init from mmcv.ops import DeformConv2d +from mmcv.runner import force_fp32 from mmdet.core import multi_apply from ..builder import HEADS, build_loss @@ -203,6 +204,7 @@ def forward_single(self, x, lvl_ind): ] return result_list + @force_fp32() def loss(self, tl_heats, br_heats, @@ -361,6 +363,7 @@ def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift, return det_loss, off_loss, guiding_loss, centripetal_loss + @force_fp32() def get_bboxes(self, tl_heats, br_heats, diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py index 327094bad67..c6a2866f94a 100644 --- a/mmdet/models/dense_heads/corner_head.py +++ b/mmdet/models/dense_heads/corner_head.py @@ -6,7 +6,7 @@ import torch.nn as nn from mmcv.cnn import ConvModule, bias_init_with_prob from mmcv.ops import CornerPool, batched_nms -from mmcv.runner import BaseModule +from mmcv.runner import BaseModule, force_fp32 from mmdet.core import multi_apply from ..builder import HEADS, build_loss @@ -152,6 +152,7 @@ def __init__(self, self.train_cfg = train_cfg self.test_cfg = test_cfg + self.fp16_enabled = False self._init_layers() def _make_layers(self, out_channels, in_channels=256, feat_channels=256): @@ -509,6 +510,7 @@ def get_targets(self, return target_result + @force_fp32() def loss(self, tl_heats, br_heats, @@ -649,6 +651,7 @@ def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off, return det_loss, pull_loss, push_loss, off_loss + @force_fp32() def get_bboxes(self, tl_heats, br_heats, diff --git a/mmdet/models/dense_heads/yolox_head.py b/mmdet/models/dense_heads/yolox_head.py index a1811c9415d..de3f93ccd36 100644 --- a/mmdet/models/dense_heads/yolox_head.py +++ b/mmdet/models/dense_heads/yolox_head.py @@ -212,6 +212,7 @@ def forward(self, feats): self.multi_level_conv_reg, self.multi_level_conv_obj) + @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'objectnesses')) def get_bboxes(self, cls_scores, bbox_preds,