Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add typehint to mmdet3d/models/backbones. #2464

Merged
merged 18 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta
from typing import Optional, Tuple

from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.utils import OptMultiConfig


class BasePointNet(BaseModule, metaclass=ABCMeta):
"""Base class for PointNet."""

def __init__(self, init_cfg=None, pretrained=None):
def __init__(self,
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None):
super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False
assert not (init_cfg and pretrained), \
Expand All @@ -19,7 +25,7 @@ def __init__(self, init_cfg=None, pretrained=None):
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

@staticmethod
def _split_point_feats(points):
def _split_point_feats(points: Tensor) -> Tuple[Tensor, Optional[Tensor]]:
"""Split coordinates and features of input points.

Args:
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/cylinder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mmcv.ops import (SparseConv3d, SparseConvTensor, SparseInverseConv3d,
SubMConv3d)
from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType
Expand Down Expand Up @@ -456,7 +457,7 @@ def __init__(self,
indice_key='ddcm',
norm_cfg=norm_cfg)

def forward(self, voxel_features: torch.Tensor, coors: torch.Tensor,
def forward(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> SparseConvTensor:
"""Forward pass."""
coors = coors.int()
Expand Down
23 changes: 14 additions & 9 deletions mmdet3d/models/backbones/dgcnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

from mmengine.model import BaseModule
from torch import Tensor
from torch import nn as nn

from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
Expand All @@ -30,14 +34,15 @@ class DGCNNBackbone(BaseModule):
"""

