Skip to content

Commit

Permalink
[Fix] Fix amp bug (#2452)
Browse files Browse the repository at this point in the history
* fix amp

* init scale 4096 & fix link

* fix pre-commit

* fix interval

* fix pp & remove fp16
  • Loading branch information
sunjiahao1999 authored May 11, 2023
1 parent 19abb93 commit 106b17e
Show file tree
Hide file tree
Showing 18 changed files with 21 additions and 38 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = './pointpillars_hv_fpn_sbn-all_8xb4-2x_nus-3d.py'
train_dataloader = dict(batch_size=2, num_workers=2)
# schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.)
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = './pointpillars_hv_secfpn_sbn-all_8xb4-2x_nus-3d.py'
train_dataloader = dict(batch_size=2, num_workers=2)
# schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.)
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-3class.py'

# schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.)
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_base_ = 'second_hv_secfpn_8xb6-80e_kitti-3d-car.py'

# schedule settings
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=512.)
optim_wrapper = dict(type='AmpOptimWrapper', loss_scale=4096.)
1 change: 0 additions & 1 deletion mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class BasePointNet(BaseModule, metaclass=ABCMeta):

def __init__(self, init_cfg=None, pretrained=None):
super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
if isinstance(pretrained, str):
Expand Down
29 changes: 15 additions & 14 deletions mmdet3d/models/dense_heads/anchor3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import numpy as np
import torch
from mmdet.models.utils import multi_apply
from mmdet.utils.memory import cast_tensor_type
from mmengine.runner import amp
from torch import Tensor
from torch import nn as nn

