Skip to content

Commit

Permalink
Add typehint for code under mmdet3d/models/backbones (#2464)
Browse files Browse the repository at this point in the history
* [Docs] Update link of registry tutorial (#2442)

Registry docs are now under Advanced Tutorials.

* add typehint to the backbones of model.

* Update typehint with OptMultiConfig and Tensor to torch.Tensor

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dla.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dla.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update mmdet3d/models/backbones/dgcnn.py

Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>

* Update typehint.

* add typehint to function mono_cam_box2vis and remove assert keyword.

* Update typehint.

* Fixed some problems.

---------

Co-authored-by: pd-michaelstanley <88335018+pd-michaelstanley@users.noreply.github.com>
Co-authored-by: Xiang Xu <xuxiang0103@gmail.com>
  • Loading branch information
3 people authored May 11, 2023
1 parent 52fe5ba commit d99dbce
Show file tree
Hide file tree
Showing 10 changed files with 142 additions and 102 deletions.
10 changes: 8 additions & 2 deletions mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta
from typing import Optional, Tuple

from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.utils import OptMultiConfig


class BasePointNet(BaseModule, metaclass=ABCMeta):
"""Base class for PointNet."""

def __init__(self, init_cfg=None, pretrained=None):
def __init__(self,
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None):
super(BasePointNet, self).__init__(init_cfg)
assert not (init_cfg and pretrained), \
'init_cfg and pretrained cannot be setting at the same time'
Expand All @@ -18,7 +24,7 @@ def __init__(self, init_cfg=None, pretrained=None):
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

@staticmethod
def _split_point_feats(points):
def _split_point_feats(points: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""Split coordinates and features of input points.
Args:
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/cylinder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mmcv.ops import (SparseConv3d, SparseConvTensor, SparseInverseConv3d,
SubMConv3d)
from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType
Expand Down Expand Up @@ -456,7 +457,7 @@ def __init__(self,
indice_key='ddcm',
norm_cfg=norm_cfg)

def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor,
def forward(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> SparseConvTensor:
"""Forward pass."""
coors = coors.int()
Expand Down
23 changes: 14 additions & 9 deletions mmdet3d/models/backbones/dgcnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn

from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
Expand All @@ -30,14 +34,15 @@ class DGCNNBackbone(BaseModule):
"""

def __init__(self,
in_channels,
num_samples=(20, 20, 20),
knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
radius=(None, None, None),
gf_channels=((64, 64), (64, 64), (64, )),
fa_channels=(1024, ),
act_cfg=dict(type='ReLU'),
init_cfg=None):
in_channels: int,
num_samples: Sequence[int] = (20, 20, 20),
knn_modes: Sequence[str] = ('D-KNN', 'F-KNN', 'F-KNN'),
radius: Sequence[Union[float, None]] = (None, None, None),
gf_channels: Sequence[Sequence[int]] = ((64, 64), (64, 64),
(64, )),
fa_channels: Sequence[int] = (1024, ),
act_cfg: ConfigType = dict(type='ReLU'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_gf = len(gf_channels)

Expand Down Expand Up @@ -71,7 +76,7 @@ def __init__(self,
self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg)

def forward(self, points):
def forward(self, points: Tensor) -> dict:
"""Forward pass.
Args:
Expand Down
109 changes: 58 additions & 51 deletions mmdet3d/models/backbones/dla.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Sequence, Tuple

import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig


def dla_build_norm_layer(cfg, num_features):
def dla_build_norm_layer(cfg: ConfigType,
num_features: int) -> Tuple[str, nn.Module]:
"""Build normalization layer specially designed for DLANet.
Args:
Expand Down Expand Up @@ -53,13 +56,13 @@ class BasicBlock(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
dilation=1,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
dilation: int = 1,
init_cfg: OptMultiConfig = None):
super(BasicBlock, self).__init__(init_cfg)
self.conv1 = build_conv_layer(
conv_cfg,
Expand All @@ -84,7 +87,7 @@ def __init__(self,
self.norm2 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.stride = stride

def forward(self, x, identity=None):
def forward(self, x: Tensor, identity: Optional[Tensor] = None) -> Tensor:
"""Forward function."""

if identity is None:
Expand Down Expand Up @@ -117,13 +120,13 @@ class Root(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
kernel_size,
add_identity,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
kernel_size: int,
add_identity: bool,
init_cfg: OptMultiConfig = None):
super(Root, self).__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg,
Expand All @@ -137,7 +140,7 @@ def __init__(self,
self.relu = nn.ReLU(inplace=True)
self.add_identity = add_identity

def forward(self, feat_list):
def forward(self, feat_list: List[Tensor]) -> Tensor:
"""Forward function.
Args:
Expand Down Expand Up @@ -181,19 +184,19 @@ class Tree(BaseModule):
"""

def __init__(self,
levels,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
level_root=False,
root_dim=None,
root_kernel_size=1,
dilation=1,
add_identity=False,
init_cfg=None):
levels: int,
block: nn.Module,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
level_root: bool = False,
root_dim: Optional[int] = None,
root_kernel_size: int = 1,
dilation: int = 1,
add_identity: bool = False,
init_cfg: OptMultiConfig = None):
super(Tree, self).__init__(init_cfg)
if root_dim is None:
root_dim = 2 * out_channels
Expand Down Expand Up @@ -258,7 +261,10 @@ def __init__(self,
bias=False),
dla_build_norm_layer(norm_cfg, out_channels)[1])

def forward(self, x, identity=None, children=None):
def forward(self,
x: Tensor,
identity: Optional[Tensor] = None,
children: Optional[List[Tensor]] = None) -> Tensor:
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
identity = self.project(bottom) if self.project else bottom
Expand Down Expand Up @@ -302,16 +308,17 @@ class DLANet(BaseModule):
}

def __init__(self,
depth,
in_channels=3,
out_indices=(0, 1, 2, 3, 4, 5),
frozen_stages=-1,
norm_cfg=None,
conv_cfg=None,
layer_with_level_root=(False, True, True, True),
with_identity_root=False,
pretrained=None,
init_cfg=None):
depth: int,
in_channels: int = 3,
out_indices: Sequence[int] = (0, 1, 2, 3, 4, 5),
frozen_stages: int = -1,
norm_cfg: OptConfigType = None,
conv_cfg: OptConfigType = None,
layer_with_level_root: Sequence[bool] = (False, True, True,
True),
with_identity_root: bool = False,
pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None):
super(DLANet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalida depth {depth} for DLA')
Expand Down Expand Up @@ -380,13 +387,13 @@ def __init__(self,
self._freeze_stages()

def _make_conv_level(self,
in_channels,
out_channels,
num_convs,
norm_cfg,
conv_cfg,
stride=1,
dilation=1):
in_channels: int,
out_channels: int,
num_convs: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
dilation: int = 1) -> nn.Sequential:
"""Conv modules.
Args:
Expand Down Expand Up @@ -418,7 +425,7 @@ def _make_conv_level(self,
in_channels = out_channels
return nn.Sequential(*modules)

def _freeze_stages(self):
def _freeze_stages(self) -> None:
if self.frozen_stages >= 0:
self.base_layer.eval()
for param in self.base_layer.parameters():
Expand All @@ -436,7 +443,7 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
outs = []
x = self.base_layer(x)
for i in range(self.num_levels):
Expand Down
27 changes: 15 additions & 12 deletions mmdet3d/models/backbones/multi_backbone.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import nn as nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
Expand All @@ -27,16 +29,17 @@ class MultiBackbone(BaseModule):
"""

def __init__(self,
num_streams,
backbones,
aggregation_mlp_channels=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs):
num_streams: int,
backbones: Union[List[dict], Dict],
aggregation_mlp_channels: Optional[Sequence[int]] = None,
conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg: ConfigType = dict(
type='BN1d', eps=1e-5, momentum=0.01),
act_cfg: ConfigType = dict(type='ReLU'),
suffixes: Tuple[str] = ('net0', 'net1'),
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None,
**kwargs) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
Expand Down Expand Up @@ -89,7 +92,7 @@ def __init__(self,
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

def forward(self, points):
def forward(self, points: Tensor) -> dict:
"""Forward pass.
Args:
Expand Down
15 changes: 12 additions & 3 deletions mmdet3d/models/backbones/nostem_regnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch.nn as nn
from mmdet.models.backbones import RegNet
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import OptMultiConfig


@MODELS.register_module()
Expand Down Expand Up @@ -59,15 +64,19 @@ class NoStemRegNet(RegNet):
(1, 1008, 1, 1)
"""

def __init__(self, arch, init_cfg=None, **kwargs):
def __init__(self,
arch: dict,
init_cfg: OptMultiConfig = None,
**kwargs) -> None:
super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)

def _make_stem_layer(self, in_channels, base_channels):
def _make_stem_layer(self, in_channels: int,
base_channels: int) -> nn.Module:
"""Override the original function that do not initialize a stem layer
since 3D detector's voxel encoder works like a stem layer."""
return

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function of backbone.
Args:
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch import Tensor, nn

from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self,
bias=True))
sa_in_channel = cur_aggregation_channel

def forward(self, points: torch.Tensor):
def forward(self, points: Tensor):
"""Forward pass.
Args:
Expand Down
Loading

0 comments on commit d99dbce

Please sign in to comment.