Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add typehint to mmdet3d/models/backbones. #2464

Merged
merged 18 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions mmdet3d/models/backbones/base_pointnet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from abc import ABCMeta
from typing import Optional, Tuple

import torch
from mmengine.model import BaseModule

from mmdet3d.utils import OptMultiConfig


class BasePointNet(BaseModule, metaclass=ABCMeta):
"""Base class for PointNet."""

def __init__(self, init_cfg=None, pretrained=None):
def __init__(self,
init_cfg: OptMultiConfig = None,
pretrained: Optional[str] = None):
super(BasePointNet, self).__init__(init_cfg)
self.fp16_enabled = False
assert not (init_cfg and pretrained), \
Expand All @@ -19,7 +25,9 @@ def __init__(self, init_cfg=None, pretrained=None):
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

@staticmethod
def _split_point_feats(points):
def _split_point_feats(
points: torch.Tensor
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
"""Split coordinates and features of input points.

Args:
Expand Down
23 changes: 14 additions & 9 deletions mmdet3d/models/backbones/dgcnn.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Sequence, Union

import torch
from mmengine.model import BaseModule
from torch import nn as nn

from mmdet3d.models.layers import DGCNNFAModule, DGCNNGFModule
from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


@MODELS.register_module()
Expand All @@ -30,14 +34,15 @@ class DGCNNBackbone(BaseModule):
"""

def __init__(self,
in_channels,
num_samples=(20, 20, 20),
knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
radius=(None, None, None),
gf_channels=((64, 64), (64, 64), (64, )),
fa_channels=(1024, ),
act_cfg=dict(type='ReLU'),
init_cfg=None):
in_channels: int,
num_samples: Sequence[int] = (20, 20, 20),
knn_modes: Sequence[str] = ('D-KNN', 'F-KNN', 'F-KNN'),
radius: Sequence[Union[float, None]] = (None, None, None),
gf_channels: Sequence[Sequence[int]] = ((64, 64), (64, 64),
(64, )),
fa_channels: Sequence[int] = (1024, ),
act_cfg: ConfigType = dict(type='ReLU'),
init_cfg: OptMultiConfig = None):
super().__init__(init_cfg=init_cfg)
self.num_gf = len(gf_channels)

