Skip to content

Commit

Permalink
[Enhance] GroupFree3d inherits BaseModule From MMCV (#704)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiliu8006 authored Jul 21, 2021
1 parent 26c1807 commit cbf194f
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions mmdet3d/models/dense_heads/groupfree3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import numpy as np
import torch
from mmcv import ConfigDict
from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer)
from mmcv.runner import force_fp32
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from torch.nn import functional as F

Expand All @@ -19,7 +19,7 @@
EPS = 1e-6


class PointsObjClsModule(nn.Module):
class PointsObjClsModule(BaseModule):
"""object candidate point prediction from seed point features.
Args:
Expand All @@ -39,8 +39,9 @@ def __init__(self,
num_convs=3,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU')):
super().__init__()
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
conv_channels = [in_channel for _ in range(num_convs - 1)]
conv_channels.append(1)

Expand Down Expand Up @@ -104,7 +105,7 @@ def forward(self, xyz, features, sample_inds):


@HEADS.register_module()
class GroupFree3DHead(nn.Module):
class GroupFree3DHead(BaseModule):
r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.
Args:
Expand Down Expand Up @@ -162,8 +163,9 @@ def __init__(self,
size_class_loss=None,
size_res_loss=None,
size_reg_loss=None,
semantic_loss=None):
super(GroupFree3DHead, self).__init__()
semantic_loss=None,
init_cfg=None):
super(GroupFree3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
Expand Down Expand Up @@ -251,15 +253,13 @@ def init_weights(self):
# initialize transformer
for m in self.decoder_layers.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)

xavier_init(m, distribution='uniform')
for m in self.decoder_self_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)

xavier_init(m, distribution='uniform')
for m in self.decoder_cross_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
xavier_init(m, distribution='uniform')

def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
Expand Down

0 comments on commit cbf194f

Please sign in to comment.