def __init__(self,
in_channels,
num_samples=(20, 20, 20),
knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
radius=(None, None, None),
gf_channels=((64, 64), (64, 64), (64, )),
fa_channels=(1024, ),
act_cfg=dict(type='ReLU'),
init_cfg=None):
in_channels: int,
num_samples: Sequence[int] = (20, 20, 20),
knn_modes: Sequence[str] = ('D-KNN', 'F-KNN', 'F-KNN'),
radius: Sequence[Union[float, None]] = (None, None, None),
gf_channels: Sequence[Sequence[int]] = ((64, 64), (64, 64),
(64, )),
fa_channels: Sequence[int] = (1024, ),
act_cfg: ConfigType = dict(type='ReLU'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_gf = len(gf_channels)

Expand Down Expand Up @@ -71,7 +76,7 @@ def __init__(self,
self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg)

def forward(self, points):
def forward(self, points: Tensor) -> dict:
"""Forward pass.

Args:
Expand Down
109 changes: 58 additions & 51 deletions mmdet3d/models/backbones/dla.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import List, Optional, Sequence, Tuple

import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptConfigType, OptMultiConfig


def dla_build_norm_layer(cfg, num_features):
def dla_build_norm_layer(cfg: ConfigType,
num_features: int) -> Tuple[str, nn.Module]:
"""Build normalization layer specially designed for DLANet.

Args:
Expand Down Expand Up @@ -53,13 +56,13 @@ class BasicBlock(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
dilation=1,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
dilation: int = 1,
init_cfg: OptMultiConfig = None):
super(BasicBlock, self).__init__(init_cfg)
self.conv1 = build_conv_layer(
conv_cfg,
Expand All @@ -84,7 +87,7 @@ def __init__(self,
self.norm2 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.stride = stride

def forward(self, x, identity=None):
def forward(self, x: Tensor, identity: Optional[Tensor] = None) -> Tensor:
"""Forward function."""

if identity is None:
Expand Down Expand Up @@ -117,13 +120,13 @@ class Root(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
kernel_size,
add_identity,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
kernel_size: int,
add_identity: bool,
init_cfg: OptMultiConfig = None):
super(Root, self).__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg,
Expand All @@ -137,7 +140,7 @@ def __init__(self,
self.relu = nn.ReLU(inplace=True)
self.add_identity = add_identity

def forward(self, feat_list):
def forward(self, feat_list: List[Tensor]) -> Tensor:
"""Forward function.

Args:
Expand Down Expand Up @@ -181,19 +184,19 @@ class Tree(BaseModule):
"""

def __init__(self,
levels,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
level_root=False,
root_dim=None,
root_kernel_size=1,
dilation=1,
add_identity=False,
init_cfg=None):
levels: int,
block: nn.Module,
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
level_root: bool = False,
root_dim: Optional[int] = None,
root_kernel_size: int = 1,
dilation: int = 1,
add_identity: bool = False,
init_cfg: OptMultiConfig = None):
super(Tree, self).__init__(init_cfg)
if root_dim is None:
root_dim = 2 * out_channels
Expand Down Expand Up @@ -258,7 +261,10 @@ def __init__(self,
bias=False),
dla_build_norm_layer(norm_cfg, out_channels)[1])

def forward(self, x, identity=None, children=None):
def forward(self,
x: Tensor,
identity: Optional[Tensor] = None,
children: Optional[List[Tensor]] = None) -> Tensor:
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
identity = self.project(bottom) if self.project else bottom
Expand Down Expand Up @@ -302,16 +308,17 @@ class DLANet(BaseModule):
}

def __init__(self,
depth,
in_channels=3,
out_indices=(0, 1, 2, 3, 4, 5),
frozen_stages=-1,
norm_cfg=None,
conv_cfg=None,
layer_with_level_root=(False, True, True, True),
with_identity_root=False,
pretrained=None,
init_cfg=None):
depth: int,
in_channels: int = 3,
out_indices: Sequence[int] = (0, 1, 2, 3, 4, 5),
frozen_stages: int = -1,
norm_cfg: OptConfigType = None,
conv_cfg: OptConfigType = None,
layer_with_level_root: Sequence[bool] = (False, True, True,
True),
with_identity_root: bool = False,
pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None):
super(DLANet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalida depth {depth} for DLA')
Expand Down Expand Up @@ -380,13 +387,13 @@ def __init__(self,
self._freeze_stages()

def _make_conv_level(self,
in_channels,
out_channels,
num_convs,
norm_cfg,
conv_cfg,
stride=1,
dilation=1):
in_channels: int,
out_channels: int,
num_convs: int,
norm_cfg: ConfigType,
conv_cfg: ConfigType,
stride: int = 1,
dilation: int = 1) -> nn.Sequential:
"""Conv modules.

Args:
Expand Down Expand Up @@ -418,7 +425,7 @@ def _make_conv_level(self,
in_channels = out_channels
return nn.Sequential(*modules)

def _freeze_stages(self):
def _freeze_stages(self) -> None:
if self.frozen_stages >= 0:
self.base_layer.eval()
for param in self.base_layer.parameters():
Expand All @@ -436,7 +443,7 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
outs = []
x = self.base_layer(x)
for i in range(self.num_levels):
Expand Down
27 changes: 15 additions & 12 deletions mmdet3d/models/backbones/multi_backbone.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import nn as nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
Expand All @@ -27,16 +29,17 @@ class MultiBackbone(BaseModule):
"""

def __init__(self,
num_streams,
backbones,
aggregation_mlp_channels=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs):
num_streams: int,
backbones: Union[List[dict], Dict],
aggregation_mlp_channels: Optional[Sequence[int]] = None,
conv_cfg: ConfigType = dict(type='Conv1d'),
norm_cfg: ConfigType = dict(
type='BN1d', eps=1e-5, momentum=0.01),
act_cfg: ConfigType = dict(type='ReLU'),
suffixes: Tuple[str] = ('net0', 'net1'),
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None,
**kwargs) -> None:
super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
Expand Down Expand Up @@ -89,7 +92,7 @@ def __init__(self,
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

def forward(self, points):
def forward(self, points: Tensor) -> dict:
"""Forward pass.

Args:
Expand Down
15 changes: 12 additions & 3 deletions mmdet3d/models/backbones/nostem_regnet.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch.nn as nn
from mmdet.models.backbones import RegNet
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import OptMultiConfig


@MODELS.register_module()
Expand Down Expand Up @@ -59,15 +64,19 @@ class NoStemRegNet(RegNet):
(1, 1008, 1, 1)
"""

def __init__(self, arch, init_cfg=None, **kwargs):
def __init__(self,
arch: dict,
init_cfg: OptMultiConfig = None,
**kwargs) -> None:
super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)

def _make_stem_layer(self, in_channels, base_channels):
def _make_stem_layer(self, in_channels: int,
base_channels: int) -> nn.Module:
"""Override the original function that do not initialize a stem layer
since 3D detector's voxel encoder works like a stem layer."""
return

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
"""Forward function of backbone.

Args:
Expand Down
4 changes: 2 additions & 2 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from mmcv.cnn import ConvModule
from torch import nn as nn
from torch import Tensor, nn

from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS
Expand Down Expand Up @@ -141,7 +141,7 @@ def __init__(self,
bias=True))
sa_in_channel = cur_aggregation_channel

def forward(self, points: torch.Tensor):
def forward(self, points: Tensor):
"""Forward pass.

Args:
Expand Down
Loading