Expand Down Expand Up @@ -71,7 +76,7 @@ def __init__(self,
self.FA_module = DGCNNFAModule(
mlp_channels=cur_fa_mlps, act_cfg=act_cfg)

def forward(self, points):
def forward(self, points: torch.Tensor) -> dict:
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
"""Forward pass.

Args:
Expand Down
108 changes: 57 additions & 51 deletions mmdet3d/models/backbones/dla.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from typing import Any, Dict, List, Optional, Tuple

import torch
from mmcv.cnn import build_conv_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS
from mmdet3d.utils import ConfigType, OptMultiConfig


def dla_build_norm_layer(cfg, num_features):
def dla_build_norm_layer(cfg: ConfigType,
num_features: int) -> Tuple[str, nn.Module]:
"""Build normalization layer specially designed for DLANet.

Args:
Expand Down Expand Up @@ -53,13 +56,13 @@ class BasicBlock(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
dilation=1,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: ConfigType,
conv_cfg: Dict[str, Any],
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
stride: int = 1,
dilation: int = 1,
init_cfg: OptMultiConfig = None):
super(BasicBlock, self).__init__(init_cfg)
self.conv1 = build_conv_layer(
conv_cfg,
Expand All @@ -84,7 +87,7 @@ def __init__(self,
self.norm2 = dla_build_norm_layer(norm_cfg, out_channels)[1]
self.stride = stride

def forward(self, x, identity=None):
def forward(self, x: Tensor, identity: Optional[Tensor] = None) -> Tensor:
"""Forward function."""

if identity is None:
Expand Down Expand Up @@ -117,13 +120,13 @@ class Root(BaseModule):
"""

def __init__(self,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
kernel_size,
add_identity,
init_cfg=None):
in_channels: int,
out_channels: int,
norm_cfg: Dict[str, Any],
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
conv_cfg: Dict[str, Any],
Xiangxu-0103 marked this conversation as resolved.
Show resolved Hide resolved
kernel_size: int,
add_identity: bool,
init_cfg: OptMultiConfig = None):
super(Root, self).__init__(init_cfg)
self.conv = build_conv_layer(
conv_cfg,
Expand All @@ -137,7 +140,7 @@ def __init__(self,
self.relu = nn.ReLU(inplace=True)
self.add_identity = add_identity

def forward(self, feat_list):
def forward(self, feat_list: List[Tensor]) -> Tensor:
"""Forward function.

Args:
Expand Down Expand Up @@ -181,19 +184,19 @@ class Tree(BaseModule):
"""

def __init__(self,
levels,
block,
in_channels,
out_channels,
norm_cfg,
conv_cfg,
stride=1,
level_root=False,
root_dim=None,
root_kernel_size=1,
dilation=1,
add_identity=False,
init_cfg=None):
levels: int,
block: nn.Module,
in_channels: int,
out_channels: int,
norm_cfg: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict -> ConfigType

conv_cfg: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict -> ConfigType

stride: int = 1,
level_root: bool = False,
root_dim: Optional[int] = None,
root_kernel_size: int = 1,
dilation: int = 1,
add_identity: bool = False,
init_cfg: OptMultiConfig = None):
super(Tree, self).__init__(init_cfg)
if root_dim is None:
root_dim = 2 * out_channels
Expand Down Expand Up @@ -258,7 +261,10 @@ def __init__(self,
bias=False),
dla_build_norm_layer(norm_cfg, out_channels)[1])

def forward(self, x, identity=None, children=None):
def forward(self,
x: Tensor,
identity: Optional[Tensor] = None,
children: Optional[List[Tensor]] = None) -> Tensor:
children = [] if children is None else children
bottom = self.downsample(x) if self.downsample else x
identity = self.project(bottom) if self.project else bottom
Expand Down Expand Up @@ -302,16 +308,16 @@ class DLANet(BaseModule):
}

def __init__(self,
depth,
in_channels=3,
out_indices=(0, 1, 2, 3, 4, 5),
frozen_stages=-1,
norm_cfg=None,
conv_cfg=None,
layer_with_level_root=(False, True, True, True),
with_identity_root=False,
pretrained=None,
init_cfg=None):
depth: int,
in_channels: int = 3,
out_indices: Tuple[int] = (0, 1, 2, 3, 4, 5),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence[int]

frozen_stages: int = -1,
norm_cfg: OptMultiConfig = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptConfigType

conv_cfg: OptMultiConfig = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptConfigType

layer_with_level_root: List[bool] = (False, True, True, True),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sequence[bool]

with_identity_root: bool = False,
pretrained: Optional[str] = None,
init_cfg: OptMultiConfig = None):
super(DLANet, self).__init__(init_cfg)
if depth not in self.arch_settings:
raise KeyError(f'invalida depth {depth} for DLA')
Expand Down Expand Up @@ -380,13 +386,13 @@ def __init__(self,
self._freeze_stages()

def _make_conv_level(self,
in_channels,
out_channels,
num_convs,
norm_cfg,
conv_cfg,
stride=1,
dilation=1):
in_channels: int,
out_channels: int,
num_convs: int,
norm_cfg: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict -> ConfigType

conv_cfg: dict,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dict ->ConfigType

stride: int = 1,
dilation: int = 1) -> nn.Sequential:
"""Conv modules.

Args:
Expand Down Expand Up @@ -418,7 +424,7 @@ def _make_conv_level(self,
in_channels = out_channels
return nn.Sequential(*modules)

def _freeze_stages(self):
def _freeze_stages(self) -> None:
if self.frozen_stages >= 0:
self.base_layer.eval()
for param in self.base_layer.parameters():
Expand All @@ -436,7 +442,7 @@ def _freeze_stages(self):
for param in m.parameters():
param.requires_grad = False

def forward(self, x):
def forward(self, x: Tensor) -> Tuple[Tensor, ...]:
outs = []
x = self.base_layer(x)
for i in range(self.num_levels):
Expand Down
23 changes: 12 additions & 11 deletions mmdet3d/models/backbones/multi_backbone.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import warnings
from typing import Any, Dict, List, Optional, Tuple, Union

import torch
from mmcv.cnn import ConvModule
Expand All @@ -27,16 +28,16 @@ class MultiBackbone(BaseModule):
"""

def __init__(self,
num_streams,
backbones,
aggregation_mlp_channels=None,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg=dict(type='ReLU'),
suffixes=('net0', 'net1'),
init_cfg=None,
pretrained=None,
**kwargs):
num_streams: int,
backbones: Union[List, Dict],
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
aggregation_mlp_channels: Optional[List[int]] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optional[Sequence[int]]

conv_cfg: Dict = dict(type='Conv1d'),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ConfigType. The same below.

norm_cfg: Dict = dict(type='BN1d', eps=1e-5, momentum=0.01),
act_cfg: Dict = dict(type='ReLU'),
suffixes: Tuple = ('net0', 'net1'),
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
init_cfg: Optional[Dict] = None,
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
pretrained: Optional[str] = None,
**kwargs: Any) -> None:
Lum1104 marked this conversation as resolved.
Show resolved Hide resolved
super().__init__(init_cfg=init_cfg)
assert isinstance(backbones, dict) or isinstance(backbones, list)
if isinstance(backbones, dict):
Expand Down Expand Up @@ -89,7 +90,7 @@ def __init__(self,
'please use "init_cfg" instead')
self.init_cfg = dict(type='Pretrained', checkpoint=pretrained)

def forward(self, points):
def forward(self, points: torch.Tensor) -> Dict[str, List[torch.Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def forward(self, points: torch.Tensor) -> Dict[str, List[torch.Tensor]]:
def forward(self, points: Tensor) -> dict:

"""Forward pass.

Args:
Expand Down
14 changes: 11 additions & 3 deletions mmdet3d/models/backbones/nostem_regnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, Optional, Tuple, Union

import torch
import torch.nn as nn
from mmdet.models.backbones import RegNet

from mmdet3d.registry import MODELS
Expand Down Expand Up @@ -59,15 +63,19 @@ class NoStemRegNet(RegNet):
(1, 1008, 1, 1)
"""

def __init__(self, arch, init_cfg=None, **kwargs):
def __init__(self,
arch: Dict[str, Union[int, float]],
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
init_cfg: Optional[Dict] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptMultiConfig

**kwargs: Any) -> None:
Lum1104 marked this conversation as resolved.
Show resolved Hide resolved
super(NoStemRegNet, self).__init__(arch, init_cfg=init_cfg, **kwargs)

def _make_stem_layer(self, in_channels, base_channels):
def _make_stem_layer(self, in_channels: int,
base_channels: int) -> nn.Module:
"""Override the original function that do not initialize a stem layer
since 3D detector's voxel encoder works like a stem layer."""
return

def forward(self, x):
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor -> Tensor

"""Forward function of backbone.

Args:
Expand Down
26 changes: 15 additions & 11 deletions mmdet3d/models/backbones/pointnet2_sa_ssg.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Dict, List, Optional, Tuple

import torch
from torch import nn as nn

Expand Down Expand Up @@ -31,20 +33,22 @@ class PointNet2SASSG(BasePointNet):
"""

def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radius=(0.2, 0.4, 0.8, 1.2),
num_samples=(64, 32, 16, 16),
sa_channels=((64, 64, 128), (128, 128, 256), (128, 128, 256),
(128, 128, 256)),
fp_channels=((256, 256), (256, 256)),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
in_channels: int,
num_points: Tuple[int] = (2048, 1024, 512, 256),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tuple ->Sequence

radius: Tuple[float] = (0.2, 0.4, 0.8, 1.2),
num_samples: Tuple[int] = (64, 32, 16, 16),
sa_channels: Tuple[Tuple[int]] = ((64, 64, 128), (128, 128,
256),
(128, 128, 256), (128, 128,
256)),
fp_channels: Tuple[Tuple[int]] = ((256, 256), (256, 256)),
norm_cfg: dict = dict(type='BN2d'),
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
sa_cfg: dict = dict(
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
type='PointSAModule',
pool_mod='max',
use_xyz=True,
normalize_xyz=True),
init_cfg=None):
init_cfg: Optional[Dict[str, Any]] = None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OptMultiConfig

super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels)
self.num_fp = len(fp_channels)
Expand Down Expand Up @@ -85,7 +89,7 @@ def __init__(self,
fp_source_channel = cur_fp_mlps[-1]
fp_target_channel = skip_channel_list.pop()

def forward(self, points):
def forward(self, points: torch.Tensor) -> Dict[str, List[torch.Tensor]]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Tensor -> Tensor

"""Forward pass.

Args:
Expand Down
Loading