Expand Down Expand Up @@ -92,7 +94,6 @@ def __init__(self,
warnings.warn(
'dir_offset and dir_limit_offset will be depressed and be '
'incorporated into box coder in the future')
self.fp16_enabled = False

# build anchor generator
self.prior_generator = TASK_UTILS.build(anchor_generator)
Expand All @@ -112,7 +113,6 @@ def __init__(self,
self.loss_cls = MODELS.build(loss_cls)
self.loss_bbox = MODELS.build(loss_bbox)
self.loss_dir = MODELS.build(loss_dir)
self.fp16_enabled = False

self._init_layers()
self._init_assigner_sampler()
Expand Down Expand Up @@ -411,17 +411,18 @@ class predictions.
num_total_pos + num_total_neg if self.sampling else num_total_pos)

# num_total_samples = None
losses_cls, losses_bbox, losses_dir = multi_apply(
self._loss_by_feat_single,
cls_scores,
bbox_preds,
dir_cls_preds,
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
dir_targets_list,
dir_weights_list,
num_total_samples=num_total_samples)
with amp.autocast(enabled=False):
losses_cls, losses_bbox, losses_dir = multi_apply(
self._loss_by_feat_single,
cast_tensor_type(cls_scores, dst_type=torch.float32),
cast_tensor_type(bbox_preds, dst_type=torch.float32),
cast_tensor_type(dir_cls_preds, dst_type=torch.float32),
labels_list,
label_weights_list,
bbox_targets_list,
bbox_weights_list,
dir_targets_list,
dir_weights_list,
num_total_samples=num_total_samples)
return dict(
loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dir=losses_dir)
1 change: 0 additions & 1 deletion mmdet3d/models/dense_heads/anchor_free_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,6 @@ def __init__(
self.test_cfg = test_cfg
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.fp16_enabled = False
self.background_label = (
num_classes if background_label is None else background_label)
# background_label should be either 0 or num_classes
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/dense_heads/centerpoint_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ def __init__(self,
self.loss_bbox = MODELS.build(loss_bbox)
self.bbox_coder = TASK_UTILS.build(bbox_coder)
self.num_anchor_per_locs = [n for n in num_classes]
self.fp16_enabled = False

# a shared convolution
self.shared_conv = ConvModule(
Expand Down
2 changes: 0 additions & 2 deletions mmdet3d/models/dense_heads/groupfree3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,8 +214,6 @@ def __init__(self,
self.fps_module = Points_Sampler([self.num_proposal])
self.points_obj_cls = PointsObjClsModule(self.in_channels)

self.fp16_enabled = False

# initial candidate prediction
self.conv_pred = BaseConvBboxHead(
**pred_layer_cfg,
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/dense_heads/vote_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def __init__(self,

self.vote_module = VoteModule(**vote_module_cfg)
self.vote_aggregation = build_sa_module(vote_aggregation_cfg)
self.fp16_enabled = False

# Bbox classification and regression
self.conv_pred = BaseConvBboxHead(
Expand Down
2 changes: 0 additions & 2 deletions mmdet3d/models/layers/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def forward(self, input: Tensor) -> Tensor:
Returns:
Tensor: Has shape (N, C) or (N, C, L), same shape as input.
"""
assert input.dtype == torch.float32, \
f'input should be in float32 type, got {input.dtype}'
using_dist = dist.is_available() and dist.is_initialized()
if (not using_dist) or dist.get_world_size() == 1 \
or not self.training:
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/middle_encoders/pillar_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def __init__(self, in_channels, output_shape):
self.ny = output_shape[0]
self.nx = output_shape[1]
self.in_channels = in_channels
self.fp16_enabled = False

def forward(self, voxel_features, coors, batch_size=None):
"""Foraward function to scatter features."""
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/middle_encoders/sparse_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
from mmdet.models.losses import sigmoid_focal_loss, smooth_l1_loss
from mmengine.runner import amp
from torch import Tensor
from torch import nn as nn

Expand Down Expand Up @@ -68,7 +69,6 @@ def __init__(self,
self.encoder_channels = encoder_channels
self.encoder_paddings = encoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
self.return_middle_feats = return_middle_feats
# Spconv init all weight on its own

Expand Down Expand Up @@ -111,6 +111,7 @@ def __init__(self,
indice_key='spconv_down2',
conv_type='SparseConv3d')

@amp.autocast(enabled=False)
def forward(self, voxel_features, coors, batch_size):
"""Forward of SparseEncoder.
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/middle_encoders/sparse_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ def __init__(self,
self.decoder_channels = decoder_channels
self.decoder_paddings = decoder_paddings
self.stage_num = len(self.encoder_channels)
self.fp16_enabled = False
# Spconv init all weight on its own

assert isinstance(order, tuple) and len(order) == 3
Expand Down
1 change: 0 additions & 1 deletion mmdet3d/models/necks/second_fpn.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ def __init__(self,
assert len(out_channels) == len(upsample_strides) == len(in_channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.fp16_enabled = False

deblocks = []
for i, out_channel in enumerate(out_channels):
Expand Down
2 changes: 0 additions & 2 deletions mmdet3d/models/voxel_encoders/pillar_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(self,
self._with_distance = with_distance
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.fp16_enabled = False
# Create PillarFeatureNet layers
self.in_channels = in_channels
feat_channels = [in_channels] + list(feat_channels)
Expand Down Expand Up @@ -209,7 +208,6 @@ def __init__(self,
norm_cfg=norm_cfg,
mode=mode,
legacy=legacy)
self.fp16_enabled = False
feat_channels = [self.in_channels] + list(feat_channels)
pfn_layers = []
# TODO: currently only support one PFNLayer
Expand Down
2 changes: 0 additions & 2 deletions mmdet3d/models/voxel_encoders/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(self,
max_out=True,
cat_max=True):
super(VFELayer, self).__init__()
self.fp16_enabled = False
self.cat_max = cat_max
self.max_out = max_out
# self.units = int(out_channels / 2)
Expand Down Expand Up @@ -127,7 +126,6 @@ def __init__(self,
mode='max'):

super().__init__()
self.fp16_enabled = False
self.name = 'PFNLayer'
self.last_vfe = last_layer
if not self.last_vfe:
Expand Down
4 changes: 0 additions & 4 deletions mmdet3d/models/voxel_encoders/voxel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class HardSimpleVFE(nn.Module):
def __init__(self, num_features: int = 4) -> None:
super(HardSimpleVFE, self).__init__()
self.num_features = num_features
self.fp16_enabled = False

def forward(self, features: Tensor, num_points: Tensor, coors: Tensor,
*args, **kwargs) -> Tensor:
Expand Down Expand Up @@ -62,7 +61,6 @@ def __init__(self,
point_cloud_range=(0, -40, -3, 70.4, 40, 1)):
super(DynamicSimpleVFE, self).__init__()
self.scatter = DynamicScatter(voxel_size, point_cloud_range, True)
self.fp16_enabled = False

@torch.no_grad()
def forward(self, features, coors, *args, **kwargs):
Expand Down Expand Up @@ -141,7 +139,6 @@ def __init__(self,
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats
self.fp16_enabled = False

# Need pillar (voxel) size and x/y offset in order to calculate offset
self.vx = voxel_size[0]
Expand Down Expand Up @@ -340,7 +337,6 @@ def __init__(self,
self._with_cluster_center = with_cluster_center
self._with_voxel_center = with_voxel_center
self.return_point_feats = return_point_feats
self.fp16_enabled = False

# Need pillar (voxel) size and x/y offset to calculate pillar offset
self.vx = voxel_size[0]
Expand Down

0 comments on commit 106b17e

Please sign in to comment.