Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Update pre-commit-config-zh-cn.yaml and add typehints for PointNet2SAMSG #2396

Merged
merged 1 commit into from
Apr 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 4 additions & 10 deletions .pre-commit-config-zh-cn.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
exclude: ^tests/data/
ZwwWayne marked this conversation as resolved.
Show resolved Hide resolved
repos:
- repo: https://gitee.com/openmmlab/mirrors-flake8
rev: 5.0.4
Expand All @@ -25,6 +24,10 @@ repos:
args: ["--remove"]
- id: mixed-line-ending
args: ["--fix=lf"]
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-mdformat
rev: 0.7.9
hooks:
Expand All @@ -34,20 +37,11 @@ repos:
- mdformat-openmmlab
- mdformat_frontmatter
- linkify-it-py
- repo: https://gitee.com/openmmlab/mirrors-codespell
rev: v2.2.1
hooks:
- id: codespell
- repo: https://gitee.com/openmmlab/mirrors-docformatter
rev: v1.3.1
hooks:
- id: docformatter
args: ["--in-place", "--wrap-descriptions", "79"]
- repo: https://gitee.com/openmmlab/mirrors-pyupgrade
rev: v3.0.0
hooks:
- id: pyupgrade
args: ["--py36-plus"]
- repo: https://gitee.com/openmmlab/pre-commit-hooks
rev: v0.2.0
hooks:
Expand Down
54 changes: 36 additions & 18 deletions mmdet3d/models/backbones/pointnet2_sa_msg.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from mmcv.cnn import ConvModule
from torch import nn as nn

from mmdet3d.models.layers.pointnet_modules import build_sa_module
from mmdet3d.registry import MODELS
from mmdet3d.utils import OptConfigType
from .base_pointnet import BasePointNet

ThreeTupleIntType = Tuple[Tuple[Tuple[int, int, int]]]
TwoTupleIntType = Tuple[Tuple[int, int, int]]
TwoTupleStrType = Tuple[Tuple[str]]


@MODELS.register_module()
class PointNet2SAMSG(BasePointNet):
Expand All @@ -22,7 +29,7 @@ class PointNet2SAMSG(BasePointNet):
sa_channels (tuple[tuple[int]]): Out channels of each mlp in SA module.
aggregation_channels (tuple[int]): Out channels of aggregation
multi-scale grouping features.
fps_mods (tuple[int]): Mod of FPS for each SA module.
fps_mods Sequence[Tuple[str]]: Mod of FPS for each SA module.
fps_sample_range_lists (tuple[tuple[int]]): The number of sampling
points which each SA module samples.
dilated_group (tuple[bool]): Whether to use dilated ball query for
Expand All @@ -38,26 +45,37 @@ class PointNet2SAMSG(BasePointNet):
"""

def __init__(self,
in_channels,
num_points=(2048, 1024, 512, 256),
radii=((0.2, 0.4, 0.8), (0.4, 0.8, 1.6), (1.6, 3.2, 4.8)),
num_samples=((32, 32, 64), (32, 32, 64), (32, 32, 32)),
sa_channels=(((16, 16, 32), (16, 16, 32), (32, 32, 64)),
((64, 64, 128), (64, 64, 128), (64, 96, 128)),
((128, 128, 256), (128, 192, 256), (128, 256,
256))),
aggregation_channels=(64, 128, 256),
fps_mods=(('D-FPS'), ('FS'), ('F-FPS', 'D-FPS')),
fps_sample_range_lists=((-1), (-1), (512, -1)),
dilated_group=(True, True, True),
out_indices=(2, ),
norm_cfg=dict(type='BN2d'),
sa_cfg=dict(
in_channels: int,
num_points: Tuple[int] = (2048, 1024, 512, 256),
radii: Tuple[Tuple[float, float, float]] = (
(0.2, 0.4, 0.8),
(0.4, 0.8, 1.6),
(1.6, 3.2, 4.8),
),
num_samples: TwoTupleIntType = ((32, 32, 64), (32, 32, 64),
(32, 32, 32)),
sa_channels: ThreeTupleIntType = (((16, 16, 32), (16, 16, 32),
(32, 32, 64)),
((64, 64, 128),
(64, 64, 128), (64, 96,
128)),
((128, 128, 256),
(128, 192, 256), (128, 256,
256))),
aggregation_channels: Tuple[int] = (64, 128, 256),
fps_mods: TwoTupleStrType = (('D-FPS'), ('FS'), ('F-FPS',
'D-FPS')),
fps_sample_range_lists: TwoTupleIntType = ((-1), (-1), (512,
-1)),
dilated_group: Tuple[bool] = (True, True, True),
out_indices: Tuple[int] = (2, ),
norm_cfg: dict = dict(type='BN2d'),
sa_cfg: dict = dict(
type='PointSAModuleMSG',
pool_mod='max',
use_xyz=True,
normalize_xyz=False),
init_cfg=None):
init_cfg: OptConfigType = None):
super().__init__(init_cfg=init_cfg)
self.num_sa = len(sa_channels)
self.out_indices = out_indices
Expand Down Expand Up @@ -123,7 +141,7 @@ def __init__(self,
bias=True))
sa_in_channel = cur_aggregation_channel

def forward(self, points):
def forward(self, points: torch.Tensor):
"""Forward pass.

Args:
Expand Down
4 changes: 3 additions & 1 deletion mmdet3d/models/layers/pointnet_modules/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from mmengine.registry import Registry
from torch import nn as nn

SA_MODULES = Registry('point_sa_module')
SA_MODULES = Registry(
name='point_sa_module',
locations=['mmdet3d.models.layers.pointnet_modules'])


def build_sa_module(cfg: Union[dict, None], *args, **kwargs) -> nn.Module:
Expand Down