diff --git a/mmdet3d/models/decode_heads/__init__.py b/mmdet3d/models/decode_heads/__init__.py index f7560e5a64..6265875bba 100644 --- a/mmdet3d/models/decode_heads/__init__.py +++ b/mmdet3d/models/decode_heads/__init__.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. from .cylinder3d_head import Cylinder3DHead +from .decode_head import Base3DDecodeHead from .dgcnn_head import DGCNNHead from .minkunet_head import MinkUNetHead from .paconv_head import PAConvHead @@ -7,5 +8,5 @@ __all__ = [ 'PointNet2Head', 'DGCNNHead', 'PAConvHead', 'Cylinder3DHead', - 'MinkUNetHead' + 'Base3DDecodeHead', 'MinkUNetHead' ] diff --git a/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-3class.py b/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-3class.py index 22a71521d9..14bcbb9296 100644 --- a/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-3class.py +++ b/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-3class.py @@ -1,4 +1,4 @@ -_base_ = ['mmdet3d::_base_/default_runtime.py'] +_base_ = ['../../../configs/_base_/default_runtime.py'] custom_imports = dict( imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False) diff --git a/projects/DETR3D/configs/detr3d_r101_gridmask.py b/projects/DETR3D/configs/detr3d_r101_gridmask.py index 522644ca5b..ef9a86383d 100644 --- a/projects/DETR3D/configs/detr3d_r101_gridmask.py +++ b/projects/DETR3D/configs/detr3d_r101_gridmask.py @@ -1,6 +1,6 @@ _base_ = [ # 'mmdet3d::_base_/datasets/nus-3d.py', - 'mmdet3d::_base_/default_runtime.py' + '../../../configs/_base_/default_runtime.py' ] custom_imports = dict(imports=['projects.DETR3D.detr3d']) diff --git a/projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py b/projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py index 5cec19400c..c61b36218c 100644 --- a/projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py +++ b/projects/PETR/configs/petr_vovnet_gridmask_p4_800x320.py @@ -1,6 +1,7 @@ _base_ = [ - 'mmdet3d::_base_/datasets/nus-3d.py', 'mmdet3d::_base_/default_runtime.py', - 'mmdet3d::_base_/schedules/cyclic-20e.py' + '../../../configs/_base_/datasets/nus-3d.py', + '../../../configs/_base_/default_runtime.py', + '../../../configs/_base_/schedules/cyclic-20e.py' ] backbone_norm_cfg = dict(type='LN', requires_grad=True) custom_imports = dict(imports=['projects.PETR.petr']) diff --git a/projects/TPVFormer/config/tpvformer_8xb1-2x_nus-seg.py b/projects/TPVFormer/config/tpvformer_8xb1-2x_nus-seg.py new file mode 100644 index 0000000000..7861c6f13e --- /dev/null +++ b/projects/TPVFormer/config/tpvformer_8xb1-2x_nus-seg.py @@ -0,0 +1,317 @@ +_base_ = ['../../../configs/_base_/default_runtime.py'] + +custom_imports = dict( + imports=['projects.TPVFormer.tpvformer'], allow_failed_imports=False) + +dataset_type = 'NuScenesSegDataset' +data_root = 'data/nuscenes/' +data_prefix = dict( + pts='samples/LIDAR_TOP', + pts_semantic_mask='lidarseg/v1.0-trainval', + CAM_FRONT='samples/CAM_FRONT', + CAM_FRONT_LEFT='samples/CAM_FRONT_LEFT', + CAM_FRONT_RIGHT='samples/CAM_FRONT_RIGHT', + CAM_BACK='samples/CAM_BACK', + CAM_BACK_RIGHT='samples/CAM_BACK_RIGHT', + CAM_BACK_LEFT='samples/CAM_BACK_LEFT') + +backend_args = None + +train_pipeline = [ + dict( + type='BEVLoadMultiViewImageFromFiles', + to_float32=False, + color_type='unchanged', + num_views=6, + backend_args=backend_args), + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=3, + backend_args=backend_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_seg_3d=True, + with_attr_label=False, + seg_3d_dtype='np.uint8'), + dict( + type='MultiViewWrapper', + transforms=dict(type='PhotoMetricDistortion3D')), + dict(type='SegLabelMapping'), + dict( + type='Pack3DDetInputs', + keys=['img', 'points', 'pts_semantic_mask'], + meta_keys=['lidar2img']) +] + +val_pipeline = [ + dict( + type='BEVLoadMultiViewImageFromFiles', + to_float32=False, + color_type='unchanged', + num_views=6, + backend_args=backend_args), + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=5, + use_dim=3, + backend_args=backend_args), + dict( + type='LoadAnnotations3D', + with_bbox_3d=False, + with_label_3d=False, + with_seg_3d=True, + with_attr_label=False, + seg_3d_dtype='np.uint8'), + dict(type='SegLabelMapping'), + dict( + type='Pack3DDetInputs', + keys=['img', 'points', 'pts_semantic_mask'], + meta_keys=['lidar2img']) +] + +test_pipeline = val_pipeline + +train_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + drop_last=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=data_prefix, + ann_file='nuscenes_infos_train.pkl', + pipeline=train_pipeline, + test_mode=False)) + +val_dataloader = dict( + batch_size=1, + num_workers=4, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=data_prefix, + ann_file='nuscenes_infos_val.pkl', + pipeline=val_pipeline, + test_mode=True)) + +test_dataloader = val_dataloader + +val_evaluator = dict(type='SegMetric') + +test_evaluator = val_evaluator + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=2e-4, weight_decay=0.01), + paramwise_cfg=dict(custom_keys={ + 'backbone': dict(lr_mult=0.1), + }), + clip_grad=dict(max_norm=35, norm_type=2), +) + +param_scheduler = [ + dict(type='LinearLR', start_factor=1e-5, by_epoch=False, begin=0, end=500), + dict( + type='CosineAnnealingLR', + begin=0, + T_max=24, + by_epoch=True, + eta_min=1e-6, + convert_to_iter_based=True) +] + +train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=24, val_interval=1) +val_cfg = dict(type='ValLoop') +test_cfg = dict(type='TestLoop') + +default_hooks = dict(checkpoint=dict(type='CheckpointHook', interval=1)) + +point_cloud_range = [-51.2, -51.2, -5.0, 51.2, 51.2, 3.0] +_dim_ = 128 +num_heads = 8 +_ffn_dim_ = _dim_ * 2 + +tpv_h_ = 200 +tpv_w_ = 200 +tpv_z_ = 16 +scale_h = 1 +scale_w = 1 +scale_z = 1 +num_points_in_pillar = [4, 32, 32] +num_points = [8, 64, 64] +hybrid_attn_anchors = 16 +hybrid_attn_points = 32 +hybrid_attn_init = 0 + +grid_shape = [tpv_h_ * scale_h, tpv_w_ * scale_w, tpv_z_ * scale_z] + +self_cross_layer = dict( + type='TPVFormerLayer', + attn_cfgs=[ + dict( + type='TPVCrossViewHybridAttention', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, + num_anchors=hybrid_attn_anchors, + embed_dims=_dim_, + num_heads=num_heads, + num_points=hybrid_attn_points, + init_mode=hybrid_attn_init, + dropout=0.1), + dict( + type='TPVImageCrossAttention', + pc_range=point_cloud_range, + num_cams=6, + dropout=0.1, + deformable_attention=dict( + type='TPVMSDeformableAttention3D', + embed_dims=_dim_, + num_heads=num_heads, + num_points=num_points, + num_z_anchors=num_points_in_pillar, + num_levels=4, + floor_sampling_offset=False, + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_), + embed_dims=_dim_, + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_) + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')) + +self_layer = dict( + type='TPVFormerLayer', + attn_cfgs=[ + dict( + type='TPVCrossViewHybridAttention', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, + num_anchors=hybrid_attn_anchors, + embed_dims=_dim_, + num_heads=num_heads, + num_points=hybrid_attn_points, + init_mode=hybrid_attn_init, + dropout=0.1) + ], + feedforward_channels=_ffn_dim_, + ffn_dropout=0.1, + operation_order=('self_attn', 'norm', 'ffn', 'norm')) + +model = dict( + type='TPVFormer', + data_preprocessor=dict( + type='TPVFormerDataPreprocessor', + pad_size_divisor=32, + mean=[103.530, 116.280, 123.675], + std=[1.0, 1.0, 1.0], + voxel=True, + voxel_type='cylindrical', + voxel_layer=dict( + grid_shape=grid_shape, + point_cloud_range=point_cloud_range, + max_num_points=-1, + max_voxels=-1, + ), + batch_augments=[ + dict( + type='GridMask', + use_h=True, + use_w=True, + rotate=1, + offset=False, + ratio=0.5, + mode=1, + prob=0.7) + ]), + backbone=dict( + type='mmdet.ResNet', + depth=101, + num_stages=4, + out_indices=(1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN2d', requires_grad=False), + norm_eval=True, + style='caffe', + dcn=dict( + type='DCNv2', deform_groups=1, fallback_on_stride=False + ), # original DCNv2 will print log when perform load_state_dict + stage_with_dcn=(False, False, True, True), + init_cfg=dict( + type='Pretrained', + checkpoint='checkpoints/tpvformer_r101_dcn_fcos3d_pretrain.pth', + prefix='backbone.')), + neck=dict( + type='mmdet.FPN', + in_channels=[512, 1024, 2048], + out_channels=_dim_, + start_level=0, + add_extra_convs='on_output', + num_outs=4, + relu_before_extra_convs=True, + init_cfg=dict( + type='Pretrained', + checkpoint='checkpoints/tpvformer_r101_dcn_fcos3d_pretrain.pth', + prefix='neck.')), + encoder=dict( + type='TPVFormerEncoder', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, + num_layers=5, + pc_range=point_cloud_range, + num_points_in_pillar=num_points_in_pillar, + num_points_in_pillar_cross_view=[16, 16, 16], + return_intermediate=False, + transformerlayers=[ + self_cross_layer, self_cross_layer, self_cross_layer, self_layer, + self_layer + ], + embed_dims=_dim_, + positional_encoding=dict( + type='TPVFormerPositionalEncoding', + num_feats=[48, 48, 32], + h=tpv_h_, + w=tpv_w_, + z=tpv_z_)), + decode_head=dict( + type='TPVFormerDecoder', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, + num_classes=17, + in_dims=_dim_, + hidden_dims=2 * _dim_, + out_dims=_dim_, + scale_h=scale_h, + scale_w=scale_w, + scale_z=scale_z, + loss_ce=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=False, + class_weight=None, + avg_non_ignore=True, + loss_weight=1.0), + loss_lovasz=dict(type='LovaszLoss', loss_weight=1.0, reduction='none'), + lovasz_input='points', + ce_input='voxel', + ignore_index=0)) diff --git a/projects/TPVFormer/tpvformer/__init__.py b/projects/TPVFormer/tpvformer/__init__.py new file mode 100644 index 0000000000..6162558cfb --- /dev/null +++ b/projects/TPVFormer/tpvformer/__init__.py @@ -0,0 +1,17 @@ +from .cross_view_hybrid_attention import TPVCrossViewHybridAttention +from .data_preprocessor import TPVFormerDataPreprocessor +from .image_cross_attention import TPVImageCrossAttention +from .loading import BEVLoadMultiViewImageFromFiles, SegLabelMapping +from .nuscenes_dataset import NuScenesSegDataset +from .positional_encoding import TPVFormerPositionalEncoding +from .tpvformer import TPVFormer +from .tpvformer_encoder import TPVFormerEncoder +from .tpvformer_head import TPVFormerDecoder +from .tpvformer_layer import TPVFormerLayer + +__all__ = [ + 'TPVCrossViewHybridAttention', 'TPVImageCrossAttention', + 'TPVFormerPositionalEncoding', 'TPVFormer', 'TPVFormerEncoder', + 'TPVFormerLayer', 'NuScenesSegDataset', 'BEVLoadMultiViewImageFromFiles', + 'SegLabelMapping', 'TPVFormerDecoder', 'TPVFormerDataPreprocessor' +] diff --git a/projects/TPVFormer/tpvformer/cross_view_hybrid_attention.py b/projects/TPVFormer/tpvformer/cross_view_hybrid_attention.py new file mode 100644 index 0000000000..5e881775dd --- /dev/null +++ b/projects/TPVFormer/tpvformer/cross_view_hybrid_attention.py @@ -0,0 +1,209 @@ +import math +from typing import List, Optional, Tuple + +import torch +import torch.nn as nn +from mmcv.ops.multi_scale_deform_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) +from mmengine.model import BaseModule, constant_init, xavier_init +from torch import Tensor + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class TPVCrossViewHybridAttention(BaseModule): + """TPVFormer Cross-view Hybrid Attention Module.""" + + def __init__(self, + tpv_h: int, + tpv_w: int, + tpv_z: int, + embed_dims: int = 256, + num_heads: int = 8, + num_points: int = 4, + num_anchors: int = 2, + init_mode: int = 0, + dropout: float = 0.1, + **kwargs): + super().__init__() + self.embed_dims = embed_dims + self.num_heads = num_heads + self.num_levels = 3 + self.num_points = num_points + self.num_anchors = num_anchors + self.init_mode = init_mode + self.dropout = nn.ModuleList([nn.Dropout(dropout) for _ in range(3)]) + self.output_proj = nn.ModuleList( + [nn.Linear(embed_dims, embed_dims) for _ in range(3)]) + self.sampling_offsets = nn.ModuleList([ + nn.Linear(embed_dims, num_heads * 3 * num_points * 2) + for _ in range(3) + ]) + self.attention_weights = nn.ModuleList([ + nn.Linear(embed_dims, num_heads * 3 * (num_points + 1)) + for _ in range(3) + ]) + self.value_proj = nn.ModuleList( + [nn.Linear(embed_dims, embed_dims) for _ in range(3)]) + + self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z + + def init_weights(self): + """Default initialization for Parameters of Module.""" + device = next(self.parameters()).device + # self plane + theta_self = torch.arange( + self.num_heads, dtype=torch.float32, + device=device) * (2.0 * math.pi / self.num_heads) + grid_self = torch.stack( + [theta_self.cos(), theta_self.sin()], -1) # H, 2 + grid_self = grid_self.view(self.num_heads, 1, + 2).repeat(1, self.num_points, 1) + for j in range(self.num_points): + grid_self[:, j, :] *= (j + 1) / 2 + + if self.init_mode == 0: + # num_phi = 4 + phi = torch.arange( + 4, dtype=torch.float32, device=device) * (2.0 * math.pi / 4) + assert self.num_heads % 4 == 0 + num_theta = int(self.num_heads / 4) + theta = torch.arange( + num_theta, dtype=torch.float32, device=device) * ( + math.pi / num_theta) + (math.pi / num_theta / 2) # 3 + x = torch.matmul(theta.sin().unsqueeze(-1), + phi.cos().unsqueeze(0)).flatten() + y = torch.matmul(theta.sin().unsqueeze(-1), + phi.sin().unsqueeze(0)).flatten() + z = theta.cos().unsqueeze(-1).repeat(1, 4).flatten() + xyz = torch.stack([x, y, z], dim=-1) # H, 3 + + elif self.init_mode == 1: + + xyz = [[0, 0, 1], [0, 0, -1], [0, 1, 0], [0, -1, 0], [1, 0, 0], + [-1, 0, 0]] + xyz = torch.tensor(xyz, dtype=torch.float32, device=device) + + grid_hw = xyz[:, [0, 1]] # H, 2 + grid_zh = xyz[:, [2, 0]] + grid_wz = xyz[:, [1, 2]] + + for i in range(3): + grid = torch.stack([grid_hw, grid_zh, grid_wz], dim=1) # H, 3, 2 + grid = grid.unsqueeze(2).repeat(1, 1, self.num_points, 1) + + grid = grid.reshape(self.num_heads, self.num_levels, + self.num_anchors, -1, 2) + for j in range(self.num_points // self.num_anchors): + grid[:, :, :, j, :] *= 2 * (j + 1) + grid = grid.flatten(2, 3) + grid[:, i, :, :] = grid_self + + constant_init(self.sampling_offsets[i], 0.) + self.sampling_offsets[i].bias.data = grid.view(-1) + + constant_init(self.attention_weights[i], val=0., bias=0.) + attn_bias = torch.zeros( + self.num_heads, 3, self.num_points + 1, device=device) + attn_bias[:, i, -1] = 10 + self.attention_weights[i].bias.data = attn_bias.flatten() + xavier_init(self.value_proj[i], distribution='uniform', bias=0.) + xavier_init(self.output_proj[i], distribution='uniform', bias=0.) + + def get_sampling_offsets_and_attention( + self, queries: List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]: + offsets = [] + attns = [] + for i, (query, fc, attn) in enumerate( + zip(queries, self.sampling_offsets, self.attention_weights)): + bs, l, d = query.shape + + offset = fc(query).reshape(bs, l, self.num_heads, self.num_levels, + self.num_points, 2) + offsets.append(offset) + + attention = attn(query).reshape(bs, l, self.num_heads, 3, -1) + level_attention = attention[:, :, :, :, + -1:].softmax(-2) # bs, l, H, 3, 1 + attention = attention[:, :, :, :, :-1] + attention = attention.softmax(-1) # bs, l, H, 3, p + attention = attention * level_attention + attns.append(attention) + + offsets = torch.cat(offsets, dim=1) + attns = torch.cat(attns, dim=1) + return offsets, attns + + def reshape_output(self, output: Tensor, lens: List[int]) -> List[Tensor]: + outputs = torch.split(output, [lens[0], lens[1], lens[2]], dim=1) + return outputs + + def forward(self, + query: List[Tensor], + identity: Optional[List[Tensor]] = None, + query_pos: Optional[List[Tensor]] = None, + reference_points=None, + spatial_shapes=None, + level_start_index=None): + identity = query if identity is None else identity + if query_pos is not None: + query = [q + p for q, p in zip(query, query_pos)] + + # value proj + query_lens = [q.shape[1] for q in query] + value = [layer(q) for layer, q in zip(self.value_proj, query)] + value = torch.cat(value, dim=1) + bs, num_value, _ = value.shape + value = value.view(bs, num_value, self.num_heads, -1) + + # sampling offsets and weights + sampling_offsets, attention_weights = \ + self.get_sampling_offsets_and_attention(query) + + if reference_points.shape[-1] == 2: + """For each tpv query, it owns `num_Z_anchors` in 3D space that + having different heights. After projecting, each tpv query has + `num_Z_anchors` reference points in each 2D image. For each + referent point, we sample `num_points` sampling points. + + For `num_Z_anchors` reference points, + it has overall `num_points * num_Z_anchors` sampling points. + """ + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + + bs, num_query, _, num_Z_anchors, xy = reference_points.shape + reference_points = reference_points[:, :, None, :, :, None, :] + sampling_offsets = sampling_offsets / \ + offset_normalizer[None, None, None, :, None, :] + bs, num_query, num_heads, num_levels, num_all_points, xy = sampling_offsets.shape # noqa + sampling_offsets = sampling_offsets.view( + bs, num_query, num_heads, num_levels, num_Z_anchors, + num_all_points // num_Z_anchors, xy) + sampling_locations = reference_points + sampling_offsets + bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, xy = sampling_locations.shape # noqa + + sampling_locations = sampling_locations.view( + bs, num_query, num_heads, num_levels, num_all_points, xy) + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2, but get {reference_points.shape[-1]} instead.') + + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, 64) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + outputs = self.reshape_output(output, query_lens) + + results = [] + for out, layer, drop, residual in zip(outputs, self.output_proj, + self.dropout, identity): + results.append(residual + drop(layer(out))) + + return results diff --git a/projects/TPVFormer/tpvformer/data_preprocessor.py b/projects/TPVFormer/tpvformer/data_preprocessor.py new file mode 100644 index 0000000000..e340873a35 --- /dev/null +++ b/projects/TPVFormer/tpvformer/data_preprocessor.py @@ -0,0 +1,152 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import List, Tuple + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image +from torch import Tensor +from torch.nn import functional as F + +from mmdet3d.models import Det3DDataPreprocessor +from mmdet3d.models.data_preprocessors.voxelize import dynamic_scatter_3d +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList + + +@MODELS.register_module() +class TPVFormerDataPreprocessor(Det3DDataPreprocessor): + + @torch.no_grad() + def voxelize(self, points: List[Tensor], + data_samples: SampleList) -> List[Tensor]: + """Apply voxelization to point cloud. In TPVFormer, it will get voxel- + wise segmentation label and voxel/point coordinates. + + Args: + points (List[Tensor]): Point cloud in one data batch. + data_samples: (List[:obj:`Det3DDataSample`]): The annotation data + of every samples. Add voxel-wise annotation for segmentation. + + Returns: + List[Tensor]: Coordinates of voxels, shape is Nx3, + """ + for point, data_sample in zip(points, data_samples): + min_bound = point.new_tensor( + self.voxel_layer.point_cloud_range[:3]) + max_bound = point.new_tensor( + self.voxel_layer.point_cloud_range[3:]) + point_clamp = torch.clamp(point, min_bound, max_bound + 1e-6) + coors = torch.floor( + (point_clamp - min_bound) / + point_clamp.new_tensor(self.voxel_layer.voxel_size)).int() + self.get_voxel_seg(coors, data_sample) + data_sample.point_coors = coors + + def get_voxel_seg(self, res_coors: Tensor, data_sample: SampleList): + """Get voxel-wise segmentation label and point2voxel map. + + Args: + res_coors (Tensor): The voxel coordinates of points, Nx3. + data_sample: (:obj:`Det3DDataSample`): The annotation data of + every samples. Add voxel-wise annotation forsegmentation. + """ + + if self.training: + pts_semantic_mask = data_sample.gt_pts_seg.pts_semantic_mask + pts_semantic_mask = F.one_hot(pts_semantic_mask.long()).float() + voxel_semantic_mask, voxel_coors, point2voxel_map = \ + dynamic_scatter_3d(pts_semantic_mask, res_coors, 'mean', True) + voxel_semantic_mask = torch.argmax(voxel_semantic_mask, dim=-1) + data_sample.gt_pts_seg.voxel_semantic_mask = voxel_semantic_mask + data_sample.point2voxel_map = point2voxel_map + data_sample.voxel_coors = voxel_coors + else: + pseudo_tensor = res_coors.new_ones([res_coors.shape[0], 1]).float() + _, _, point2voxel_map = dynamic_scatter_3d(pseudo_tensor, + res_coors, 'mean', True) + data_sample.point2voxel_map = point2voxel_map + + +@MODELS.register_module() +class GridMask(nn.Module): + """GridMask data augmentation. + + Modified from https://github.com/dvlab-research/GridMask. + + Args: + use_h (bool): Whether to mask on height dimension. Defaults to True. + use_w (bool): Whether to mask on width dimension. Defaults to True. + rotate (int): Rotation degree. Defaults to 1. + offset (bool): Whether to mask offset. Defaults to False. + ratio (float): Mask ratio. Defaults to 0.5. + mode (int): Mask mode. if mode == 0, mask with square grid. + if mode == 1, mask the rest. Defaults to 0 + prob (float): Probability of applying the augmentation. + Defaults to 1.0. + """ + + def __init__(self, + use_h: bool = True, + use_w: bool = True, + rotate: int = 1, + offset: bool = False, + ratio: float = 0.5, + mode: int = 0, + prob: float = 1.0): + super().__init__() + self.use_h = use_h + self.use_w = use_w + self.rotate = rotate + self.offset = offset + self.ratio = ratio + self.mode = mode + self.prob = prob + + def forward(self, inputs: Tensor, + data_samples: SampleList) -> Tuple[Tensor, SampleList]: + if np.random.rand() > self.prob: + return inputs, data_samples + height, width = inputs.shape[-2:] + mask_height = int(1.5 * height) + mask_width = int(1.5 * width) + distance = np.random.randint(2, min(height, width)) + length = min(max(int(distance * self.ratio + 0.5), 1), distance - 1) + mask = np.ones((mask_height, mask_width), np.float32) + stride_on_height = np.random.randint(distance) + stride_on_width = np.random.randint(distance) + if self.use_h: + for i in range(mask_height // distance): + start = distance * i + stride_on_height + end = min(start + length, mask_height) + mask[start:end, :] *= 0 + if self.use_w: + for i in range(mask_width // distance): + start = distance * i + stride_on_width + end = min(start + length, mask_width) + mask[:, start:end] *= 0 + + # NOTE: r is the rotation radian, here is a random counterclockwise + # rotation of 1° or remain unchanged, which follows the implementation + # of the official detection version. + # https://github.com/dvlab-research/GridMask. + r = np.random.randint(self.rotate) + mask = Image.fromarray(np.uint8(mask)) + + mask = mask.rotate(r) + mask = np.array(mask) + mask = mask[int(0.25 * height):int(0.25 * height) + height, + int(0.25 * width):int(0.25 * width) + width] + + mask = inputs.new_tensor(mask) + if self.mode == 1: + mask = 1 - mask + mask = mask.expand_as(inputs) + if self.offset: + offset = inputs.new_tensor(2 * + (np.random.rand(height, width) - 0.5)) + inputs = inputs * mask + offset * (1 - mask) + else: + inputs = inputs * mask + + return inputs, data_samples diff --git a/projects/TPVFormer/tpvformer/image_cross_attention.py b/projects/TPVFormer/tpvformer/image_cross_attention.py new file mode 100644 index 0000000000..06ad6331a1 --- /dev/null +++ b/projects/TPVFormer/tpvformer/image_cross_attention.py @@ -0,0 +1,465 @@ +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.ops.multi_scale_deform_attn import ( + MultiScaleDeformableAttnFunction, multi_scale_deformable_attn_pytorch) +from mmengine.model import BaseModule, constant_init, xavier_init + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class TPVImageCrossAttention(BaseModule): + """An attention module used in TPVFormer. + + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_cams (int): The number of cameras + dropout (float): A Dropout layer on `inp_residual`. + Default: 0.1. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Whether the first dimension of the input is batch. + deformable_attention: (dict): The config for the deformable + attention used in SCA. + tpv_h (int): The height of the TPV. + tpv_w (int): The width of the TPV. + tpv_z (int): The depth of the TPV. + """ + + def __init__(self, + embed_dims=256, + num_cams=6, + pc_range=None, + dropout=0.1, + init_cfg=None, + batch_first=True, + deformable_attention=dict( + type='MSDeformableAttention3D', + embed_dims=256, + num_levels=4), + tpv_h=None, + tpv_w=None, + tpv_z=None): + super().__init__(init_cfg) + + self.init_cfg = init_cfg + self.dropout = nn.Dropout(dropout) + self.pc_range = pc_range + self.fp16_enabled = False + self.deformable_attention = MODELS.build(deformable_attention) + self.embed_dims = embed_dims + self.num_cams = num_cams + self.output_proj = nn.Linear(embed_dims, embed_dims) + self.batch_first = batch_first + self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z + self.init_weight() + + def init_weight(self): + """Default initialization for Parameters of Module.""" + xavier_init(self.output_proj, distribution='uniform', bias=0.) + + def forward(self, + query, + key, + value, + residual=None, + spatial_shapes=None, + reference_points_cams=None, + tpv_masks=None, + level_start_index=None): + """Forward Function of Detr3DCrossAtten. + + Args: + query (Tensor): Query of Transformer with shape + (bs, num_query, embed_dims). + key (Tensor): The key tensor with shape + (bs, num_key, embed_dims). + value (Tensor): The value tensor with shape + (bs, num_key, embed_dims). + residual (Tensor): The tensor used for addition, with the + same shape as `x`. Default None. If None, `x` will be used. + spatial_shapes (Tensor): Spatial shape of features in + different level. With shape (num_levels, 2), + last dimension represent (h, w). + tpv_masks (List[Tensor]): The mask of each views. + level_start_index (Tensor): The start index of each level. + A tensor has shape (num_levels) and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + reference_points_cams (List[Tensor]): The reference points in + each camera. + tpv_masks (List[Tensor]): The mask of each views. + level_start_index (List[int]): The start index of each level. + + Returns: + Tensor: forwarded results with shape [num_query, bs, embed_dims]. + """ + if key is None: + key = query + if value is None: + value = key + + if residual is None: + inp_residual = query + bs, _, _ = query.size() + + queries = torch.split( + query, [ + self.tpv_h * self.tpv_w, self.tpv_z * self.tpv_h, + self.tpv_w * self.tpv_z + ], + dim=1) + if residual is None: + slots = [torch.zeros_like(q) for q in queries] + indexeses = [] + max_lens = [] + queries_rebatches = [] + reference_points_rebatches = [] + for tpv_idx, tpv_mask in enumerate(tpv_masks): + indexes = [] + for _, mask_per_img in enumerate(tpv_mask): + index_query_per_img = mask_per_img[0].sum( + -1).nonzero().squeeze(-1) + indexes.append(index_query_per_img) + max_len = max([len(each) for each in indexes]) + max_lens.append(max_len) + indexeses.append(indexes) + + reference_points_cam = reference_points_cams[tpv_idx] + D = reference_points_cam.size(3) + + queries_rebatch = queries[tpv_idx].new_zeros( + [bs * self.num_cams, max_len, self.embed_dims]) + reference_points_rebatch = reference_points_cam.new_zeros( + [bs * self.num_cams, max_len, D, 2]) + + for i, reference_points_per_img in enumerate(reference_points_cam): + for j in range(bs): + index_query_per_img = indexes[i] + queries_rebatch[j * self.num_cams + + i, :len(index_query_per_img)] = queries[ + tpv_idx][j, index_query_per_img] + reference_points_rebatch[j * self.num_cams + i, :len( + index_query_per_img)] = reference_points_per_img[ + j, index_query_per_img] + + queries_rebatches.append(queries_rebatch) + reference_points_rebatches.append(reference_points_rebatch) + + num_cams, l, bs, embed_dims = key.shape + + key = key.permute(0, 2, 1, 3).view(self.num_cams * bs, l, + self.embed_dims) + value = value.permute(0, 2, 1, 3).view(self.num_cams * bs, l, + self.embed_dims) + + queries = self.deformable_attention( + query=queries_rebatches, + key=key, + value=value, + reference_points=reference_points_rebatches, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + ) + + for tpv_idx, indexes in enumerate(indexeses): + for i, index_query_per_img in enumerate(indexes): + for j in range(bs): + slots[tpv_idx][j, index_query_per_img] += queries[tpv_idx][ + j * self.num_cams + i, :len(index_query_per_img)] + + count = tpv_masks[tpv_idx].sum(-1) > 0 + count = count.permute(1, 2, 0).sum(-1) + count = torch.clamp(count, min=1.0) + slots[tpv_idx] = slots[tpv_idx] / count[..., None] + slots = torch.cat(slots, dim=1) + slots = self.output_proj(slots) + + return self.dropout(slots) + inp_residual + + +@MODELS.register_module() +class TPVMSDeformableAttention3D(BaseModule): + """An attention module used in tpvFormer based on Deformable-Detr. + `Deformable DETR: Deformable Transformers for End-to-End Object Detection. + + `_. + Args: + embed_dims (int): The embedding dimension of Attention. + Default: 256. + num_heads (int): Parallel attention heads. Default: 64. + num_levels (int): The number of feature map used in + Attention. Default: 4. + num_points (int): The number of sampling points for + each query in each head. Default: 4. + im2col_step (int): The step used in image_to_column. + Default: 64. + dropout (float): A Dropout layer on `inp_identity`. + Default: 0.1. + batch_first (bool): Key, Query and Value are shape of + (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + norm_cfg (dict): Config dict for normalization layer. + Default: None. + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + """ + + def __init__( + self, + embed_dims=256, + num_heads=8, + num_levels=4, + num_points=[8, 64, 64], + num_z_anchors=[4, 32, 32], + pc_range=None, + im2col_step=64, + dropout=0.1, + batch_first=True, + norm_cfg=None, + init_cfg=None, + floor_sampling_offset=True, + tpv_h=None, + tpv_w=None, + tpv_z=None, + ): + super().__init__(init_cfg) + if embed_dims % num_heads != 0: + raise ValueError(f'embed_dims must be divisible by num_heads, ' + f'but got {embed_dims} and {num_heads}') + dim_per_head = embed_dims // num_heads + self.norm_cfg = norm_cfg + self.batch_first = batch_first + self.output_proj = None + self.fp16_enabled = False + + # you'd better set dim_per_head to a power of 2 + # which is more efficient in the CUDA implementation + def _is_power_of_2(n): + if (not isinstance(n, int)) or (n < 0): + raise ValueError( + 'invalid input for _is_power_of_2: {} (type: {})'.format( + n, type(n))) + return (n & (n - 1) == 0) and n != 0 + + if not _is_power_of_2(dim_per_head): + warnings.warn( + "You'd better set embed_dims in " + 'MultiScaleDeformAttention to make ' + 'the dimension of each attention head a power of 2 ' + 'which is more efficient in our CUDA implementation.') + + self.im2col_step = im2col_step + self.embed_dims = embed_dims + self.num_levels = num_levels + self.num_heads = num_heads + self.num_points = num_points + self.num_z_anchors = num_z_anchors + self.base_num_points = num_points[0] + self.base_z_anchors = num_z_anchors[0] + self.points_multiplier = [ + points // self.base_z_anchors for points in num_z_anchors + ] + self.pc_range = pc_range + self.tpv_h, self.tpv_w, self.tpv_z = tpv_h, tpv_w, tpv_z + self.sampling_offsets = nn.ModuleList([ + nn.Linear(embed_dims, num_heads * num_levels * num_points[i] * 2) + for i in range(3) + ]) + self.floor_sampling_offset = floor_sampling_offset + self.attention_weights = nn.ModuleList([ + nn.Linear(embed_dims, num_heads * num_levels * num_points[i]) + for i in range(3) + ]) + self.value_proj = nn.Linear(embed_dims, embed_dims) + + def init_weights(self): + """Default initialization for Parameters of Module.""" + device = next(self.parameters()).device + for i in range(3): + constant_init(self.sampling_offsets[i], 0.) + thetas = torch.arange( + self.num_heads, dtype=torch.float32, + device=device) * (2.0 * math.pi / self.num_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.num_heads, 1, 1, + 2).repeat(1, self.num_levels, self.num_points[i], + 1) + grid_init = grid_init.reshape(self.num_heads, self.num_levels, + self.num_z_anchors[i], -1, 2) + for j in range(self.num_points[i] // self.num_z_anchors[i]): + grid_init[:, :, :, j, :] *= j + 1 + + self.sampling_offsets[i].bias.data = grid_init.view(-1) + constant_init(self.attention_weights[i], val=0., bias=0.) + xavier_init(self.value_proj, distribution='uniform', bias=0.) + xavier_init(self.output_proj, distribution='uniform', bias=0.) + self._is_init = True + + def get_sampling_offsets_and_attention(self, queries): + offsets = [] + attns = [] + for i, (query, fc, attn) in enumerate( + zip(queries, self.sampling_offsets, self.attention_weights)): + bs, l, d = query.shape + + offset = fc(query).reshape(bs, l, self.num_heads, self.num_levels, + self.points_multiplier[i], -1, 2) + offset = offset.permute(0, 1, 4, 2, 3, 5, 6).flatten(1, 2) + offsets.append(offset) + + attention = attn(query).reshape(bs, l, self.num_heads, -1) + attention = attention.softmax(-1) + attention = attention.view(bs, l, self.num_heads, self.num_levels, + self.points_multiplier[i], -1) + attention = attention.permute(0, 1, 4, 2, 3, 5).flatten(1, 2) + attns.append(attention) + + offsets = torch.cat(offsets, dim=1) + attns = torch.cat(attns, dim=1) + return offsets, attns + + def reshape_reference_points(self, reference_points): + reference_point_list = [] + for i, reference_point in enumerate(reference_points): + bs, l, z_anchors, _ = reference_point.shape + reference_point = reference_point.reshape( + bs, l, self.points_multiplier[i], -1, 2) + reference_point = reference_point.flatten(1, 2) + reference_point_list.append(reference_point) + return torch.cat(reference_point_list, dim=1) + + def reshape_output(self, output, lens): + bs, _, d = output.shape + outputs = torch.split( + output, [ + lens[0] * self.points_multiplier[0], lens[1] * + self.points_multiplier[1], lens[2] * self.points_multiplier[2] + ], + dim=1) + + outputs = [ + o.reshape(bs, -1, self.points_multiplier[i], d).sum(dim=2) + for i, o in enumerate(outputs) + ] + return outputs + + def forward(self, + query, + key=None, + value=None, + identity=None, + reference_points=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): Query of Transformer with shape + ( bs, num_query, embed_dims). + key (Tensor): The key tensor with shape + `(bs, num_key, embed_dims)`. + value (Tensor): The value tensor with shape + `(bs, num_key, embed_dims)`. + identity (Tensor): The tensor used for addition, with the + same shape as `query`. Default None. If None, + `query` will be used. + reference_points (Tensor): The normalized reference + points with shape (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + Returns: + Tensor: forwarded results with shape [bs, num_query, embed_dims]. + """ + + if value is None: + value = query + if identity is None: + identity = query + + if not self.batch_first: + # change to (bs, num_query ,embed_dims) + query = [q.permute(1, 0, 2) for q in query] + value = value.permute(1, 0, 2) + + # bs, num_query, _ = query.shape + query_lens = [q.shape[1] for q in query] + bs, num_value, _ = value.shape + assert (spatial_shapes[:, 0] * spatial_shapes[:, 1]).sum() == num_value + + value = self.value_proj(value) + value = value.view(bs, num_value, self.num_heads, -1) + + sampling_offsets, attention_weights = \ + self.get_sampling_offsets_and_attention(query) + + reference_points = self.reshape_reference_points(reference_points) + + if reference_points.shape[-1] == 2: + """For each tpv query, it owns `num_Z_anchors` in 3D space that + having different heights. After projecting, each tpv query has + `num_Z_anchors` reference points in each 2D image. For each + referent point, we sample `num_points` sampling points. + + For `num_Z_anchors` reference points, + it has overall `num_points * num_Z_anchors` sampling points. + """ + offset_normalizer = torch.stack( + [spatial_shapes[..., 1], spatial_shapes[..., 0]], -1) + + bs, num_query, num_Z_anchors, xy = reference_points.shape + reference_points = reference_points[:, :, None, None, :, None, :] + sampling_offsets = sampling_offsets / \ + offset_normalizer[None, None, None, :, None, :] + bs, num_query, num_heads, num_levels, num_all_points, xy = \ + sampling_offsets.shape + sampling_offsets = sampling_offsets.view( + bs, num_query, num_heads, num_levels, num_Z_anchors, + num_all_points // num_Z_anchors, xy) + sampling_locations = reference_points + sampling_offsets + bs, num_query, num_heads, num_levels, num_points, num_Z_anchors, \ + xy = sampling_locations.shape + assert num_all_points == num_points * num_Z_anchors + + sampling_locations = sampling_locations.view( + bs, num_query, num_heads, num_levels, num_all_points, xy) + + if self.floor_sampling_offset: + sampling_locations = sampling_locations - torch.floor( + sampling_locations) + + elif reference_points.shape[-1] == 4: + assert False + else: + raise ValueError( + f'Last dim of reference_points must be' + f' 2 or 4, but get {reference_points.shape[-1]} instead.') + + if torch.cuda.is_available() and value.is_cuda: + output = MultiScaleDeformableAttnFunction.apply( + value, spatial_shapes, level_start_index, sampling_locations, + attention_weights, self.im2col_step) + else: + output = multi_scale_deformable_attn_pytorch( + value, spatial_shapes, sampling_locations, attention_weights) + + output = self.reshape_output(output, query_lens) + if not self.batch_first: + output = [o.permute(1, 0, 2) for o in output] + + return output diff --git a/projects/TPVFormer/tpvformer/loading.py b/projects/TPVFormer/tpvformer/loading.py new file mode 100644 index 0000000000..a5c3e74fce --- /dev/null +++ b/projects/TPVFormer/tpvformer/loading.py @@ -0,0 +1,172 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import Optional, Union + +import mmcv +import numpy as np +from mmcv.transforms.base import BaseTransform +from mmengine.fileio import get + +from mmdet3d.datasets.transforms import LoadMultiViewImageFromFiles +from mmdet3d.registry import TRANSFORMS + +Number = Union[int, float] + + +@TRANSFORMS.register_module() +class BEVLoadMultiViewImageFromFiles(LoadMultiViewImageFromFiles): + """Load multi channel images from a list of separate channel files. + + ``BEVLoadMultiViewImageFromFiles`` adds the following keys for the + convenience of view transforms in the forward: + - 'cam2lidar' + - 'lidar2img' + + Args: + to_float32 (bool): Whether to convert the img to float32. + Defaults to False. + color_type (str): Color type of the file. Defaults to 'unchanged'. + backend_args (dict, optional): Arguments to instantiate the + corresponding backend. Defaults to None. + num_views (int): Number of view in a frame. Defaults to 5. + num_ref_frames (int): Number of frame in loading. Defaults to -1. + test_mode (bool): Whether is test mode in loading. Defaults to False. + set_default_scale (bool): Whether to set default scale. + Defaults to True. + """ + + def transform(self, results: dict) -> Optional[dict]: + """Call function to load multi-view image from files. + + Args: + results (dict): Result dict containing multi-view image filenames. + + Returns: + dict: The result dict containing the multi-view image data. + Added keys and values are described below. + + - filename (str): Multi-view image filenames. + - img (np.ndarray): Multi-view image arrays. + - img_shape (tuple[int]): Shape of multi-view image arrays. + - ori_shape (tuple[int]): Shape of original image arrays. + - pad_shape (tuple[int]): Shape of padded image arrays. + - scale_factor (float): Scale factor. + - img_norm_cfg (dict): Normalization configuration of images. + """ + filename, cam2img, lidar2cam, lidar2img = [], [], [], [] + for _, cam_item in results['images'].items(): + filename.append(cam_item['img_path']) + lidar2cam.append(cam_item['lidar2cam']) + + lidar2cam_array = np.array(cam_item['lidar2cam']) + cam2img_array = np.eye(4).astype(np.float64) + cam2img_array[:3, :3] = np.array(cam_item['cam2img']) + cam2img.append(cam2img_array) + lidar2img.append(cam2img_array @ lidar2cam_array) + + results['img_path'] = filename + results['cam2img'] = np.stack(cam2img, axis=0) + results['lidar2cam'] = np.stack(lidar2cam, axis=0) + results['lidar2img'] = np.stack(lidar2img, axis=0) + + results['ori_cam2img'] = copy.deepcopy(results['cam2img']) + + # img is of shape (h, w, c, num_views) + # h and w can be different for different views + img_bytes = [ + get(name, backend_args=self.backend_args) for name in filename + ] + # gbr follow tpvformer + imgs = [ + mmcv.imfrombytes(img_byte, flag=self.color_type) + for img_byte in img_bytes + ] + # handle the image with different shape + img_shapes = np.stack([img.shape for img in imgs], axis=0) + img_shape_max = np.max(img_shapes, axis=0) + img_shape_min = np.min(img_shapes, axis=0) + assert img_shape_min[-1] == img_shape_max[-1] + if not np.all(img_shape_max == img_shape_min): + pad_shape = img_shape_max[:2] + else: + pad_shape = None + if pad_shape is not None: + imgs = [ + mmcv.impad(img, shape=pad_shape, pad_val=0) for img in imgs + ] + img = np.stack(imgs, axis=-1) + if self.to_float32: + img = img.astype(np.float32) + + results['filename'] = filename + # unravel to list, see `DefaultFormatBundle` in formating.py + # which will transpose each image separately and then stack into array + results['img'] = [img[..., i] for i in range(img.shape[-1])] + results['img_shape'] = img.shape[:2] + results['ori_shape'] = img.shape[:2] + # Set initial values for default meta_keys + results['pad_shape'] = img.shape[:2] + if self.set_default_scale: + results['scale_factor'] = 1.0 + num_channels = 1 if len(img.shape) < 3 else img.shape[2] + results['img_norm_cfg'] = dict( + mean=np.zeros(num_channels, dtype=np.float32), + std=np.ones(num_channels, dtype=np.float32), + to_rgb=False) + results['num_views'] = self.num_views + results['num_ref_frames'] = self.num_ref_frames + return results + + +@TRANSFORMS.register_module() +class SegLabelMapping(BaseTransform): + """Map original semantic class to valid category ids. + + Required Keys: + + - seg_label_mapping (np.ndarray) + - pts_semantic_mask (np.ndarray) + + Added Keys: + + - points (np.float32) + + Map valid classes as 0~len(valid_cat_ids)-1 and + others as len(valid_cat_ids). + """ + + def transform(self, results: dict) -> dict: + """Call function to map original semantic class to valid category ids. + + Args: + results (dict): Result dict containing point semantic masks. + + Returns: + dict: The result dict containing the mapped category ids. + Updated key and value are described below. + + - pts_semantic_mask (np.ndarray): Mapped semantic masks. + """ + assert 'pts_semantic_mask' in results + pts_semantic_mask = results['pts_semantic_mask'] + + assert 'seg_label_mapping' in results + label_mapping = results['seg_label_mapping'] + converted_pts_sem_mask = np.vectorize( + label_mapping.__getitem__, otypes=[np.uint8])( + pts_semantic_mask) + + results['pts_semantic_mask'] = converted_pts_sem_mask + + # 'eval_ann_info' will be passed to evaluator + if 'eval_ann_info' in results: + assert 'pts_semantic_mask' in results['eval_ann_info'] + results['eval_ann_info']['pts_semantic_mask'] = \ + converted_pts_sem_mask + + return results + + def __repr__(self) -> str: + """str: Return a string that describes the module.""" + repr_str = self.__class__.__name__ + return repr_str diff --git a/projects/TPVFormer/tpvformer/nuscenes_dataset.py b/projects/TPVFormer/tpvformer/nuscenes_dataset.py new file mode 100644 index 0000000000..763226bc10 --- /dev/null +++ b/projects/TPVFormer/tpvformer/nuscenes_dataset.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import os.path as osp +from typing import Callable, List, Union + +from mmengine.dataset import BaseDataset + +from mmdet3d.registry import DATASETS + + +@DATASETS.register_module() +class NuScenesSegDataset(BaseDataset): + r"""NuScenes Dataset. + + This class serves as the API for experiments on the NuScenes Dataset. + + Please refer to `NuScenes Dataset `_ + for data downloading. + + Args: + data_root (str): Path of dataset root. + ann_file (str): Path of annotation file. + pipeline (list[dict]): Pipeline used for data processing. + Defaults to []. + test_mode (bool): Store `True` when building test or val dataset. + """ + METAINFO = { + 'classes': + ('noise', 'barrier', 'bicycle', 'bus', 'car', 'construction_vehicle', + 'motorcycle', 'pedestrian', 'traffic_cone', 'trailer', 'truck', + 'driveable_surface', 'other_flat', 'sidewalk', 'terrain', 'manmade', + 'vegetation'), + 'ignore_index': + 0, + 'label_mapping': + dict([(1, 0), (5, 0), (7, 0), (8, 0), (10, 0), (11, 0), (13, 0), + (19, 0), (20, 0), (0, 0), (29, 0), (31, 0), (9, 1), (14, 2), + (15, 3), (16, 3), (17, 4), (18, 5), (21, 6), (2, 7), (3, 7), + (4, 7), (6, 7), (12, 8), (22, 9), (23, 10), (24, 11), (25, 12), + (26, 13), (27, 14), (28, 15), (30, 16)]), + 'palette': [ + [0, 0, 0], # noise + [255, 120, 50], # barrier orange + [255, 192, 203], # bicycle pink + [255, 255, 0], # bus yellow + [0, 150, 245], # car blue + [0, 255, 255], # construction_vehicle cyan + [255, 127, 0], # motorcycle dark orange + [255, 0, 0], # pedestrian red + [255, 240, 150], # traffic_cone light yellow + [135, 60, 0], # trailer brown + [160, 32, 240], # truck purple + [255, 0, 255], # driveable_surface dark pink + [139, 137, 137], # other_flat dark red + [75, 0, 75], # sidewalk dard purple + [150, 240, 80], # terrain light green + [230, 230, 250], # manmade white + [0, 175, 0], # vegetation green + ] + } + + def __init__(self, + data_root: str, + ann_file: str, + pipeline: List[Union[dict, Callable]] = [], + test_mode: bool = False, + **kwargs) -> None: + metainfo = dict(label2cat={ + i: cat_name + for i, cat_name in enumerate(self.METAINFO['classes']) + }) + super().__init__( + ann_file=ann_file, + data_root=data_root, + metainfo=metainfo, + pipeline=pipeline, + test_mode=test_mode, + **kwargs) + + def parse_data_info(self, info: dict) -> Union[List[dict], dict]: + """Process the raw data info. + + The only difference with it in `Det3DDataset` + is the specific process for `plane`. + + Args: + info (dict): Raw info dict. + + Returns: + List[dict] or dict: Has `ann_info` in training stage. And + all path has been converted to absolute path. + """ + + data_list = [] + info['lidar_points']['lidar_path'] = \ + osp.join( + self.data_prefix.get('pts', ''), + info['lidar_points']['lidar_path']) + + for cam_id, img_info in info['images'].items(): + if 'img_path' in img_info: + if cam_id in self.data_prefix: + cam_prefix = self.data_prefix[cam_id] + else: + cam_prefix = self.data_prefix.get('img', '') + img_info['img_path'] = osp.join(cam_prefix, + img_info['img_path']) + + if 'pts_semantic_mask_path' in info: + info['pts_semantic_mask_path'] = \ + osp.join(self.data_prefix.get('pts_semantic_mask', ''), + info['pts_semantic_mask_path']) + + # only be used in `PointSegClassMapping` in pipeline + # to map original semantic class to valid category ids. + info['seg_label_mapping'] = self.metainfo['label_mapping'] + + # 'eval_ann_info' will be updated in loading transforms + if self.test_mode: + info['eval_ann_info'] = dict() + + data_list.append(info) + return data_list diff --git a/projects/TPVFormer/tpvformer/positional_encoding.py b/projects/TPVFormer/tpvformer/positional_encoding.py new file mode 100644 index 0000000000..8c5aa89fec --- /dev/null +++ b/projects/TPVFormer/tpvformer/positional_encoding.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class TPVFormerPositionalEncoding(BaseModule): + + def __init__(self, + num_feats, + h, + w, + z, + init_cfg=dict(type='Uniform', layer='Embedding')): + super().__init__(init_cfg) + if not isinstance(num_feats, list): + num_feats = [num_feats] * 3 + self.h_embed = nn.Embedding(h, num_feats[0]) + self.w_embed = nn.Embedding(w, num_feats[1]) + self.z_embed = nn.Embedding(z, num_feats[2]) + self.num_feats = num_feats + self.h, self.w, self.z = h, w, z + + def forward(self, bs, device, ignore_axis='z'): + if ignore_axis == 'h': + h_embed = torch.zeros( + 1, 1, self.num_feats[0], + device=device).repeat(self.w, self.z, 1) # w, z, d + w_embed = self.w_embed(torch.arange(self.w, device=device)) + w_embed = w_embed.reshape(self.w, 1, -1).repeat(1, self.z, 1) + z_embed = self.z_embed(torch.arange(self.z, device=device)) + z_embed = z_embed.reshape(1, self.z, -1).repeat(self.w, 1, 1) + elif ignore_axis == 'w': + h_embed = self.h_embed(torch.arange(self.h, device=device)) + h_embed = h_embed.reshape(1, self.h, -1).repeat(self.z, 1, 1) + w_embed = torch.zeros( + 1, 1, self.num_feats[1], + device=device).repeat(self.z, self.h, 1) + z_embed = self.z_embed(torch.arange(self.z, device=device)) + z_embed = z_embed.reshape(self.z, 1, -1).repeat(1, self.h, 1) + elif ignore_axis == 'z': + h_embed = self.h_embed(torch.arange(self.h, device=device)) + h_embed = h_embed.reshape(self.h, 1, -1).repeat(1, self.w, 1) + w_embed = self.w_embed(torch.arange(self.w, device=device)) + w_embed = w_embed.reshape(1, self.w, -1).repeat(self.h, 1, 1) + z_embed = torch.zeros( + 1, 1, self.num_feats[2], + device=device).repeat(self.h, self.w, 1) + + pos = torch.cat((h_embed, w_embed, z_embed), + dim=-1).flatten(0, 1).unsqueeze(0).repeat(bs, 1, 1) + return pos diff --git a/projects/TPVFormer/tpvformer/tpvformer.py b/projects/TPVFormer/tpvformer/tpvformer.py new file mode 100644 index 0000000000..1cd3de3548 --- /dev/null +++ b/projects/TPVFormer/tpvformer/tpvformer.py @@ -0,0 +1,72 @@ +from typing import Optional, Union + +from torch import nn + +from mmdet3d.models import Base3DSegmentor +from mmdet3d.registry import MODELS +from mmdet3d.structures.det3d_data_sample import SampleList + + +@MODELS.register_module() +class TPVFormer(Base3DSegmentor): + + def __init__(self, + data_preprocessor: Optional[Union[dict, nn.Module]] = None, + backbone=None, + neck=None, + encoder=None, + decode_head=None): + + super().__init__(data_preprocessor=data_preprocessor) + + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + self.encoder = MODELS.build(encoder) + self.decode_head = MODELS.build(decode_head) + + def extract_feat(self, img): + """Extract features of images.""" + B, N, C, H, W = img.size() + img = img.view(B * N, C, H, W) + img_feats = self.backbone(img) + + if hasattr(self, 'neck'): + img_feats = self.neck(img_feats) + + img_feats_reshaped = [] + for img_feat in img_feats: + _, C, H, W = img_feat.size() + img_feats_reshaped.append(img_feat.view(B, N, C, H, W)) + return img_feats_reshaped + + def _forward(self, batch_inputs, batch_data_samples): + """Forward training function.""" + img_feats = self.extract_feat(batch_inputs['imgs']) + outs = self.encoder(img_feats, batch_data_samples) + outs = self.decode_head(outs, batch_inputs['voxels']['coors']) + return outs + + def loss(self, batch_inputs: dict, + batch_data_samples: SampleList) -> SampleList: + img_feats = self.extract_feat(batch_inputs['imgs']) + queries = self.encoder(img_feats, batch_data_samples) + losses = self.decode_head.loss(queries, batch_data_samples) + return losses + + def predict(self, batch_inputs: dict, + batch_data_samples: SampleList) -> SampleList: + """Forward predict function.""" + img_feats = self.extract_feat(batch_inputs['imgs']) + tpv_queries = self.encoder(img_feats, batch_data_samples) + seg_logits = self.decode_head.predict(tpv_queries, batch_data_samples) + seg_preds = [seg_logit.argmax(dim=1) for seg_logit in seg_logits] + + return self.postprocess_result(seg_preds, batch_data_samples) + + def aug_test(self, batch_inputs, batch_data_samples): + pass + + def encode_decode(self, batch_inputs: dict, + batch_data_samples: SampleList) -> SampleList: + pass diff --git a/projects/TPVFormer/tpvformer/tpvformer_encoder.py b/projects/TPVFormer/tpvformer/tpvformer_encoder.py new file mode 100644 index 0000000000..ea75df0b12 --- /dev/null +++ b/projects/TPVFormer/tpvformer/tpvformer_encoder.py @@ -0,0 +1,340 @@ +import numpy as np +import torch +from mmcv.cnn.bricks.transformer import TransformerLayerSequence +from mmengine.registry import MODELS +from torch import nn +from torch.nn.init import normal_ + +from .cross_view_hybrid_attention import TPVCrossViewHybridAttention +from .image_cross_attention import TPVMSDeformableAttention3D + + +@MODELS.register_module() +class TPVFormerEncoder(TransformerLayerSequence): + + def __init__(self, + tpv_h=200, + tpv_w=200, + tpv_z=16, + pc_range=[-51.2, -51.2, -5, 51.2, 51.2, 3], + num_feature_levels=4, + num_cams=6, + embed_dims=256, + num_points_in_pillar=[4, 32, 32], + num_points_in_pillar_cross_view=[32, 32, 32], + num_layers=5, + transformerlayers=None, + positional_encoding=None, + return_intermediate=False): + super().__init__(transformerlayers, num_layers) + + self.tpv_h = tpv_h + self.tpv_w = tpv_w + self.tpv_z = tpv_z + self.pc_range = pc_range + self.real_w = pc_range[3] - pc_range[0] + self.real_h = pc_range[4] - pc_range[1] + self.real_z = pc_range[5] - pc_range[2] + + self.level_embeds = nn.Parameter( + torch.Tensor(num_feature_levels, embed_dims)) + self.cams_embeds = nn.Parameter(torch.Tensor(num_cams, embed_dims)) + self.tpv_embedding_hw = nn.Embedding(tpv_h * tpv_w, embed_dims) + self.tpv_embedding_zh = nn.Embedding(tpv_z * tpv_h, embed_dims) + self.tpv_embedding_wz = nn.Embedding(tpv_w * tpv_z, embed_dims) + + ref_3d_hw = self.get_reference_points(tpv_h, tpv_w, self.real_z, + num_points_in_pillar[0]) + ref_3d_zh = self.get_reference_points(tpv_z, tpv_h, self.real_w, + num_points_in_pillar[1]) + ref_3d_zh = ref_3d_zh.permute(3, 0, 1, 2)[[2, 0, 1]] # change to x,y,z + ref_3d_zh = ref_3d_zh.permute(1, 2, 3, 0) + ref_3d_wz = self.get_reference_points(tpv_w, tpv_z, self.real_h, + num_points_in_pillar[2]) + ref_3d_wz = ref_3d_wz.permute(3, 0, 1, 2)[[1, 2, 0]] # change to x,y,z + ref_3d_wz = ref_3d_wz.permute(1, 2, 3, 0) + self.register_buffer('ref_3d_hw', ref_3d_hw) + self.register_buffer('ref_3d_zh', ref_3d_zh) + self.register_buffer('ref_3d_wz', ref_3d_wz) + + cross_view_ref_points = self.get_cross_view_ref_points( + tpv_h, tpv_w, tpv_z, num_points_in_pillar_cross_view) + self.register_buffer('cross_view_ref_points', cross_view_ref_points) + + # positional encoding + self.positional_encoding = MODELS.build(positional_encoding) + self.return_intermediate = return_intermediate + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, TPVMSDeformableAttention3D) or isinstance( + m, TPVCrossViewHybridAttention): + m.init_weights() + normal_(self.level_embeds) + normal_(self.cams_embeds) + + @staticmethod + def get_cross_view_ref_points(tpv_h, tpv_w, tpv_z, num_points_in_pillar): + # ref points generating target: (#query)hw+zh+wz, (#level)3, #p, 2 + # generate points for hw and level 1 + h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h + w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w + h_ranges = h_ranges.unsqueeze(-1).expand(-1, tpv_w).flatten() + w_ranges = w_ranges.unsqueeze(0).expand(tpv_h, -1).flatten() + hw_hw = torch.stack([w_ranges, h_ranges], dim=-1) # hw, 2 + hw_hw = hw_hw.unsqueeze(1).expand(-1, num_points_in_pillar[2], + -1) # hw, #p, 2 + # generate points for hw and level 2 + z_ranges = torch.linspace(0.5, tpv_z - 0.5, + num_points_in_pillar[2]) / tpv_z # #p + z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p + h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h + h_ranges = h_ranges.reshape(-1, 1, 1).expand( + -1, tpv_w, num_points_in_pillar[2]).flatten(0, 1) + hw_zh = torch.stack([h_ranges, z_ranges], dim=-1) # hw, #p, 2 + # generate points for hw and level 3 + z_ranges = torch.linspace(0.5, tpv_z - 0.5, + num_points_in_pillar[2]) / tpv_z # #p + z_ranges = z_ranges.unsqueeze(0).expand(tpv_h * tpv_w, -1) # hw, #p + w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w + w_ranges = w_ranges.reshape(1, -1, 1).expand( + tpv_h, -1, num_points_in_pillar[2]).flatten(0, 1) + hw_wz = torch.stack([z_ranges, w_ranges], dim=-1) # hw, #p, 2 + + # generate points for zh and level 1 + w_ranges = torch.linspace(0.5, tpv_w - 0.5, + num_points_in_pillar[1]) / tpv_w + w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) + h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h + h_ranges = h_ranges.reshape(1, -1, 1).expand( + tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) + zh_hw = torch.stack([w_ranges, h_ranges], dim=-1) + # generate points for zh and level 2 + z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z + z_ranges = z_ranges.reshape(-1, 1, 1).expand( + -1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) + h_ranges = torch.linspace(0.5, tpv_h - 0.5, tpv_h) / tpv_h + h_ranges = h_ranges.reshape(1, -1, 1).expand( + tpv_z, -1, num_points_in_pillar[1]).flatten(0, 1) + zh_zh = torch.stack([h_ranges, z_ranges], dim=-1) # zh, #p, 2 + # generate points for zh and level 3 + w_ranges = torch.linspace(0.5, tpv_w - 0.5, + num_points_in_pillar[1]) / tpv_w + w_ranges = w_ranges.unsqueeze(0).expand(tpv_z * tpv_h, -1) + z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z + z_ranges = z_ranges.reshape(-1, 1, 1).expand( + -1, tpv_h, num_points_in_pillar[1]).flatten(0, 1) + zh_wz = torch.stack([z_ranges, w_ranges], dim=-1) + + # generate points for wz and level 1 + h_ranges = torch.linspace(0.5, tpv_h - 0.5, + num_points_in_pillar[0]) / tpv_h + h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) + w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w + w_ranges = w_ranges.reshape(-1, 1, 1).expand( + -1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) + wz_hw = torch.stack([w_ranges, h_ranges], dim=-1) + # generate points for wz and level 2 + h_ranges = torch.linspace(0.5, tpv_h - 0.5, + num_points_in_pillar[0]) / tpv_h + h_ranges = h_ranges.unsqueeze(0).expand(tpv_w * tpv_z, -1) + z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z + z_ranges = z_ranges.reshape(1, -1, 1).expand( + tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) + wz_zh = torch.stack([h_ranges, z_ranges], dim=-1) + # generate points for wz and level 3 + w_ranges = torch.linspace(0.5, tpv_w - 0.5, tpv_w) / tpv_w + w_ranges = w_ranges.reshape(-1, 1, 1).expand( + -1, tpv_z, num_points_in_pillar[0]).flatten(0, 1) + z_ranges = torch.linspace(0.5, tpv_z - 0.5, tpv_z) / tpv_z + z_ranges = z_ranges.reshape(1, -1, 1).expand( + tpv_w, -1, num_points_in_pillar[0]).flatten(0, 1) + wz_wz = torch.stack([z_ranges, w_ranges], dim=-1) + + reference_points = torch.cat([ + torch.stack([hw_hw, hw_zh, hw_wz], dim=1), + torch.stack([zh_hw, zh_zh, zh_wz], dim=1), + torch.stack([wz_hw, wz_zh, wz_wz], dim=1) + ], + dim=0) # hw+zh+wz, 3, #p, 2 + + return reference_points + + @staticmethod + def get_reference_points(H, + W, + Z=8, + num_points_in_pillar=4, + dim='3d', + bs=1, + device='cuda', + dtype=torch.float): + """Get the reference points used in SCA and TSA. + + Args: + H, W: spatial shape of tpv. + Z: height of pillar. + device (obj:`device`): The device where + reference_points should be. + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + + # reference points in 3D space, used in spatial cross-attention (SCA) + zs = torch.linspace( + 0.5, Z - 0.5, num_points_in_pillar, + dtype=dtype, device=device).view(-1, 1, 1).expand( + num_points_in_pillar, H, W) / Z + xs = torch.linspace( + 0.5, W - 0.5, W, dtype=dtype, device=device).view(1, 1, -1).expand( + num_points_in_pillar, H, W) / W + ys = torch.linspace( + 0.5, H - 0.5, H, dtype=dtype, device=device).view(1, -1, 1).expand( + num_points_in_pillar, H, W) / H + ref_3d = torch.stack((xs, ys, zs), -1) + ref_3d = ref_3d.permute(0, 3, 1, 2).flatten(2).permute(0, 2, 1) + ref_3d = ref_3d[None].repeat(bs, 1, 1, 1) + return ref_3d + + def point_sampling(self, reference_points, pc_range, batch_data_smaples): + + lidar2img = [] + for data_sample in batch_data_smaples: + lidar2img.append(data_sample.lidar2img) + lidar2img = np.asarray(lidar2img) + lidar2img = reference_points.new_tensor(lidar2img) # (B, N, 4, 4) + reference_points = reference_points.clone() + + reference_points[..., 0:1] = reference_points[..., 0:1] * \ + (pc_range[3] - pc_range[0]) + pc_range[0] + reference_points[..., 1:2] = reference_points[..., 1:2] * \ + (pc_range[4] - pc_range[1]) + pc_range[1] + reference_points[..., 2:3] = reference_points[..., 2:3] * \ + (pc_range[5] - pc_range[2]) + pc_range[2] + + reference_points = torch.cat( + (reference_points, torch.ones_like(reference_points[..., :1])), -1) + + reference_points = reference_points.permute(1, 0, 2, 3) + D, B, num_query = reference_points.size()[:3] + num_cam = lidar2img.size(1) + + reference_points = reference_points.view(D, B, 1, num_query, 4).repeat( + 1, 1, num_cam, 1, 1).unsqueeze(-1) + + lidar2img = lidar2img.view(1, B, num_cam, 1, 4, + 4).repeat(D, 1, 1, num_query, 1, 1) + + reference_points_cam = torch.matmul( + lidar2img.to(torch.float32), + reference_points.to(torch.float32)).squeeze(-1) + eps = 1e-5 + + tpv_mask = (reference_points_cam[..., 2:3] > eps) + reference_points_cam = reference_points_cam[..., 0:2] / torch.maximum( + reference_points_cam[..., 2:3], + torch.ones_like(reference_points_cam[..., 2:3]) * eps) + + reference_points_cam[..., 0] /= data_sample.batch_input_shape[1] + reference_points_cam[..., 1] /= data_sample.batch_input_shape[0] + + tpv_mask = ( + tpv_mask & (reference_points_cam[..., 1:2] > 0.0) + & (reference_points_cam[..., 1:2] < 1.0) + & (reference_points_cam[..., 0:1] < 1.0) + & (reference_points_cam[..., 0:1] > 0.0)) + + tpv_mask = torch.nan_to_num(tpv_mask) + + reference_points_cam = reference_points_cam.permute(2, 1, 3, 0, 4) + tpv_mask = tpv_mask.permute(2, 1, 3, 0, 4).squeeze(-1) + + return reference_points_cam, tpv_mask + + def forward(self, mlvl_feats, batch_data_samples): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 5D-tensor with shape + (B, N, C, H, W). + """ + bs = mlvl_feats[0].shape[0] + dtype = mlvl_feats[0].dtype + device = mlvl_feats[0].device + + # tpv queries and pos embeds + tpv_queries_hw = self.tpv_embedding_hw.weight.to(dtype) + tpv_queries_zh = self.tpv_embedding_zh.weight.to(dtype) + tpv_queries_wz = self.tpv_embedding_wz.weight.to(dtype) + tpv_queries_hw = tpv_queries_hw.unsqueeze(0).repeat(bs, 1, 1) + tpv_queries_zh = tpv_queries_zh.unsqueeze(0).repeat(bs, 1, 1) + tpv_queries_wz = tpv_queries_wz.unsqueeze(0).repeat(bs, 1, 1) + tpv_query = [tpv_queries_hw, tpv_queries_zh, tpv_queries_wz] + + tpv_pos_hw = self.positional_encoding(bs, device, 'z') + tpv_pos_zh = self.positional_encoding(bs, device, 'w') + tpv_pos_wz = self.positional_encoding(bs, device, 'h') + tpv_pos = [tpv_pos_hw, tpv_pos_zh, tpv_pos_wz] + + # flatten image features of different scales + feat_flatten = [] + spatial_shapes = [] + for lvl, feat in enumerate(mlvl_feats): + bs, num_cam, c, h, w = feat.shape + spatial_shape = (h, w) + feat = feat.flatten(3).permute(1, 0, 3, 2) # num_cam, bs, hw, c + feat = feat + self.cams_embeds[:, None, None, :].to(dtype) + feat = feat + self.level_embeds[None, None, + lvl:lvl + 1, :].to(dtype) + spatial_shapes.append(spatial_shape) + feat_flatten.append(feat) + + feat_flatten = torch.cat(feat_flatten, 2) # num_cam, bs, hw++, c + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + feat_flatten = feat_flatten.permute( + 0, 2, 1, 3) # (num_cam, H*W, bs, embed_dims) + + reference_points_cams, tpv_masks = [], [] + ref_3ds = [self.ref_3d_hw, self.ref_3d_zh, self.ref_3d_wz] + for ref_3d in ref_3ds: + reference_points_cam, tpv_mask = self.point_sampling( + ref_3d, self.pc_range, + batch_data_samples) # num_cam, bs, hw++, #p, 2 + reference_points_cams.append(reference_points_cam) + tpv_masks.append(tpv_mask) + + ref_cross_view = self.cross_view_ref_points.clone().unsqueeze( + 0).expand(bs, -1, -1, -1, -1) + + intermediate = [] + for layer in self.layers: + output = layer( + tpv_query, + feat_flatten, + feat_flatten, + tpv_pos=tpv_pos, + ref_2d=ref_cross_view, + tpv_h=self.tpv_h, + tpv_w=self.tpv_w, + tpv_z=self.tpv_z, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + reference_points_cams=reference_points_cams, + tpv_masks=tpv_masks) + tpv_query = output + if self.return_intermediate: + intermediate.append(output) + + if self.return_intermediate: + return torch.stack(intermediate) + + return output diff --git a/projects/TPVFormer/tpvformer/tpvformer_head.py b/projects/TPVFormer/tpvformer/tpvformer_head.py new file mode 100644 index 0000000000..2c477f1605 --- /dev/null +++ b/projects/TPVFormer/tpvformer/tpvformer_head.py @@ -0,0 +1,298 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmengine.model import BaseModule + +from mmdet3d.registry import MODELS + + +@MODELS.register_module() +class TPVFormerDecoder(BaseModule): + + def __init__(self, + tpv_h, + tpv_w, + tpv_z, + num_classes=20, + in_dims=64, + hidden_dims=128, + out_dims=None, + scale_h=2, + scale_w=2, + scale_z=2, + ignore_index=0, + loss_lovasz=None, + loss_ce=None, + lovasz_input='points', + ce_input='voxel'): + super().__init__() + self.tpv_h = tpv_h + self.tpv_w = tpv_w + self.tpv_z = tpv_z + self.scale_h = scale_h + self.scale_w = scale_w + self.scale_z = scale_z + + out_dims = in_dims if out_dims is None else out_dims + self.in_dims = in_dims + self.decoder = nn.Sequential( + nn.Linear(in_dims, hidden_dims), nn.Softplus(), + nn.Linear(hidden_dims, out_dims)) + + self.classifier = nn.Linear(out_dims, num_classes) + self.loss_lovasz = MODELS.build(loss_lovasz) + self.loss_ce = MODELS.build(loss_ce) + self.ignore_index = ignore_index + self.lovasz_input = lovasz_input + self.ce_input = ce_input + + def forward(self, tpv_list, points=None): + """ + tpv_list[0]: bs, h*w, c + tpv_list[1]: bs, z*h, c + tpv_list[2]: bs, w*z, c + """ + tpv_hw, tpv_zh, tpv_wz = tpv_list[0], tpv_list[1], tpv_list[2] + bs, _, c = tpv_hw.shape + tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) + tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) + tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) + + if self.scale_h != 1 or self.scale_w != 1: + tpv_hw = F.interpolate( + tpv_hw, + size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), + mode='bilinear') + if self.scale_z != 1 or self.scale_h != 1: + tpv_zh = F.interpolate( + tpv_zh, + size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), + mode='bilinear') + if self.scale_w != 1 or self.scale_z != 1: + tpv_wz = F.interpolate( + tpv_wz, + size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), + mode='bilinear') + + if points is not None: + # points: bs, n, 3 + _, n, _ = points.shape + points = points.reshape(bs, 1, n, 3).float() + points[..., + 0] = points[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 + points[..., + 1] = points[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 + points[..., + 2] = points[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 + sample_loc = points[:, :, :, [0, 1]] + tpv_hw_pts = F.grid_sample(tpv_hw, + sample_loc).squeeze(2) # bs, c, n + sample_loc = points[:, :, :, [1, 2]] + tpv_zh_pts = F.grid_sample(tpv_zh, sample_loc).squeeze(2) + sample_loc = points[:, :, :, [2, 0]] + tpv_wz_pts = F.grid_sample(tpv_wz, sample_loc).squeeze(2) + + tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( + -1, -1, -1, -1, self.scale_z * self.tpv_z) + tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( + -1, -1, self.scale_w * self.tpv_w, -1, -1) + tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( + -1, -1, -1, self.scale_h * self.tpv_h, -1) + + fused_vox = (tpv_hw_vox + tpv_zh_vox + tpv_wz_vox).flatten(2) + fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts + fused = torch.cat([fused_vox, fused_pts], dim=-1) # bs, c, whz+n + + fused = fused.permute(0, 2, 1) + if self.use_checkpoint: + fused = torch.utils.checkpoint.checkpoint(self.decoder, fused) + logits = torch.utils.checkpoint.checkpoint( + self.classifier, fused) + else: + fused = self.decoder(fused) + logits = self.classifier(fused) + logits = logits.permute(0, 2, 1) + logits_vox = logits[:, :, :(-n)].reshape(bs, self.classes, + self.scale_w * self.tpv_w, + self.scale_h * self.tpv_h, + self.scale_z * self.tpv_z) + logits_pts = logits[:, :, (-n):].reshape(bs, self.classes, n, 1, 1) + return logits_vox, logits_pts + + else: + tpv_hw = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( + -1, -1, -1, -1, self.scale_z * self.tpv_z) + tpv_zh = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( + -1, -1, self.scale_w * self.tpv_w, -1, -1) + tpv_wz = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( + -1, -1, -1, self.scale_h * self.tpv_h, -1) + + fused = tpv_hw + tpv_zh + tpv_wz + fused = fused.permute(0, 2, 3, 4, 1) + if self.use_checkpoint: + fused = torch.utils.checkpoint.checkpoint(self.decoder, fused) + logits = torch.utils.checkpoint.checkpoint( + self.classifier, fused) + else: + fused = self.decoder(fused) + logits = self.classifier(fused) + logits = logits.permute(0, 4, 1, 2, 3) + + return logits + + def predict(self, tpv_list, batch_data_samples): + """ + tpv_list[0]: bs, h*w, c + tpv_list[1]: bs, z*h, c + tpv_list[2]: bs, w*z, c + """ + tpv_hw, tpv_zh, tpv_wz = tpv_list + bs, _, c = tpv_hw.shape + tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) + tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) + tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) + + if self.scale_h != 1 or self.scale_w != 1: + tpv_hw = F.interpolate( + tpv_hw, + size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), + mode='bilinear') + if self.scale_z != 1 or self.scale_h != 1: + tpv_zh = F.interpolate( + tpv_zh, + size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), + mode='bilinear') + if self.scale_w != 1 or self.scale_z != 1: + tpv_wz = F.interpolate( + tpv_wz, + size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), + mode='bilinear') + + logits = [] + for i, data_sample in enumerate(batch_data_samples): + point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float() + point_coors[ + ..., + 0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 + point_coors[ + ..., + 1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 + point_coors[ + ..., + 2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 + sample_loc = point_coors[..., [0, 1]] + tpv_hw_pts = F.grid_sample( + tpv_hw[i:i + 1], sample_loc, align_corners=False) + sample_loc = point_coors[..., [1, 2]] + tpv_zh_pts = F.grid_sample( + tpv_zh[i:i + 1], sample_loc, align_corners=False) + sample_loc = point_coors[..., [2, 0]] + tpv_wz_pts = F.grid_sample( + tpv_wz[i:i + 1], sample_loc, align_corners=False) + + fused_pts = tpv_hw_pts + tpv_zh_pts + tpv_wz_pts + + fused_pts = fused_pts.squeeze(0).squeeze(1).transpose(0, 1) + fused_pts = self.decoder(fused_pts) + logit = self.classifier(fused_pts) + logits.append(logit) + + return logits + + def loss(self, tpv_list, batch_data_samples): + tpv_hw, tpv_zh, tpv_wz = tpv_list + bs, _, c = tpv_hw.shape + tpv_hw = tpv_hw.permute(0, 2, 1).reshape(bs, c, self.tpv_h, self.tpv_w) + tpv_zh = tpv_zh.permute(0, 2, 1).reshape(bs, c, self.tpv_z, self.tpv_h) + tpv_wz = tpv_wz.permute(0, 2, 1).reshape(bs, c, self.tpv_w, self.tpv_z) + + if self.scale_h != 1 or self.scale_w != 1: + tpv_hw = F.interpolate( + tpv_hw, + size=(self.tpv_h * self.scale_h, self.tpv_w * self.scale_w), + mode='bilinear') + if self.scale_z != 1 or self.scale_h != 1: + tpv_zh = F.interpolate( + tpv_zh, + size=(self.tpv_z * self.scale_z, self.tpv_h * self.scale_h), + mode='bilinear') + if self.scale_w != 1 or self.scale_z != 1: + tpv_wz = F.interpolate( + tpv_wz, + size=(self.tpv_w * self.scale_w, self.tpv_z * self.scale_z), + mode='bilinear') + + batch_pts, batch_vox = [], [] + for i, data_sample in enumerate(batch_data_samples): + point_coors = data_sample.point_coors.reshape(1, 1, -1, 3).float() + point_coors[ + ..., + 0] = point_coors[..., 0] / (self.tpv_w * self.scale_w) * 2 - 1 + point_coors[ + ..., + 1] = point_coors[..., 1] / (self.tpv_h * self.scale_h) * 2 - 1 + point_coors[ + ..., + 2] = point_coors[..., 2] / (self.tpv_z * self.scale_z) * 2 - 1 + sample_loc = point_coors[..., [0, 1]] + tpv_hw_pts = F.grid_sample( + tpv_hw[i:i + 1], sample_loc, align_corners=False) + sample_loc = point_coors[..., [1, 2]] + tpv_zh_pts = F.grid_sample( + tpv_zh[i:i + 1], sample_loc, align_corners=False) + sample_loc = point_coors[..., [2, 0]] + tpv_wz_pts = F.grid_sample( + tpv_wz[i:i + 1], sample_loc, align_corners=False) + fused_pts = (tpv_hw_pts + tpv_zh_pts + + tpv_wz_pts).squeeze(0).squeeze(1) + batch_pts.append(fused_pts) + + tpv_hw_vox = tpv_hw.unsqueeze(-1).permute(0, 1, 3, 2, 4).expand( + -1, -1, -1, -1, self.scale_z * self.tpv_z) + tpv_zh_vox = tpv_zh.unsqueeze(-1).permute(0, 1, 4, 3, 2).expand( + -1, -1, self.scale_w * self.tpv_w, -1, -1) + tpv_wz_vox = tpv_wz.unsqueeze(-1).permute(0, 1, 2, 4, 3).expand( + -1, -1, -1, self.scale_h * self.tpv_h, -1) + fused_vox = tpv_hw_vox + tpv_zh_vox + tpv_wz_vox + voxel_coors = data_sample.voxel_coors.long() + fused_vox = fused_vox[:, :, voxel_coors[:, 0], voxel_coors[:, 1], + voxel_coors[:, 2]] + fused_vox = fused_vox.squeeze(0) + batch_vox.append(fused_vox) + batch_pts = torch.cat(batch_pts, dim=1) + batch_vox = torch.cat(batch_vox, dim=1) + num_points = batch_pts.shape[1] + + logits = self.decoder( + torch.cat([batch_pts, batch_vox], dim=1).transpose(0, 1)) + logits = self.classifier(logits) + pts_logits = logits[:num_points, :] + vox_logits = logits[num_points:, :] + + pts_seg_label = torch.cat([ + data_sample.gt_pts_seg.pts_semantic_mask + for data_sample in batch_data_samples + ]) + voxel_seg_label = torch.cat([ + data_sample.gt_pts_seg.voxel_semantic_mask + for data_sample in batch_data_samples + ]) + if self.ce_input == 'voxel': + ce_input = vox_logits + ce_label = voxel_seg_label + else: + ce_input = pts_logits + ce_label = pts_seg_label + if self.lovasz_input == 'voxel': + lovasz_input = vox_logits + lovasz_label = voxel_seg_label + else: + lovasz_input = pts_logits + lovasz_label = pts_seg_label + + loss = dict() + loss['loss_ce'] = self.loss_ce( + ce_input, ce_label, ignore_index=self.ignore_index) + loss['loss_lovasz'] = self.loss_lovasz( + lovasz_input, lovasz_label, ignore_index=self.ignore_index) + return loss diff --git a/projects/TPVFormer/tpvformer/tpvformer_layer.py b/projects/TPVFormer/tpvformer/tpvformer_layer.py new file mode 100644 index 0000000000..03569fdd12 --- /dev/null +++ b/projects/TPVFormer/tpvformer/tpvformer_layer.py @@ -0,0 +1,223 @@ +import copy +import warnings + +import torch +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import (build_attention, + build_feedforward_network) +from mmengine.config import ConfigDict +from mmengine.model import BaseModule, ModuleList +from mmengine.registry import MODELS + + +@MODELS.register_module() +class TPVFormerLayer(BaseModule): + """Base `TPVFormerLayer` for vision transformer. + + It can be built from `mmcv.ConfigDict` and support more flexible + customization, for example, using any number of `FFN or LN ` and + use different kinds of `attention` by specifying a list of `ConfigDict` + named `attn_cfgs`. It is worth mentioning that it supports `prenorm` + when you specifying `norm` as the first element of `operation_order`. + More details about the `prenorm`: `On Layer Normalization in the + Transformer Architecture `_ . + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for `self_attention` or `cross_attention` modules, + The order of the configs in the list should be consistent with + corresponding attentions in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. Default: None. + ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )): + Configs for FFN, The order of the configs in the list should be + consistent with corresponding ffn in operation_order. + If it is a dict, all of the attention modules in operation_order + will be built with this config. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Support `prenorm` when you specifying first element as `norm`. + Default: None. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='LN'). + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Default: None. + batch_first (bool): Key, Query and Value are shape + of (batch, n, embed_dim) + or (n, batch, embed_dim). Default to False. + """ + + def __init__(self, + attn_cfgs=None, + ffn_cfgs=dict( + type='FFN', + feedforward_channels=1024, + num_fcs=2, + ffn_drop=0., + act_cfg=dict(type='ReLU', inplace=True), + ), + operation_order=None, + norm_cfg=dict(type='LN'), + init_cfg=None, + batch_first=True, + **kwargs): + deprecated_args = dict( + feedforward_channels='feedforward_channels', + ffn_dropout='ffn_drop', + ffn_num_fcs='num_fcs') + for ori_name, new_name in deprecated_args.items(): + if ori_name in kwargs: + warnings.warn( + f'The arguments `{ori_name}` in BaseTransformerLayer ' + f'has been deprecated, now you should set `{new_name}` ' + f'and other FFN related arguments ' + f'to a dict named `ffn_cfgs`. ') + ffn_cfgs[new_name] = kwargs[ori_name] + + super().__init__(init_cfg) + + self.batch_first = batch_first + + num_attn = operation_order.count('self_attn') + operation_order.count( + 'cross_attn') + if isinstance(attn_cfgs, dict): + attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)] + else: + assert num_attn == len(attn_cfgs), f'The length ' \ + f'of attn_cfg {num_attn} is ' \ + f'not consistent with the number of attention' \ + f'in operation_order {operation_order}.' + + self.num_attn = num_attn + self.operation_order = operation_order + self.norm_cfg = norm_cfg + self.pre_norm = operation_order[0] == 'norm' + self.attentions = ModuleList() + + index = 0 + for operation_name in operation_order: + if operation_name in ['self_attn', 'cross_attn']: + if 'batch_first' in attn_cfgs[index]: + assert self.batch_first == attn_cfgs[index]['batch_first'] + else: + attn_cfgs[index]['batch_first'] = self.batch_first + attention = build_attention(attn_cfgs[index]) + # Some custom attentions used as `self_attn` + # or `cross_attn` can have different behavior. + attention.operation_name = operation_name + self.attentions.append(attention) + index += 1 + + self.embed_dims = self.attentions[0].embed_dims + + self.ffns = ModuleList() + num_ffns = operation_order.count('ffn') + if isinstance(ffn_cfgs, dict): + ffn_cfgs = ConfigDict(ffn_cfgs) + if isinstance(ffn_cfgs, dict): + ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)] + assert len(ffn_cfgs) == num_ffns + for ffn_index in range(num_ffns): + if 'embed_dims' not in ffn_cfgs[ffn_index]: + ffn_cfgs[ffn_index]['embed_dims'] = self.embed_dims + else: + assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims + + self.ffns.append(build_feedforward_network(ffn_cfgs[ffn_index])) + + self.norms = ModuleList() + num_norms = operation_order.count('norm') + for _ in range(num_norms): + self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) + + def forward(self, + query, + key=None, + value=None, + tpv_pos=None, + ref_2d=None, + tpv_h=None, + tpv_w=None, + tpv_z=None, + reference_points_cams=None, + tpv_masks=None, + spatial_shapes=None, + level_start_index=None, + **kwargs): + """ + **kwargs contains some specific arguments of attentions. + + Args: + query (Tensor): The input query with shape + [num_queries, bs, embed_dims] if + self.batch_first is False, else + [bs, num_queries embed_dims]. + key (Tensor): The key tensor with shape [num_keys, bs, + embed_dims] if self.batch_first is False, else + [bs, num_keys, embed_dims] . + value (Tensor): The value tensor with same shape as `key`. + tpv_pos (Tensor): The positional encoding for self attn. + Returns: + Tensor: forwarded results with shape + [[bs, num_queries, embed_dims] * 3] for 3 tpv planes. + """ + + norm_index = 0 + attn_index = 0 + ffn_index = 0 + if self.operation_order[0] == 'cross_attn': + query = torch.cat(query, dim=1) + identity = query + + for layer in self.operation_order: + # cross view hybrid-attention + if layer == 'self_attn': + ss = torch.tensor( + [[tpv_h, tpv_w], [tpv_z, tpv_h], [tpv_w, tpv_z]], + device=query[0].device) + lsi = torch.tensor( + [0, tpv_h * tpv_w, tpv_h * tpv_w + tpv_z * tpv_h], + device=query[0].device) + + if not isinstance(query, (list, tuple)): + query = torch.split( + query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], + dim=1) + + query = self.attentions[attn_index]( + query, + identity if self.pre_norm else None, + query_pos=tpv_pos, + reference_points=ref_2d, + spatial_shapes=ss, + level_start_index=lsi, + **kwargs) + attn_index += 1 + query = torch.cat(query, dim=1) + identity = query + + elif layer == 'norm': + query = self.norms[norm_index](query) + norm_index += 1 + + # image cross attention + elif layer == 'cross_attn': + query = self.attentions[attn_index]( + query, + key, + value, + identity if self.pre_norm else None, + reference_points_cams=reference_points_cams, + tpv_masks=tpv_masks, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + **kwargs) + attn_index += 1 + identity = query + + elif layer == 'ffn': + query = self.ffns[ffn_index]( + query, identity if self.pre_norm else None) + ffn_index += 1 + query = torch.split( + query, [tpv_h * tpv_w, tpv_z * tpv_h, tpv_w * tpv_z], dim=1) + return query diff --git a/projects/TR3D/configs/tr3d.py b/projects/TR3D/configs/tr3d.py index 3352857789..c203f0d88a 100644 --- a/projects/TR3D/configs/tr3d.py +++ b/projects/TR3D/configs/tr3d.py @@ -1,4 +1,4 @@ -_base_ = ['mmdet3d::_base_/default_runtime.py'] +_base_ = ['../../../configs/_base_/default_runtime.py'] custom_imports = dict(imports=['projects.TR3D.tr3d']) model = dict( diff --git a/requirements/tests.txt b/requirements/tests.txt index 563fc4682d..e358587a49 100644 --- a/requirements/tests.txt +++ b/requirements/tests.txt @@ -1,4 +1,3 @@ -asynctest codecov flake8 interrogate diff --git a/tools/dataset_converters/nuscenes_converter.py b/tools/dataset_converters/nuscenes_converter.py index 5c76239530..c53a2871f3 100644 --- a/tools/dataset_converters/nuscenes_converter.py +++ b/tools/dataset_converters/nuscenes_converter.py @@ -267,6 +267,11 @@ def _fill_trainval_infos(nusc, [a['num_radar_pts'] for a in annotations]) info['valid_flag'] = valid_flag + if 'lidarseg' in nusc.table_names: + info['pts_semantic_mask_path'] = osp.join( + nusc.dataroot, + nusc.get('lidarseg', lidar_token)['filename']) + if sample['scene_token'] in train_scenes: train_nusc_infos.append(info) else: diff --git a/tools/dataset_converters/update_infos_to_v2.py b/tools/dataset_converters/update_infos_to_v2.py index 200dffff47..a2ddbd0688 100644 --- a/tools/dataset_converters/update_infos_to_v2.py +++ b/tools/dataset_converters/update_infos_to_v2.py @@ -365,6 +365,9 @@ def update_nuscenes_infos(pkl_path, out_dir): temp_data_info[ 'cam_instances'] = generate_nuscenes_camera_instances( ori_info_dict, nusc) + if 'pts_semantic_mask_path' in ori_info_dict: + temp_data_info['pts_semantic_mask_path'] = Path( + ori_info_dict['pts_semantic_mask_path']).name temp_data_info, _ = clear_data_info_unused_keys(temp_data_info) converted_list.append(temp_data_info) pkl_name = Path(pkl_path).name