From 4184679dbdf35340a10823a7d12b9a7c0381d3e5 Mon Sep 17 00:00:00 2001 From: Yuxiao Guo Date: Sun, 25 Jun 2023 13:12:02 +0800 Subject: [PATCH] Initial submission of Swin3D (#6) * Initial commit * update README * init repos * add KNN * remove point ops * update readme * fix fp16 * fix bug of knn * update config * add license * add comment * update readme and license * update readme * Create codeql.yml * update readme * update codeql * update codeql * remove cpp codeql * format code * update model * update readme --------- Co-authored-by: Yukichiii <45515584+Yukichiii@users.noreply.github.com> Co-authored-by: Yuqi Yang Co-authored-by: Yuqi Yang Co-authored-by: Yuxiao Guo --- README.md | 58 ++-- Swin3D/modules/mink_layers.py | 208 ++++++++----- Swin3D/modules/swin3d_layers.py | 519 +++++++++++++++++++++++--------- examples/segmentation_mini.py | 79 +++++ 4 files changed, 615 insertions(+), 249 deletions(-) create mode 100644 examples/segmentation_mini.py diff --git a/README.md b/README.md index ac3bef8..cf46852 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ Initial commits: 1. Pretrained models on Structured3D are provided. -2. The supported code and models for Semantic Segmentation on ScanNet and S3DIS are provided. +2. The supported code for Semantic Segmentation on ScanNet and S3DIS are provided. ## Introduction @@ -37,12 +37,12 @@ We pretrained our Swin3D on Structured3D, please refer to this [link](https://gi The models pretrained on Structured3D with different cRSE are provided here. -| | Pretrain | #params | cRSE | mIoU(val) | Model | Log | -| :------- | :----------: | :------ | :----------- | :-------: | :-------: | :-----: | -| Swin3D-S | Structured3D | 23.57M | XYZ,RGB | 77.69 | [model]() | [log]() | -| Swin3D-S | Structured3D | 23.57M | XYZ,RGB,NORM | 79.15 | [model]() | [log]() | -| Swin3D-L | Structured3D | 60.75M | XYZ,RGB | 79.79 | [model]() | [log]() | -| Swin3D-L | Structured3D | 60.75M | XYZ,RGB,NORM | 81.04 | [model]() | [log]() | +| | Pretrain | #params | cRSE | mIoU(val) | Model | Log | +| :------- | :----------: | :------ | :----------- | :-------: | :-----------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------: | +| Swin3D-S | Structured3D | 23.57M | XYZ,RGB | 77.69 | [model](https://drive.google.com/file/d/1oezNkN3_HZvyxGxjtOpSaQUbGl3YYF90/view?usp=sharing) | [log](https://drive.google.com/file/d/1TuwZqpKm8OYj8BeMhDUhLcGqzXhgJcpC/view?usp=sharing) | +| Swin3D-S | Structured3D | 23.57M | XYZ,RGB,NORM | 79.15 | [model](https://drive.google.com/file/d/1FMmAgHwS__NtFldH-lFTsraKj0my62t4/view?usp=sharing) | [log](https://drive.google.com/file/d/1-0kz81X0j2Zp-mntN1GwQlsm5sLIy3JX/view?usp=sharing) | +| Swin3D-L | Structured3D | 60.75M | XYZ,RGB | 79.79 | [model](https://drive.google.com/file/d/1ior8uAQRiVd2mwfYapcaF_e_R80y7DQm/view?usp=sharing) | [log](https://drive.google.com/file/d/1YYd8SOaAIqz16T7XOL54aGPC4sSoMXsW/view?usp=sharing) | +| Swin3D-L | Structured3D | 60.75M | XYZ,RGB,NORM | 81.04 | [model](https://drive.google.com/file/d/1ySNrP39H6m-euK-2La60-MNOp0e3Pe_4/view?usp=sharing) | [log](https://drive.google.com/file/d/1nXQCw5G2swrSksBnpGBveNSHwAqy8hAZ/view?usp=sharing) | ## Quick Start @@ -61,44 +61,44 @@ Build models and load our pretrained weight, Then you can finetune your model in num_layers=num_layers, stem_transformer=stem_transformer, \ upsample=upsample, first_down_stride=down_stride, \ knn_down=knn_down, in_channels=in_channels, \ - cRSE='XYZ_RGB_NORM', fp16_mode=2) + cRSE='XYZ_RGB_NORM', fp16_mode=1) model.load_pretrained_model(ckpt_path) ## Results and models -To reproduce our results on downstream tasks, please follow the code in this [repo](https://github.com/Yukichiii/Swin3D_Task). The results and models are provided here. +To reproduce our results on downstream tasks, please follow the code in this [repo](https://github.com/Yukichiii/Swin3D_Task). The results are provided here. ### ScanNet Segmentation -| | Pretrained | mIoU(Val) | mIoU(Test) | Model | Log | -| :------- | :--------: | :-------: | :--------: | :-------: | :-----: | -| Swin3D-S | ✗ | 75.2 | - | [model]() | [log]() | -| Swin3D-S | ✓ | 75.7 | - | [model]() | [log]() | -| Swin3D-L | ✓ | 77.5 | 77.9 | [model]() | [log]() | +| | Pretrained | mIoU(Val) | mIoU(Test) | +| :------- | :--------: | :--------: | :--------: | +| Swin3D-S | ✗ | 75.2 | - | +| Swin3D-S | ✓ | 75.6(76.8) | - | +| Swin3D-L | ✓ | 76.2(77.5) | 77.9 | ### S3DIS Segmentation -| | Pretrained | Area 5 mIoU | 6-fold mIoU | Model | Log | -| :------- | :--------: | :---------: | :---------: | :-------: | :-----: | -| Swin3D-S | ✗ | 72.5 | 76.9 | [model]() | [log]() | -| Swin3D-S | ✓ | 73.0 | 78.2 | [model]() | [log]() | -| Swin3D-L | ✓ | 74.5 | 79.8 | [model]() | [log]() | +| | Pretrained | Area 5 mIoU | 6-fold mIoU | +| :------- | :--------: | :---------: | :---------: | +| Swin3D-S | ✗ | 72.5 | 76.9 | +| Swin3D-S | ✓ | 73.0 | 78.2 | +| Swin3D-L | ✓ | 74.5 | 79.8 | ### ScanNet 3D Detection -| | Pretrained | mAP@0.25 | mAP@0.50 | Model | Log | -| :----------------- | :--------: | :------: | :------: | :---: | :---: | -| Swin3D-S+FCAF3D | ✓ | 74.2 | 59.5 | model | log | -| Swin3D-L+FCAF3D | ✓ | 74.2 | 58.6 | model | log | -| Swin3D-S+CAGroup3D | ✓ | 76.4 | 62.7 | model | log | -| Swin3D-L+CAGroup3D | ✓ | 76.4 | 63.2 | model | log | +| | Pretrained | mAP@0.25 | mAP@0.50 | +| :----------------- | :--------: | :------: | :------: | +| Swin3D-S+FCAF3D | ✓ | 74.2 | 59.5 | +| Swin3D-L+FCAF3D | ✓ | 74.2 | 58.6 | +| Swin3D-S+CAGroup3D | ✓ | 76.4 | 62.7 | +| Swin3D-L+CAGroup3D | ✓ | 76.4 | 63.2 | ### S3DIS 3D Detection -| | Pretrained | mAP@0.25 | mAP@0.50 | Model | Log | -| :-------------- | :--------: | :------: | :------: | :---: | :---: | -| Swin3D-S+FCAF3D | ✓ | 69.9 | 50.2 | model | log | -| Swin3D-L+FCAF3D | ✓ | 72.1 | 54.0 | model | log | +| | Pretrained | mAP@0.25 | mAP@0.50 | +| :-------------- | :--------: | :------: | :------: | +| Swin3D-S+FCAF3D | ✓ | 69.9 | 50.2 | +| Swin3D-L+FCAF3D | ✓ | 72.1 | 54.0 | ## Citation diff --git a/Swin3D/modules/mink_layers.py b/Swin3D/modules/mink_layers.py index f13a85f..7843195 100644 --- a/Swin3D/modules/mink_layers.py +++ b/Swin3D/modules/mink_layers.py @@ -6,13 +6,28 @@ import torch.nn as nn import torch.nn.functional as F import MinkowskiEngine as ME -import numpy as np +import numpy as np + def assign_feats(sp, x): - return ME.SparseTensor(features=x.float(), coordinate_map_key=sp.coordinate_map_key, coordinate_manager=sp.coordinate_manager) + return ME.SparseTensor( + features=x.float(), + coordinate_map_key=sp.coordinate_map_key, + coordinate_manager=sp.coordinate_manager, + ) + class MinkConvBN(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, dimension=3): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + bias=False, + dimension=3, + ): super().__init__() self.conv_layers = nn.Sequential( ME.MinkowskiConvolution( @@ -22,16 +37,27 @@ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation= stride=stride, dilation=dilation, bias=bias, - dimension=dimension), - ME.MinkowskiBatchNorm(out_channels) + dimension=dimension, + ), + ME.MinkowskiBatchNorm(out_channels), ) def forward(self, x): x = self.conv_layers(x) return x + class MinkConvBNRelu(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=False, dimension=3): + def __init__( + self, + in_channels, + out_channels, + kernel_size=3, + stride=1, + dilation=1, + bias=False, + dimension=3, + ): super().__init__() self.conv_layers = nn.Sequential( ME.MinkowskiConvolution( @@ -41,9 +67,10 @@ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, dilation= stride=stride, dilation=dilation, bias=bias, - dimension=dimension), + dimension=dimension, + ), ME.MinkowskiBatchNorm(out_channels), - ME.MinkowskiReLU(inplace=True) + ME.MinkowskiReLU(inplace=True), ) def forward(self, x): @@ -52,8 +79,18 @@ def forward(self, x): x = assign_feats(x, x.F.float()) return x + class MinkDeConvBNRelu(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, bias=False, dimension=3): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride, + dilation=1, + bias=False, + dimension=3, + ): super().__init__() self.conv_layers = nn.Sequential( ME.MinkowskiConvolutionTranspose( @@ -63,54 +100,58 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, dilation=1, b stride=stride, dilation=dilation, bias=bias, - dimension=dimension), + dimension=dimension, + ), ME.MinkowskiBatchNorm(out_channels), - ME.MinkowskiReLU() + ME.MinkowskiReLU(), ) def forward(self, x): x = self.conv_layers(x) return x -class MinkResBlock(nn.Module): - def __init__(self, in_channels, out_channels, stride=1, dilation=1): - super(MinkResBlock, self).__init__() - self.conv1 = ME.MinkowskiConvolution( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=3, - stride=stride, - dilation=dilation, - bias=False, - dimension=3) - self.norm1 = ME.MinkowskiBatchNorm(out_channels) - self.conv2 = ME.MinkowskiConvolution( - in_channels=out_channels, - out_channels=out_channels, - kernel_size=3, - stride=1, - dilation=dilation, - bias=False, - dimension=3) +class MinkResBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, dilation=1): + super(MinkResBlock, self).__init__() + + self.conv1 = ME.MinkowskiConvolution( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + dilation=dilation, + bias=False, + dimension=3, + ) + self.norm1 = ME.MinkowskiBatchNorm(out_channels) + self.conv2 = ME.MinkowskiConvolution( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=1, + dilation=dilation, + bias=False, + dimension=3, + ) - self.norm2 = ME.MinkowskiBatchNorm(out_channels) - self.relu = ME.MinkowskiReLU(inplace=True) + self.norm2 = ME.MinkowskiBatchNorm(out_channels) + self.relu = ME.MinkowskiReLU(inplace=True) - def forward(self, x): - residual = x + def forward(self, x): + residual = x - out = self.conv1(x) - out = self.norm1(out) - out = self.relu(out) + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) - out = self.conv2(out) - out = self.norm2(out) + out = self.conv2(out) + out = self.norm2(out) - out += residual - out = self.relu(out) + out += residual + out = self.relu(out) - return out + return out class SparseTensorLinear(nn.Module): @@ -134,22 +175,33 @@ class MinkResBlock_v2(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() d_2 = out_channels // 4 - self.conv1 = torch.nn.Sequential(SparseTensorLinear(in_channels, d_2, bias=False), ME.MinkowskiBatchNorm(d_2), ME.MinkowskiReLU()) - self.unary_2 = torch.nn.Sequential(SparseTensorLinear(d_2, out_channels, bias=False), ME.MinkowskiBatchNorm(out_channels), ME.MinkowskiReLU()) + self.conv1 = torch.nn.Sequential( + SparseTensorLinear(in_channels, d_2, bias=False), + ME.MinkowskiBatchNorm(d_2), + ME.MinkowskiReLU(), + ) + self.unary_2 = torch.nn.Sequential( + SparseTensorLinear(d_2, out_channels, bias=False), + ME.MinkowskiBatchNorm(out_channels), + ME.MinkowskiReLU(), + ) self.spconv = ME.MinkowskiConvolution( - in_channels=d_2, - out_channels=d_2, - kernel_size=5, - stride=1, - dilation=1, - bias=False, - dimension=3) + in_channels=d_2, + out_channels=d_2, + kernel_size=5, + stride=1, + dilation=1, + bias=False, + dimension=3, + ) if in_channels != out_channels: self.shortcut_op = torch.nn.Sequential( - SparseTensorLinear(in_channels, out_channels, bias=False), ME.MinkowskiBatchNorm(out_channels) + SparseTensorLinear(in_channels, out_channels, bias=False), + ME.MinkowskiBatchNorm(out_channels), ) else: self.shortcut_op = nn.Identity() + def forward(self, x): # feats: [N, C] # xyz: [N, 3] @@ -162,28 +214,32 @@ def forward(self, x): shortcut = self.shortcut_op(shortcut) x += shortcut return x - + class MinkResBlock_BottleNeck(nn.Module): - def __init__(self, in_channels, out_channels): - super(MinkResBlock_BottleNeck, self).__init__() - bottle_neck = out_channels // 4 - self.conv1x1a = MinkConvBNRelu(in_channels, bottle_neck, kernel_size=1, stride=1) - self.conv3x3 = MinkConvBNRelu(bottle_neck, bottle_neck, kernel_size=3, stride=1) - self.conv1x1b = MinkConvBN(bottle_neck, out_channels, kernel_size=1, stride=1) - if in_channels != out_channels: - self.conv1x1c = MinkConvBN(in_channels, out_channels, kernel_size=1, stride=1) - else: - self.conv1x1c = None - self.relu = ME.MinkowskiReLU(inplace=True) - - def forward(self, x): - residual = x - out = self.conv1x1a(x) - out = self.conv3x3(out) - out = self.conv1x1b(out) - if self.conv1x1c is not None: - residual = self.conv1x1c(residual) - out = self.relu(out+residual) - - return out + def __init__(self, in_channels, out_channels): + super(MinkResBlock_BottleNeck, self).__init__() + bottle_neck = out_channels // 4 + self.conv1x1a = MinkConvBNRelu( + in_channels, bottle_neck, kernel_size=1, stride=1 + ) + self.conv3x3 = MinkConvBNRelu(bottle_neck, bottle_neck, kernel_size=3, stride=1) + self.conv1x1b = MinkConvBN(bottle_neck, out_channels, kernel_size=1, stride=1) + if in_channels != out_channels: + self.conv1x1c = MinkConvBN( + in_channels, out_channels, kernel_size=1, stride=1 + ) + else: + self.conv1x1c = None + self.relu = ME.MinkowskiReLU(inplace=True) + + def forward(self, x): + residual = x + out = self.conv1x1a(x) + out = self.conv3x3(out) + out = self.conv1x1b(out) + if self.conv1x1c is not None: + residual = self.conv1x1c(residual) + out = self.relu(out + residual) + + return out diff --git a/Swin3D/modules/swin3d_layers.py b/Swin3D/modules/swin3d_layers.py index 1df6895..e96b67b 100644 --- a/Swin3D/modules/swin3d_layers.py +++ b/Swin3D/modules/swin3d_layers.py @@ -8,20 +8,37 @@ from timm.models.layers import DropPath, trunc_normal_ import MinkowskiEngine as ME from MinkowskiEngine import SparseTensor -from Swin3D.modules.mink_layers import assign_feats, SparseTensorLayerNorm, SparseTensorLinear -from Swin3D.sparse_dl.attn.attn_coff import SelfAttnAIOFunction, PosEmb, TableDims, IndexMode, PrecisionMode +from Swin3D.modules.mink_layers import ( + assign_feats, + SparseTensorLayerNorm, + SparseTensorLinear, +) +from Swin3D.sparse_dl.attn.attn_coff import ( + SelfAttnAIOFunction, + PosEmb, + TableDims, + IndexMode, + PrecisionMode, +) import Swin3D.sparse_dl.knn from Swin3D.sparse_dl.knn import KNN -def query_knn_feature(K, src_xyz, query_xyz, src_feat, src_offset, query_offset, return_idx=False): - """ - gather feature in the KNN neighborhood + +def query_knn_feature( + K, src_xyz, query_xyz, src_feat, src_offset, query_offset, return_idx=False +): + """ + gather feature in the KNN neighborhood """ - assert src_xyz.is_contiguous() and query_xyz.is_contiguous() and src_feat.is_contiguous() + assert ( + src_xyz.is_contiguous() + and query_xyz.is_contiguous() + and src_feat.is_contiguous() + ) if query_xyz is None: query_xyz = src_xyz query_offset = src_offset - + idx, _ = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) n, m, c = src_xyz.shape[0], query_xyz.shape[0], src_feat.shape[1] @@ -32,12 +49,19 @@ def query_knn_feature(K, src_xyz, query_xyz, src_feat, src_offset, query_offset, else: return grouped_feat -def knn_linear_interpolation(src_xyz, query_xyz, src_feat, src_offset, query_offset, K=3): - """ - interpolation feature using distance in KNN neighborhood + +def knn_linear_interpolation( + src_xyz, query_xyz, src_feat, src_offset, query_offset, K=3 +): + """ + interpolation feature using distance in KNN neighborhood """ N, C = query_xyz.shape[0], src_feat.shape[1] - assert src_xyz.is_contiguous() and query_xyz.is_contiguous() and src_feat.is_contiguous() + assert ( + src_xyz.is_contiguous() + and query_xyz.is_contiguous() + and src_feat.is_contiguous() + ) # (N, K) idx, dist = KNN.apply(K, src_xyz, query_xyz, src_offset, query_offset) weight = 1.0 / (dist + 1e-8) @@ -49,14 +73,16 @@ def knn_linear_interpolation(src_xyz, query_xyz, src_feat, src_offset, query_off return query_feat -def sparse_self_attention(w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: str = 'v1'): +def sparse_self_attention( + w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: str = "v1" +): """ Args: indices [torch.Tensor]: sparse window index with shape [N, 2], N is the total number of non-empty voxels with indices (window_id, within_window_id). window_id - is ordered and starts from 0; within_window_id is a sparse index to indicate the + is ordered and starts from 0; within_window_id is a sparse index to indicate the offset of kernel_size ** 3. - feats [torch.Tensor]: sprase features of each non-empty voxel with shape [N, C] + feats [torch.Tensor]: sprase features of each non-empty voxel with shape [N, C] Outputs: [M, 3]: sparse indices of cofficient matrix (window_id, att_a_id, att_b_id). att_a_id and att_b_id are the within_window_id @@ -69,28 +95,38 @@ def sparse_self_attention(w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: """ w_sizes_2 = w_sizes**2 - # w2n_indices - [W], mapping window index to window global offset in input + # w2n_indices - [W], mapping window index to window global offset in input # space w_cumsum = torch.cumsum(w_sizes, dim=-1) - w2n_indices = torch.cat([torch.zeros(1, dtype=w_cumsum.dtype, device=w_cumsum.device), w_cumsum[:-1]]) + w2n_indices = torch.cat( + [torch.zeros(1, dtype=w_cumsum.dtype, device=w_cumsum.device), w_cumsum[:-1]] + ) # w2m indices - [W], mapping window index to window global offset in output # space w2_cumsum = torch.cumsum(w_sizes_2, dim=-1) - w2m_indices = torch.cat([torch.zeros(1, dtype=w2_cumsum.dtype, device=w2_cumsum.device), w2_cumsum[:-1]]) + w2m_indices = torch.cat( + [torch.zeros(1, dtype=w2_cumsum.dtype, device=w2_cumsum.device), w2_cumsum[:-1]] + ) # m2w indices - [M], mapping element global offset to the window index - m2w_indices = torch.zeros([w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device) - m2w_offset = torch.zeros([w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device) + m2w_indices = torch.zeros( + [w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device + ) + m2w_offset = torch.zeros( + [w2_cumsum[-1]], dtype=w_sizes.dtype, device=w_sizes.device + ) m2w_indices[w2m_indices[1:]] = 1 m2w_offset[w2m_indices[1:]] = w_sizes_2[:-1] m2w_indices = torch.cumsum(m2w_indices, dim=-1) m2w_offset = torch.cumsum(m2w_offset, dim=-1) # m_indices = [M], element global offset in output space - m_indices = torch.arange(0, w2_cumsum[-1], dtype=w_sizes.dtype, device=w_sizes.device) + m_indices = torch.arange( + 0, w2_cumsum[-1], dtype=w_sizes.dtype, device=w_sizes.device + ) - # m2n_indices - [M], mapping element global offset to the window global offset + # m2n_indices - [M], mapping element global offset to the window global offset # in input space m2n_indices = w2n_indices[m2w_indices] @@ -101,20 +137,28 @@ def sparse_self_attention(w_w_id: torch.Tensor, w_sizes: torch.Tensor, protocol: # print_log_main("m2n_indices:", m2n_indices, m2n_indices.shape) y_offset = m2n_indices + m_offset % m2w_sizes - x_offset = m2n_indices + torch.div(m_offset, m2w_sizes, rounding_mode='floor') + x_offset = m2n_indices + torch.div(m_offset, m2w_sizes, rounding_mode="floor") # print_log_main("=================================") # print_log_main(w_sizes[:5]) # print_log_main(x_offset[:50]) # print_log_main(y_offset[:50]) # coord = torch.stack([m2w_indices, w_w_id[x_offset], w_w_id[y_offset]], axis=-1) - if protocol == 'v1': + if protocol == "v1": return x_offset, y_offset - elif protocol == 'v2': + elif protocol == "v2": return x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices + class Mlp(nn.Module): - def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features @@ -131,18 +175,25 @@ def forward(self, x): x = self.drop(x) return x + class GridCoordsDown(nn.Module): - """ - downsample the grid coordinates - keep the nearest point to the average point of the downsampled grid + """ + downsample the grid coordinates + keep the nearest point to the average point of the downsampled grid """ def __init__(self, stride): super().__init__() self.stride = stride - self.avg_pool = ME.MinkowskiAvgPooling(kernel_size=self.stride, stride=self.stride, dimension=3) - self.unpool = ME.MinkowskiPoolingTranspose(kernel_size=stride, stride=stride, dimension=3) - self.max_pool = ME.MinkowskiMaxPooling(kernel_size=self.stride, stride=self.stride, dimension=3) + self.avg_pool = ME.MinkowskiAvgPooling( + kernel_size=self.stride, stride=self.stride, dimension=3 + ) + self.unpool = ME.MinkowskiPoolingTranspose( + kernel_size=stride, stride=stride, dimension=3 + ) + self.max_pool = ME.MinkowskiMaxPooling( + kernel_size=self.stride, stride=self.stride, dimension=3 + ) def forward(self, coords_sp, sp, return_map=False): device = sp.C.device @@ -153,18 +204,24 @@ def forward(self, coords_sp, sp, return_map=False): avg_coords_sp = self.avg_pool(coords_sp) dist_sp = self.unpool(avg_coords_sp) - coords_sp dist = dist_sp.F - dist = -torch.sqrt((dist ** 2).sum(dim=1)).unsqueeze(1) + dist = -torch.sqrt((dist**2).sum(dim=1)).unsqueeze(1) dist_sp = assign_feats(dist_sp, dist) min_dist_sp = self.max_pool(dist_sp) - map_pair = sp.coordinate_manager.kernel_map(dist_sp.coordinate_map_key, min_dist_sp.coordinate_map_key, stride=self.stride, kernel_size=self.stride, is_pool=True)[0] + map_pair = sp.coordinate_manager.kernel_map( + dist_sp.coordinate_map_key, + min_dist_sp.coordinate_map_key, + stride=self.stride, + kernel_size=self.stride, + is_pool=True, + )[0] in_map, out_map = map_pair broad_min_dist_sp = self.unpool(min_dist_sp) - mask = (broad_min_dist_sp.F==dist_sp.F).squeeze(1) + mask = (broad_min_dist_sp.F == dist_sp.F).squeeze(1) in_map = in_map[mask].long() out_map = out_map[mask].long() downsample_map = torch.zeros(N, dtype=torch.long, device=device) - 1 downsample_map[out_map] = in_map - assert (downsample_map>=0).all() + assert (downsample_map >= 0).all() assert (dist_sp.F[downsample_map] == min_dist_sp.F).all() new_coords = coords_sp.F[downsample_map] new_coords_sp = assign_feats(sp, new_coords) @@ -173,27 +230,32 @@ def forward(self, coords_sp, sp, return_map=False): else: return new_coords_sp + def get_offset(batch): offset = [] - bs = batch.max()+1 + bs = batch.max() + 1 for i in range(bs): - offset.append(torch.sum(batch==i)) + offset.append(torch.sum(batch == i)) offset = torch.cuda.IntTensor(offset) offset = offset.cumsum(dim=0).int() return offset + class GridDownsample(nn.Module): """ - use stride to downsample voxel - use grid maxpooling with kernel_size + use stride to downsample voxel + use grid maxpooling with kernel_size """ + def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): super().__init__() self.kernel_size = kernel_size self.stride = stride self.in_channels = in_channels self.out_channels = out_channels - self.sp_pool = ME.MinkowskiMaxPooling(kernel_size=kernel_size, stride=stride, dimension=3) + self.sp_pool = ME.MinkowskiMaxPooling( + kernel_size=kernel_size, stride=stride, dimension=3 + ) self.coords_pool = GridCoordsDown(stride=stride) self.norm = SparseTensorLayerNorm(in_channels) self.linear = SparseTensorLinear(in_channels, out_channels) @@ -202,29 +264,33 @@ def forward(self, sp, coords_sp): sp_down = self.sp_pool(self.linear(self.norm(sp))) coords_sp_down = self.coords_pool(coords_sp, sp_down) return sp_down, coords_sp_down + def extra_repr(self) -> str: return f"kernel_size={self.kernel_size}, stride={self.stride}, in_channels={self.in_channels}, out_channels={self.out_channels}" + class GridKNNDownsample(nn.Module): """ - use stride to downsample voxel - use KNN to do maxpooling + use stride to downsample voxel + use KNN to do maxpooling """ + def __init__(self, in_channels, out_channels, kernel_size=2, stride=2): super().__init__() self.stride = stride self.in_channels = in_channels self.out_channels = out_channels self.k = 16 - self.sp_pool = ME.MinkowskiMaxPooling(kernel_size=stride, stride=stride, dimension=3) + self.sp_pool = ME.MinkowskiMaxPooling( + kernel_size=stride, stride=stride, dimension=3 + ) self.coords_pool = GridCoordsDown(stride=stride) self.norm = nn.LayerNorm(in_channels) self.linear = nn.Linear(in_channels, out_channels, bias=False) self.pool = nn.MaxPool1d(self.k) - def forward(self, sp, coords_sp): - # calculate the voxel + # calculate the voxel sp_down = self.sp_pool(sp) # for downsampled cRSE coords_sp_down = self.coords_pool(coords_sp, sp_down) @@ -235,7 +301,11 @@ def forward(self, sp, coords_sp): n_xyz = coords_sp_down.F[:, 1:4].detach().contiguous() feats = query_knn_feature(self.k, xyz, n_xyz, sp.F, offset, n_offset) m, k, c = feats.shape - feats = self.linear(self.norm(feats.view(m*k, c)).view(m, k, c)).transpose(1, 2).contiguous() + feats = ( + self.linear(self.norm(feats.view(m * k, c)).view(m, k, c)) + .transpose(1, 2) + .contiguous() + ) feats = self.pool(feats).squeeze(-1) sp = assign_feats(sp_down, feats.float()) coords_sp = coords_sp_down @@ -247,31 +317,47 @@ def extra_repr(self) -> str: class Upsample(nn.Module): """ - upsample using trilinear interpolation - follower by attn block according to self.attn + upsample using trilinear interpolation + follower by attn block according to self.attn """ - def __init__(self, in_channels, out_channels, num_heads, window_size, quant_size, attn=True, up_k=3, cRSE='XYZ_RGB', fp16_mode=0): + + def __init__( + self, + in_channels, + out_channels, + num_heads, + window_size, + quant_size, + attn=True, + up_k=3, + cRSE="XYZ_RGB", + fp16_mode=0, + ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels - self.linear1 = nn.Sequential(nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels)) - self.linear2 = nn.Sequential(nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels)) + self.linear1 = nn.Sequential( + nn.LayerNorm(out_channels), nn.Linear(out_channels, out_channels) + ) + self.linear2 = nn.Sequential( + nn.LayerNorm(in_channels), nn.Linear(in_channels, out_channels) + ) self.up_k = up_k - self.attn = attn and window_size>0 + self.attn = attn and window_size > 0 if self.attn: self.block = BasicLayer( - dim=out_channels, - depth=1, - num_heads=num_heads, + dim=out_channels, + depth=1, + num_heads=num_heads, window_size=window_size, quant_size=quant_size, - drop_path=0.1, + drop_path=0.1, downsample=None, out_channels=None, cRSE=cRSE, - fp16_mode=fp16_mode - ) + fp16_mode=fp16_mode, + ) def forward(self, sp, coords_sp, sp_up, coords_sp_up): feats = sp.F @@ -281,16 +367,20 @@ def forward(self, sp, coords_sp, sp_up, coords_sp_up): offset = get_offset(sp.C[:, 0]) support_offset = get_offset(sp_up.C[:, 0]) - feats = self.linear1(support_feats) + knn_linear_interpolation(xyz, support_xyz, self.linear2(feats), offset, support_offset, K=self.up_k) + feats = self.linear1(support_feats) + knn_linear_interpolation( + xyz, support_xyz, self.linear2(feats), offset, support_offset, K=self.up_k + ) sp_up = assign_feats(sp_up, feats) if self.attn: sp_up, _, _ = self.block(sp_up, coords_sp_up) return sp_up + def extra_repr(self) -> str: return f"up_k={self.up_k}, in_channels={self.in_channels}, out_channels={self.out_channels}, attn={self.attn}" + class WindowAttention(nn.Module): - """ + """ Window based multi-head self attention (W-MSA) module with cRSE. Designed for sparse structure It supports both of shifted and non-shifted window. @@ -311,14 +401,25 @@ class WindowAttention(nn.Module): 2: fp16 forward and fp16 backward """ - def __init__(self, dim, window_size, quant_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., cRSE='XYZ_RGB', fp16_mode=0): - + def __init__( + self, + dim, + window_size, + quant_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + cRSE="XYZ_RGB", + fp16_mode=0, + ): super().__init__() self.dim = dim self.window_size = window_size self.num_heads = num_heads head_dim = dim // num_heads - self.scale = qk_scale or head_dim ** -0.5 + self.scale = qk_scale or head_dim**-0.5 # color in [-1, 1], color_windowsize = 2 # normal in [-1, 1], normal_windowsize = 2 @@ -329,47 +430,45 @@ def __init__(self, dim, window_size, quant_size, num_heads, qkv_bias=True, qk_sc table_offsets = [] self.cRSE = cRSE - if 'XYZ' in cRSE: + if "XYZ" in cRSE: self.xyz_quant_size = quant_size quant_grid_length_xyz = window_size * self.xyz_quant_size table_shape_xyz = (3, 2 * quant_grid_length_xyz, num_heads, head_dim) self.query_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) - trunc_normal_(self.query_xyz_table, std=.02) + trunc_normal_(self.query_xyz_table, std=0.02) self.key_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) - trunc_normal_(self.key_xyz_table, std=.02) + trunc_normal_(self.key_xyz_table, std=0.02) self.value_xyz_table = nn.Parameter(torch.zeros(table_shape_xyz)) - trunc_normal_(self.value_xyz_table, std=.02) - table_offsets += [np.prod(table_shape_xyz[1:])]*3 + trunc_normal_(self.value_xyz_table, std=0.02) + table_offsets += [np.prod(table_shape_xyz[1:])] * 3 - if 'RGB' in cRSE: + if "RGB" in cRSE: self.color_quant_size = quant_size * 2 quant_grid_length_rgb = self.color_windowsize * self.color_quant_size - table_shape_rgb = (3, 2*quant_grid_length_rgb, num_heads, head_dim) + table_shape_rgb = (3, 2 * quant_grid_length_rgb, num_heads, head_dim) self.query_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) - trunc_normal_(self.query_rgb_table, std=.02) + trunc_normal_(self.query_rgb_table, std=0.02) self.key_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) - trunc_normal_(self.key_rgb_table, std=.02) + trunc_normal_(self.key_rgb_table, std=0.02) self.value_rgb_table = nn.Parameter(torch.zeros(table_shape_rgb)) - trunc_normal_(self.value_rgb_table, std=.02) - table_offsets += [np.prod(table_shape_rgb[1:])]*3 - - if 'NORM' in cRSE: + trunc_normal_(self.value_rgb_table, std=0.02) + table_offsets += [np.prod(table_shape_rgb[1:])] * 3 + + if "NORM" in cRSE: self.normal_quant_size = quant_size * 2 quant_grid_length_norm = self.normal_windowsize * self.normal_quant_size - table_shape_norm = (3, 2*quant_grid_length_norm, num_heads, head_dim) + table_shape_norm = (3, 2 * quant_grid_length_norm, num_heads, head_dim) self.query_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) - trunc_normal_(self.query_norm_table, std=.02) + trunc_normal_(self.query_norm_table, std=0.02) self.key_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) - trunc_normal_(self.key_norm_table, std=.02) + trunc_normal_(self.key_norm_table, std=0.02) self.value_norm_table = nn.Parameter(torch.zeros(table_shape_norm)) - trunc_normal_(self.value_norm_table, std=.02) - table_offsets += [np.prod(table_shape_norm[1:])]*3 + trunc_normal_(self.value_norm_table, std=0.02) + table_offsets += [np.prod(table_shape_norm[1:])] * 3 self.table_offsets = table_offsets self.quant_size = quant_size - self.quant_grid_length_xyz = quant_grid_length_xyz - self.quant_grid_length_rgb = quant_grid_length_rgb self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop, inplace=True) @@ -379,7 +478,7 @@ def __init__(self, dim, window_size, quant_size, num_heads, qkv_bias=True, qk_sc self.softmax = nn.Softmax(dim=-1) def forward(self, feats: torch.Tensor, attn_args): - """ Forward function. + """Forward function. Args: feats: N, C @@ -387,34 +486,46 @@ def forward(self, feats: torch.Tensor, attn_args): """ num_v, _ = feats.shape num_sc = self.dim // self.num_heads - - x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, n2n_indices, w2m_indices,\ - n_coords = attn_args - + + ( + x_offset, + y_offset, + m2w_indices, + w_sizes, + w2n_indices, + n2n_indices, + w2m_indices, + n_coords, + ) = attn_args + # Query, Key, Value qkv = self.qkv(feats) - qkv = qkv.reshape(num_v, 3, self.num_heads, num_sc).permute(1, 0, 2, 3).contiguous() + qkv = ( + qkv.reshape(num_v, 3, self.num_heads, num_sc) + .permute(1, 0, 2, 3) + .contiguous() + ) query, key, value = qkv[0], qkv[1], qkv[2] # [N, num_heads, C//num_heads] query = query * self.scale - + table_offsets = torch.IntTensor(self.table_offsets).cuda() - query_table, key_table, value_table = [], [] ,[] + query_table, key_table, value_table = [], [], [] n_cRSE = [] - if 'XYZ' in self.cRSE: + if "XYZ" in self.cRSE: n_xyz = n_coords[:, 0:3] n_xyz = n_xyz * self.quant_size n_cRSE.append(n_xyz) query_table.append(self.query_xyz_table.view(-1)) key_table.append(self.key_xyz_table.view(-1)) value_table.append(self.value_xyz_table.view(-1)) - if 'RGB' in self.cRSE: + if "RGB" in self.cRSE: n_rgb = n_coords[:, 3:6] n_rgb = n_rgb * self.color_quant_size n_cRSE.append(n_rgb) query_table.append(self.query_rgb_table.view(-1)) key_table.append(self.key_rgb_table.view(-1)) value_table.append(self.value_rgb_table.view(-1)) - if 'NORM' in self.cRSE: + if "NORM" in self.cRSE: n_norm = n_coords[:, 6:9] n_norm = n_norm * self.normal_quant_size n_cRSE.append(n_norm) @@ -429,19 +540,31 @@ def forward(self, feats: torch.Tensor, attn_args): key_table = torch.cat(key_table) value_table = torch.cat(value_table) - if self.fp16_mode==0: + if self.fp16_mode == 0: # do not use fp16 # cast q,k,v to fp32 in forward and backward fp16_mode = PrecisionMode.HALF_NONE - elif self.fp16_mode==1: + elif self.fp16_mode == 1: # use fp16 only in forward fp16_mode = PrecisionMode.HALF_FORWARD - elif self.fp16_mode==2: + elif self.fp16_mode == 2: # use fp16 both in forward and backward fp16_mode = PrecisionMode.HALF_ALL - updated_values = SelfAttnAIOFunction.apply(query, key, value, query_table, key_table, value_table, table_offsets, indices, \ - PosEmb.SEPARATE, TableDims.D0, IndexMode.INDIRECT, fp16_mode) + updated_values = SelfAttnAIOFunction.apply( + query, + key, + value, + query_table, + key_table, + value_table, + table_offsets, + indices, + PosEmb.SEPARATE, + TableDims.D0, + IndexMode.INDIRECT, + fp16_mode, + ) updated_values = updated_values.flatten(1) updated_feats = updated_values.view(num_v, self.dim) @@ -451,34 +574,59 @@ def forward(self, feats: torch.Tensor, attn_args): return updated_feats + class SwinTransformerBlock(nn.Module): - def __init__(self, dim, num_heads, window_size, quant_size, drop_path=0.0,\ - mlp_ratio=4.0, qkv_bias=True, qk_scale=None, act_layer=nn.GELU, norm_layer=nn.LayerNorm, cRSE='XYZ_RGB', fp16_mode=0): + def __init__( + self, + dim, + num_heads, + window_size, + quant_size, + drop_path=0.0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + cRSE="XYZ_RGB", + fp16_mode=0, + ): super().__init__() self.window_size = window_size self.norm1 = norm_layer(dim) - self.attn = WindowAttention(dim, window_size=self.window_size, quant_size=quant_size, num_heads=num_heads, \ - qkv_bias=qkv_bias, qk_scale=qk_scale, cRSE=cRSE, fp16_mode=fp16_mode) - - self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.attn = WindowAttention( + dim, + window_size=self.window_size, + quant_size=quant_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + cRSE=cRSE, + fp16_mode=fp16_mode, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) - self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer + ) def forward(self, feats, attn_args): # feats: [N, c] short_cut = feats feats = self.norm1(feats) - feats = self.attn(feats, attn_args) # [N, c] - + feats = self.attn(feats, attn_args) # [N, c] + feats = short_cut + self.drop_path(feats) feats = feats + self.drop_path(self.mlp(self.norm2(feats))) return feats + class BasicLayer(nn.Module): - """ A basic Swin3D layer for one stage. + """A basic Swin3D layer for one stage. Args: dim (int): Number of input channels. @@ -499,9 +647,24 @@ class BasicLayer(nn.Module): 2: fp16 forward and fp16 backward """ - def __init__(self, dim, depth, num_heads, window_size, quant_size, out_channels=None, - mlp_ratio=4., qkv_bias=True, qk_scale=None, - drop_path=0., norm_layer=nn.LayerNorm, downsample=None, down_stride=2, cRSE='XYZ_RGB', fp16_mode=0): + def __init__( + self, + dim, + depth, + num_heads, + window_size, + quant_size, + out_channels=None, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + down_stride=2, + cRSE="XYZ_RGB", + fp16_mode=0, + ): super().__init__() self.window_size = window_size self.depth = depth @@ -511,22 +674,39 @@ def __init__(self, dim, depth, num_heads, window_size, quant_size, out_channels= self.cRSE = cRSE self.fp16_mode = fp16_mode - - self.shift_size = window_size//2 + self.shift_size = window_size // 2 # build blocks - self.blocks = nn.ModuleList([ - SwinTransformerBlock( - dim, num_heads, window_size, quant_size, - drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,\ - mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, norm_layer=norm_layer, cRSE=cRSE, fp16_mode=fp16_mode) - for i in range(depth)]) + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim, + num_heads, + window_size, + quant_size, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + norm_layer=norm_layer, + cRSE=cRSE, + fp16_mode=fp16_mode, + ) + for i in range(depth) + ] + ) - self.pool = ME.MinkowskiMaxPooling(kernel_size=self.window_size, stride=self.window_size,dimension=3) + self.pool = ME.MinkowskiMaxPooling( + kernel_size=self.window_size, stride=self.window_size, dimension=3 + ) if downsample is not None: if out_channels is None: out_channels = dim * 2 - self.downsample = downsample(dim, out_channels, kernel_size=down_stride, stride=down_stride) + self.downsample = downsample( + dim, out_channels, kernel_size=down_stride, stride=down_stride + ) else: self.downsample = None @@ -534,21 +714,26 @@ def get_map_pair(self, sp): """ use minkowski pool to calculate windows get the mapping from voxel to window - """ - window_size = [self.window_size]*3 - pool_sp = self.pool(sp) + """ + window_size = [self.window_size] * 3 + pool_sp = self.pool(sp) windows = pool_sp.C - window_N = windows.shape[0] + window_N = windows.shape[0] stride_in = sp.coordinate_map_key.get_tensor_stride() - x, y, z = [torch.arange(window_size[i], device=self.device)*stride_in[i] for i in range(3)] - x, y, z = torch.meshgrid(x,y,z) + x, y, z = [ + torch.arange(window_size[i], device=self.device) * stride_in[i] + for i in range(3) + ] + x, y, z = torch.meshgrid(x, y, z) i = torch.zeros_like(x, device=self.device) local_window = torch.stack([i, x, y, z], dim=-1).flatten(0, -2) all_windows = windows.unsqueeze(1) + local_window.unsqueeze(0) all_windows = all_windows.flatten(0, -2).int() cm = sp.coordinate_manager - query_key, (map, inverse_map) = cm.insert_and_map(all_windows, tensor_stride=stride_in) + query_key, (map, inverse_map) = cm.insert_and_map( + all_windows, tensor_stride=stride_in + ) map_pair = cm.kernel_map(query_key, sp.coordinate_map_key, kernel_size=1)[0] return map_pair, window_N @@ -560,42 +745,81 @@ def get_window_mapping(self, sp): nempty_num: non-empty voxel number in each window sort_idx: sort voxel according to window_id, to gather the point inside the same window inv_sort_idx: inverse sort index - """ + """ map_pair, window_N = self.get_map_pair(sp) window_size = self.window_size - nW = window_size ** 3 + nW = window_size**3 in_map, out_map = map_pair in_map, sort_idx = torch.sort(in_map) # assert out_map == arange(out_map.shape[0]) out_map = out_map[sort_idx] sort_idx = out_map.long() inv_sort_idx = torch.zeros_like(sort_idx) - inv_sort_idx[sort_idx] = torch.arange(sort_idx.shape[0], dtype=sort_idx.dtype, device=self.device) + inv_sort_idx[sort_idx] = torch.arange( + sort_idx.shape[0], dtype=sort_idx.dtype, device=self.device + ) N = window_N * nW v2w_mask = torch.zeros(N, dtype=torch.bool, device=self.device) - w_id = torch.arange(window_N, dtype=torch.long, device=self.device).unsqueeze(1).repeat(1, nW).view(-1) - w_w_id = torch.arange(nW, dtype=torch.long, device=self.device).unsqueeze(0).repeat(window_N, 1).view(-1) + w_id = ( + torch.arange(window_N, dtype=torch.long, device=self.device) + .unsqueeze(1) + .repeat(1, nW) + .view(-1) + ) + w_w_id = ( + torch.arange(nW, dtype=torch.long, device=self.device) + .unsqueeze(0) + .repeat(window_N, 1) + .view(-1) + ) v2w_mask[in_map.long()] = True nempty_num = v2w_mask.view(-1, nW).sum(dim=-1) w_id = w_id[in_map.long()] w_w_id = w_w_id[in_map.long()] - w_w_xyz = torch.stack([w_w_id//window_size//window_size, w_w_id//window_size%window_size, w_w_id%window_size], dim=-1) + w_w_xyz = torch.stack( + [ + w_w_id // window_size // window_size, + w_w_id // window_size % window_size, + w_w_id % window_size, + ], + dim=-1, + ) return w_w_id, w_w_xyz, nempty_num, sort_idx, inv_sort_idx def get_index01(self, sp, local_xyz, colors): """ calculate the arguments for sparse attention """ - w_w_id, w_w_xyz, nempty_num, n2n_indices, inv_sort_idx = self.get_window_mapping(sp) + ( + w_w_id, + w_w_xyz, + nempty_num, + n2n_indices, + inv_sort_idx, + ) = self.get_window_mapping(sp) local_xyz = local_xyz[n2n_indices] colors = colors[n2n_indices] # recover the relative pos in the voxel n_coords = w_w_xyz + local_xyz n_coords = torch.cat([n_coords, colors], dim=1) - x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, w2m_indices =\ - sparse_self_attention(w_w_id, nempty_num, protocol='v2') - return x_offset, y_offset, m2w_indices, w_sizes, w2n_indices, n2n_indices,\ - w2m_indices, n_coords + ( + x_offset, + y_offset, + m2w_indices, + w_sizes, + w2n_indices, + w2m_indices, + ) = sparse_self_attention(w_w_id, nempty_num, protocol="v2") + return ( + x_offset, + y_offset, + m2w_indices, + w_sizes, + w2n_indices, + n2n_indices, + w2m_indices, + n_coords, + ) def get_shifted_sp(self, sp): """ @@ -605,12 +829,17 @@ def get_shifted_sp(self, sp): shift_size = self.shift_size * stride_in[0] shifted_C = sp.C.clone() shifted_C[:, 1:] += shift_size - shifted_sp = SparseTensor(features=sp.F, coordinates=shifted_C, device=self.device, tensor_stride=stride_in) + shifted_sp = SparseTensor( + features=sp.F, + coordinates=shifted_C, + device=self.device, + tensor_stride=stride_in, + ) return shifted_sp def get_window_pos(self, sp): stride_in = sp.coordinate_map_key.get_tensor_stride() - return (sp.C[:, 1:]/stride_in[0]) % self.window_size + return (sp.C[:, 1:] / stride_in[0]) % self.window_size def forward(self, sp, coords_sp): """ @@ -620,7 +849,9 @@ def forward(self, sp, coords_sp): """ colors = coords_sp.F[:, 4:] xyz = coords_sp.F[:, :4] - local_xyz = (xyz - coords_sp.C)[:, 1:] / coords_sp.coordinate_map_key.get_tensor_stride()[0] + local_xyz = (xyz - coords_sp.C)[ + :, 1: + ] / coords_sp.coordinate_map_key.get_tensor_stride()[0] self.device = sp.device sp_shift = self.get_shifted_sp(sp) @@ -630,7 +861,7 @@ def forward(self, sp, coords_sp): feats = sp.F for i, blk in enumerate(self.blocks): attn_args_blk = attn_args if i % 2 == 0 else attn_args_shift - feats = blk(feats, attn_args_blk) #[N, C] + feats = blk(feats, attn_args_blk) # [N, C] sp = assign_feats(sp, feats) if self.downsample is not None: diff --git a/examples/segmentation_mini.py b/examples/segmentation_mini.py new file mode 100644 index 0000000..2bbb2c9 --- /dev/null +++ b/examples/segmentation_mini.py @@ -0,0 +1,79 @@ +""" +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +""" +import numpy as np +import torch +import torch.nn as nn +from Swin3D.models import Swin3DUNet +from easydict import EasyDict +from MinkowskiEngine import SparseTensor + +args = EasyDict({ + 'in_channels': 9, + 'num_layers': 5, + 'depths': [2, 2, 2, 2, 2], + 'channels': [16, 16, 32, 64, 64] , + 'num_heads': [2, 2, 4, 8, 8], + 'window_sizes': [5, 7, 7, 7, 7], + 'quant_size': 4, + 'down_stride': 3, + 'knn_down': True, + 'stem_transformer': True, + 'upsample': 'linear', + 'up_k': 3, + 'drop_path_rate': 0.3, + 'num_classes': 20, + 'ignore_label': -100, + 'base_lr': 0.001, + 'transformer_lr_scale': 0.1, + 'weight_decay': 0.0001, +}) +model = Swin3DUNet(args.depths, args.channels, args.num_heads, \ + args.window_sizes, args.quant_size, up_k=args.up_k, drop_path_rate=args.drop_path_rate, num_classes=args.num_classes, \ + num_layers=args.num_layers, stem_transformer=args.stem_transformer, upsample=args.upsample, first_down_stride=args.down_stride, + knn_down=args.knn_down, in_channels=args.in_channels, cRSE='XYZ_RGB_NORM', fp16_mode=2) +print(model) +print('#Model parameters: {}'.format(sum([x.nelement() for x in model.parameters()]))) +model = model.cuda() +criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label).cuda() +param_dicts = [ + {"params": [p for n, p in model.named_parameters() if "blocks" not in n and p.requires_grad]}, + { + "params": [p for n, p in model.named_parameters() if "blocks" in n and p.requires_grad], + "lr": args.base_lr * args.transformer_lr_scale, + }, +] +optimizer = torch.optim.AdamW(param_dicts, lr=args.base_lr, weight_decay=args.weight_decay) + +data = np.load("examples/input.npz") +feat, xyz, batch, target = data["feat"], data["xyz"], data["batch"], data["target"] +# feats: [N, 6], RGB, Normal +# xyz: [N, 3], +# batch: [N], +# target: [N], +feat, xyz, batch, target = torch.from_numpy(feat).cuda(), torch.from_numpy(xyz).cuda(), torch.from_numpy(batch).cuda(), torch.from_numpy(target).cuda() +coords = torch.cat([batch.unsqueeze(-1), xyz], dim=-1) +feat = torch.cat([feat, xyz], dim=1) +sp = SparseTensor(feat.float(), torch.floor(coords).int(), device=feat.device) +colors = feat[:, 0:3] +normals = feat[:, 3:6] +coords_sp = SparseTensor(features=torch.cat([coords, colors, normals], dim=1), coordinate_map_key=sp.coordinate_map_key, +coordinate_manager=sp.coordinate_manager) + +use_amp = True +scaler = torch.cuda.amp.GradScaler() +with torch.cuda.amp.autocast(enabled=use_amp): + output = model(sp, coords_sp) +loss = criterion(output, target) +optimizer.zero_grad() + +if use_amp: + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() +else: + loss.backward() + optimizer.step() +print("FINISHED!") +