From 0c8cd33a845ffc2ca3e8369a7fecd6f321174656 Mon Sep 17 00:00:00 2001 From: bdf <36697723+defei-coder@users.noreply.github.com> Date: Wed, 19 Apr 2023 10:42:07 +0800 Subject: [PATCH] Pick MLU modifications from master (1.x) to main (2.x) (#2704) * [Feature] Support Voxelization with cambricon MLU device (#2500) * [Feature] Support hard_voxelize with cambricon MLU backend * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Feature](bangc-ops): add voxelization op * [Enhance] Optimize the performace of ms_deform_attn for MLU device (#2510) * ms_opt * ms_opt * ms_opt * ms_opt * ms_opt * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] ms_deform_attn performance optimization * [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. (#2520) * [Feature] Support ball_query with cambricon MLU backend and mlu-ops library. * [Fix] update operator data layout setting. * [Fix] add cxx compile option to avoid symbol conflict. * [Fix] fix lint errors. * [Fix] update ops.md with info of ball_query support by MLU backend. * [Feature] Fix typo. * [Fix] Remove print. * [Fix] get mlu-ops from MMCV_MLU_OPS_PATH env. * [Fix] update MMCV_MLU_OPS_PATH check logic. * [Fix] update error info when failed to download mlu-ops. * [Fix] check mlu-ops version matching info in mmcv. * [Fix] revise wrong filename. * [Fix] remove f.close and re. * [Docs] Steps to compile mmcv-full on MLU machine (#2571) * [Docs] Steps to compile mmcv-full on MLU machine * [Docs] Adjust paragraph order * Update docs/zh_cn/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/zh_cn/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update docs/en/get_started/build.md Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Docs] Modify the format --------- Co-authored-by: budefei Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * [Fix] Fix tensor descriptor setting in MLU ball_query. (#2579) * [Feature] Add MLU support for Sparse Convolution op (#2589) * [Feature] Add sparse convolution MLU API * [Feature] update cpp code style * end-of-file * delete libext.a * code style * update ops.md --------- Co-authored-by: budefei * [Enhancement] Replace the implementation of deform_roi_pool with mlu-ops (#2598) * [Feature] Replace the implementation of deform_roi_pool with mlu-ops * [Feature] Modify code --------- Co-authored-by: budefei * [Enhancement] ms_deform_attn performance optimization (#2616) * ms_opt_v2 * ms_opt_v2_1 * optimize MultiScaleDeformableAttention ops for MLU * ms_opt_v2_1 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 * [Feature] ms_deform_attn performance optimization V2 --------- Co-authored-by: dongchengwei * [Feature] Support NmsRotated with cambricon MLU backend (#2643) * [Feature] Support NmsRotated with cambricon MLU backend * [Feature] remove foolproofs in nms_rotated_mlu.cpp * [Feature] fix lint in test_nms_rotated.py * [Feature] fix kMLU not found in nms_rotated.cpp * [Feature] modify mlu support in nms.py * [Feature] modify nms_rotated support in ops.md * [Feature] modify ops/nms.py * [Enhance] Add a default value for MMCV_MLU_ARGS (#2688) * add mlu_args * add mlu_args * Modify the code --------- Co-authored-by: budefei * [Enhance] Ignore mlu-ops files (#2691) Co-authored-by: budefei --------- Co-authored-by: ZShaopeng <108382403+ZShaopeng@users.noreply.github.com> Co-authored-by: BinZheng <38182684+Wickyzheng@users.noreply.github.com> Co-authored-by: liuduanhui <103939338+DanieeelLiu@users.noreply.github.com> Co-authored-by: budefei Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: duzekun <108381389+duzekunKTH@users.noreply.github.com> Co-authored-by: dongchengwei Co-authored-by: liuyuan1-v <125547457+liuyuan1-v@users.noreply.github.com> --- .gitignore | 2 + docs/en/get_started/build.md | 57 + docs/en/understand_mmcv/ops.md | 8 +- docs/zh_cn/get_started/build.md | 56 + docs/zh_cn/understand_mmcv/ops.md | 8 +- .../common/mlu/deform_roi_pool_mlu_kernel.mlu | 712 --------- .../common/mlu/ms_deform_attn_mlu_kernel.mlu | 1417 ++++++++++++++++- .../common/mlu/voxelization_mlu_kernel.mlu | 532 +++++++ mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp | 47 + .../csrc/pytorch/mlu/deform_roi_pool_mlu.cpp | 322 +--- .../csrc/pytorch/mlu/mlu_common_helper.cpp | 136 ++ mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h | 99 ++ .../csrc/pytorch/mlu/ms_deform_attn_mlu.cpp | 145 +- mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp | 53 + mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp | 446 ++++++ .../ops/csrc/pytorch/mlu/voxelization_mlu.cpp | 268 ++++ mmcv/ops/csrc/pytorch/nms_rotated.cpp | 9 + mmcv/ops/csrc/pytorch/spconv_ops.cpp | 26 + mmcv/ops/nms.py | 9 +- setup.py | 90 +- tests/test_ops/test_ball_query.py | 84 +- tests/test_ops/test_nms_rotated.py | 14 +- tests/test_ops/test_spconv.py | 50 +- tests/test_ops/test_voxelization.py | 42 +- 24 files changed, 3475 insertions(+), 1157 deletions(-) delete mode 100644 mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu create mode 100644 mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h create mode 100644 mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp create mode 100644 mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp diff --git a/.gitignore b/.gitignore index 65fc532494..1769eff2de 100644 --- a/.gitignore +++ b/.gitignore @@ -27,6 +27,8 @@ wheels/ .installed.cfg *.egg MANIFEST +mlu-ops/ +mlu-ops.* # PyInstaller # Usually these files are written by a python script from a template diff --git a/docs/en/get_started/build.md b/docs/en/get_started/build.md index e3d48ec7cf..742e60c356 100644 --- a/docs/en/get_started/build.md +++ b/docs/en/get_started/build.md @@ -290,3 +290,60 @@ If you need to use PyTorch-related modules, make sure PyTorch has been successfu ```bash python -c 'import mmcv;print(mmcv.__version__)' ``` + +### Build mmcv-full on Cambricon MLU Devices + +#### Install torch_mlu + +##### Option1: Install mmcv-full based on Cambricon docker image + +Firstly, install and pull Cambricon docker image (please email service@cambricon.com for the latest release docker): + +```bash +docker pull ${docker image} +``` + +Run and attach to the docker, [Install mmcv-full on MLU device](#install-mmcv\-full-on-cambricon-mlu-device) and [make sure you've installed mmcv-full on MLU device successfully](#test-code) + +##### Option2: Install mmcv-full from compiling Cambricon PyTorch source code + +Please email service@cambricon.com or contact with Cambricon engineers for a suitable version of CATCH package. After you get the suitable version of CATCH package, please follow the steps in ${CATCH-path}/CONTRIBUTING.md to install Cambricon PyTorch. + +#### Install mmcv-full on Cambricon MLU device + +Clone the repo + +```bash +git clone https://github.com/open-mmlab/mmcv.git +``` + +The mlu-ops library will be downloaded to the default directory (mmcv/mlu-ops) while building MMCV. You can also set `MMCV_MLU_OPS_PATH` to an existing mlu-ops library before building as follows: + +```bash +export MMCV_MLU_OPS_PATH=/xxx/xxx/mlu-ops +``` + +Install mmcv-full + +```bash +cd mmcv +export MMCV_WITH_OPS=1 +export FORCE_MLU=1 +python setup.py install +``` + +#### Test Code + +After finishing previous steps, you can run the following python code to make sure that you've installed mmcv-full on MLU device successfully + +```python +import torch +import torch_mlu +from mmcv.ops import sigmoid_focal_loss +x = torch.randn(3, 10).mlu() +x.requires_grad = True +y = torch.tensor([1, 5, 3]).mlu() +w = torch.ones(10).float().mlu() +output = sigmoid_focal_loss(x, y, 2.0, 0.25, w, 'none') +print(output) +``` diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index e60f77c772..258adffba6 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -6,7 +6,7 @@ We implement common ops used in detection, segmentation, etc. | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | | | | +| BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | | | | @@ -35,7 +35,7 @@ We implement common ops used in detection, segmentation, etc. | ModulatedDeformConv2d | √ | √ | | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | √ | +| NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | @@ -52,13 +52,13 @@ We implement common ops used in detection, segmentation, etc. | SigmoidFocalLoss | | √ | √ | | √ | | SoftmaxFocalLoss | | √ | | | √ | | SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | +| Sparse Convolution | | √ | √ | | | | Synchronized BatchNorm | | √ | | | | | ThreeInterpolate | | √ | | | | | ThreeNN | | √ | √ | | | | TINShift | | √ | √ | | | | UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | | | √ | +| Voxelization | √ | √ | √ | | √ | | PrRoIPool | | √ | | | | | BezierAlign | √ | √ | | | | | BiasAct | | √ | | | | diff --git a/docs/zh_cn/get_started/build.md b/docs/zh_cn/get_started/build.md index 95f611bc2e..99d2214e90 100644 --- a/docs/zh_cn/get_started/build.md +++ b/docs/zh_cn/get_started/build.md @@ -298,3 +298,59 @@ mmcv 有两个版本: ```bash python -c 'import mmcv;print(mmcv.__version__)' ``` + +### 在寒武纪 MLU 机器编译 mmcv-full + +#### 安装 torch_mlu + +##### 选项1: 基于寒武纪 docker image 安装 + +首先请下载并且拉取寒武纪 docker (请向 service@cambricon.com 发邮件以获得最新的寒武纪 pytorch 发布 docker)。 + +``` +docker pull ${docker image} +``` + +进入 docker, [编译 MMCV MLU](#编译mmcv-mlu) 并[进行验证](#验证是否成功安装)。 + +##### 选项2:基于 cambricon pytorch 源码编译安装 + +请向 service@cambricon.com 发送邮件或联系 Cambricon 工程师以获取合适版本的 CATCH 软件包,在您获得合适版本的 CATCH 软件包后,请参照 ${CATCH-path}/CONTRIBUTING.md 中的步骤安装 CATCH。 + +#### 编译 MMCV + +克隆代码仓库 + +```bash +git clone https://github.com/open-mmlab/mmcv.git +``` + +算子库 mlu-ops 在编译 MMCV 时自动下载到默认路径(mmcv/mlu-ops),你也可以在编译前设置环境变量 MMCV_MLU_OPS_PATH 指向已经存在的 mlu-ops 算子库路径。 + +```bash +export MMCV_MLU_OPS_PATH=/xxx/xxx/mlu-ops +``` + +开始编译 + +```bash +cd mmcv +export MMCV_WITH_OPS=1 +export FORCE_MLU=1 +python setup.py install +``` + +#### 验证是否成功安装 + +完成上述安装步骤之后,您可以尝试运行下面的 Python 代码以测试您是否成功在 MLU 设备上安装了 mmcv-full + +```python +import torch +import torch_mlu +from mmcv.ops import sigmoid_focal_loss +x = torch.randn(3, 10).mlu() +x.requires_grad = True +y = torch.tensor([1, 5, 3]).mlu() +w = torch.ones(10).float().mlu() +output = sigmoid_focal_loss(x, y, 2.0, 0.25, w, 'none') +``` diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 11b885d37c..4fd19eedb4 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -6,7 +6,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ---------------------------- | --- | ---- | --- | --- | ------ | | ActiveRotatedFilter | √ | √ | | | | | AssignScoreWithK | | √ | | | | -| BallQuery | | √ | | | | +| BallQuery | | √ | √ | | | | BBoxOverlaps | | √ | √ | √ | √ | | BorderAlign | | √ | | | | | BoxIouRotated | √ | √ | | | | @@ -35,7 +35,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | ModulatedDeformConv2d | √ | √ | | | √ | | MultiScaleDeformableAttn | | √ | √ | | | | NMS | √ | √ | √ | | √ | -| NMSRotated | √ | √ | | | √ | +| NMSRotated | √ | √ | √ | | √ | | NMSQuadri | √ | √ | | | | | PixelGroup | √ | | | | | | PointsInBoxes | √ | √ | | | | @@ -52,13 +52,13 @@ MMCV 提供了检测、分割等任务中常用的算子 | SigmoidFocalLoss | | √ | √ | | √ | | SoftmaxFocalLoss | | √ | | | √ | | SoftNMS | | √ | | | | -| Sparse Convolution | | √ | | | | +| Sparse Convolution | | √ | √ | | | | Synchronized BatchNorm | | √ | | | | | ThreeInterpolate | | √ | | | | | ThreeNN | | √ | √ | | | | TINShift | | √ | √ | | | | UpFirDn2d | | √ | | | | -| Voxelization | √ | √ | | | √ | +| Voxelization | √ | √ | √ | | √ | | PrRoIPool | | √ | | | | | BezierAlign | √ | √ | | | | | BiasAct | | √ | | | | diff --git a/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu deleted file mode 100644 index 6c765e3eaa..0000000000 --- a/mmcv/ops/csrc/common/mlu/deform_roi_pool_mlu_kernel.mlu +++ /dev/null @@ -1,712 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include - -#include "common_mlu_helper.hpp" - -#define ROI_OFFSET 5 -#define FOURSPLIT 4 -#define FIVESPLIT 5 -#define NINESPLIT 9 -#define THIRTEENSPLIT 13 - -__nram__ char nram_buffer[MAX_NRAM_SIZE]; - -template -static __mlu_func__ void bilinearInterpolate(const int input_width, T y, T x, - T *w1, T *w2, T *w3, T *w4, - int *x_low, int *x_high, - const int y_low, bool *is_empty) { - if (x < -1.0 || x > input_width) { - *is_empty = true; - return; - } - - if (x <= 0) x = 0; - - *x_low = int(x); - - if (*x_low >= input_width - 1) { - *x_high = *x_low = input_width - 1; - x = T(*x_low); - } else { - *x_high = *x_low + 1; - } - - T ly = y - y_low; - T lx = x - *x_low; - T hy = 1.0 - ly; - T hx = 1.0 - lx; - *w1 = hy * hx; - *w2 = hy * lx; - *w3 = ly * hx; - *w4 = ly * lx; -} - -template -__mlu_func__ void MLUUnion1DeformRoIPoolForward( - const T *input, const T *rois, const T *offset, T *output, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const T spatial_scale, - const int sampling_ratio, const T gamma) { - for (int bin_index = taskId; - bin_index < num_rois * pooled_width * pooled_height; - bin_index += taskDim) { - int out_batch = bin_index / pooled_width / pooled_height; - int out_height = bin_index / pooled_width % pooled_height; - int out_width = bin_index % pooled_width; - const T *cur_roi = rois + out_batch * ROI_OFFSET; - T *nram_rois = (T *)nram_buffer; - __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - const int roi_batch = nram_rois[0]; - T roi_x_min = nram_rois[1] * spatial_scale - 0.5; - T roi_y_min = nram_rois[2] * spatial_scale - 0.5; - const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; - const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; - const T roi_width = roi_x_max - roi_x_min; - const T roi_height = roi_y_max - roi_y_min; - const T bin_width = roi_width / static_cast(pooled_width); - const T bin_height = roi_height / static_cast(pooled_height); - const T *offset_input = input + roi_batch * height * width * channels; - int roi_bin_grid_height = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_width = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_width / pooled_width)); - if (offset != NULL) { - const T *offset_cur = offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width; - roi_x_min += gamma * roi_width * offset_cur[0]; - roi_y_min += - gamma * roi_height * offset_cur[pooled_width * pooled_height]; - } - int type_align = NFU_ALIGN_SIZE / sizeof(T); - int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); - int channels_nram_split = - channels_max_num_nram / NINESPLIT / type_align * type_align; - int channel_rem = channels % channels_nram_split; - int channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - for (int channel_loop_index = 0; channel_loop_index < channel_loops; - ++channel_loop_index) { - int channels_num = - channels_nram_split >= channels ? channels : channels_nram_split; - const int channel_offset = channel_loop_index * channels_num; - if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { - channels_num = channel_rem; - } - int channels_align = CEIL_ALIGN(channels_num, type_align); - int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - channels_align) >> 1; - int c_slice = nram_limit / FOURSPLIT / type_align * type_align; - int c_slice_align = 0; - - /* NRAM partition - * - * | | ping | pong | - * |----------|-------------------|-------------------| - * | nram_out | p1 | p2 | p3 | p4 | p1 | p2 | p3 | p4 | - * - */ - - T *nram_out = (T *)nram_buffer; - T *nram_ping = nram_out + channels_align; - T *nram_pong = nram_ping + nram_limit; - __bang_write_value((T *)nram_out, channels_align, (T)0); - __bang_write_value((T *)nram_ping, FOURSPLIT * c_slice, (T)0); - __bang_write_value((T *)nram_pong, FOURSPLIT * c_slice, (T)0); - const T num_bins = - static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); - const T value_div = 1.0f / num_bins; - bool is_ping_empty = true; - for (int iy = 0; iy < roi_bin_grid_height; ++iy) { - T y = roi_y_min + out_height * bin_height + - static_cast(iy + .5f) * bin_height / - static_cast(roi_bin_grid_height); - if (y < -1.0 || y > height) { - is_ping_empty = true; - continue; - } - if (y <= 0) { - y = 0; - } - int y_low = 0, y_high = 0; - y_low = int(y); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = T(y_low); - } else { - y_high = y_low + 1; - } - for (int ix = 0; ix < roi_bin_grid_width; ++ix) { - T x = roi_x_min + out_width * bin_width + - static_cast(ix + .5f) * bin_width / - static_cast(roi_bin_grid_width); - const int sample_index = iy * roi_bin_grid_width + ix; - int c_rem = channels_num; - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - c_slice_align = 0; - bool is_empty = false; - T w1, w2, w3, w4; - int x_low = 0, x_high = 0; - bilinearInterpolate(width, y, x, &w1, &w2, &w3, &w4, &x_low, &x_high, - y_low, &is_empty); - if (is_empty) { - is_ping_empty = true; - continue; - } - if (is_ping_empty) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - __bang_write_value(nram_ping, FOURSPLIT * c_slice_align, (T)0); - __asm__ volatile("sync;"); - __memcpy(nram_ping, - offset_input + y_low * width * channels + - x_low * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + c_slice_align, - offset_input + y_low * width * channels + - x_high * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + 2 * c_slice_align, - offset_input + y_high * width * channels + - x_low * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping + 3 * c_slice_align, - offset_input + y_high * width * channels + - x_high * channels + channel_offset, - c_slice * sizeof(T), GDRAM2NRAM); - is_ping_empty = false; - } - int c_offset = 0; - int pongc_slice = 0; - int pongc_slice_align = 0; - while (c_rem > 0) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { - int iy_tmp = (sample_index + 1) / roi_bin_grid_width; - int ix_tmp = (sample_index + 1) % roi_bin_grid_width; - y = roi_y_min + out_height * bin_height + - static_cast(iy_tmp + .5f) * bin_height / - static_cast(roi_bin_grid_height); - x = roi_x_min + out_width * bin_width + - static_cast(ix_tmp + .5f) * bin_width / - static_cast(roi_bin_grid_width); - if (y < -1.0 || y > height) { - is_empty = true; - } else { - T w1_tmp, w2_tmp, w3_tmp, w4_tmp; - if (y <= 0) { - y = 0; - } - y_low = int(y); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y = T(y_low); - } else { - y_high = y_low + 1; - } - bilinearInterpolate(width, y, x, &w1_tmp, &w2_tmp, &w3_tmp, - &w4_tmp, &x_low, &x_high, y_low, &is_empty); - } - pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; - pongc_slice = - pongc_slice > channels_num ? channels_num : pongc_slice; - pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); - __bang_write_value(nram_pong, FOURSPLIT * pongc_slice_align, - (T)0); - __asm__ volatile("sync;"); - if (!is_empty) { - __memcpy_async(nram_pong, - offset_input + y_low * width * channels + - x_low * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + pongc_slice_align, - offset_input + y_low * width * channels + - x_high * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + 2 * pongc_slice_align, - offset_input + y_high * width * channels + - x_low * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong + 3 * pongc_slice_align, - offset_input + y_high * width * channels + - x_high * channels + channel_offset, - pongc_slice * sizeof(T), GDRAM2NRAM); - } - } - __bang_mul_scalar(nram_ping, nram_ping, w1, c_slice_align); - __bang_mul_scalar(nram_ping + c_slice_align, - nram_ping + c_slice_align, w2, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + c_slice_align, - c_slice_align); - __bang_mul_scalar(nram_ping + 2 * c_slice_align, - nram_ping + 2 * c_slice_align, w3, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + 2 * c_slice_align, - c_slice_align); - __bang_mul_scalar(nram_ping + 3 * c_slice_align, - nram_ping + 3 * c_slice_align, w4, c_slice_align); - __bang_add(nram_ping, nram_ping, nram_ping + 3 * c_slice_align, - c_slice_align); - __bang_add(nram_out + c_offset, nram_out + c_offset, nram_ping, - c_slice_align); - T *nram_tmp = nram_ping; - nram_ping = nram_pong; - nram_pong = nram_tmp; - c_rem -= c_slice; - c_offset += c_slice; - __asm__ volatile("sync;"); - } - } - } - __bang_mul_scalar(nram_out, nram_out, value_div, channels_align); - __memcpy(output + channels * bin_index + channel_offset, nram_out, - channels_num * sizeof(T), NRAM2GDRAM); - } - } -} - -__mlu_global__ void MLUKernelDeformRoIPoolForward( - cnrtDataType_t data_type, const void *input, const void *rois, - const void *offset, void *output, const int channels, const int height, - const int width, const int num_rois, const int pooled_height, - const int pooled_width, const float spatial_scale, const int sampling_ratio, - const float gamma) { - switch (data_type) { - case CNRT_FLOAT16: { - MLUUnion1DeformRoIPoolForward((half *)input, (half *)rois, (half *)offset, - (half *)output, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), - sampling_ratio, static_cast(gamma)); - }; break; - case CNRT_FLOAT32: { - MLUUnion1DeformRoIPoolForward( - (float *)input, (float *)rois, (float *)offset, (float *)output, - channels, height, width, num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - default: { - break; - } - } -} - -void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *input, const void *rois, - const void *offset, void *output, - const int channels, const int height, - const int width, const int num_rois, - const int pooled_height, const int pooled_width, - const float spatial_scale, - const int sampling_ratio, const float gamma) { - MLUKernelDeformRoIPoolForward<<>>( - data_type, input, rois, offset, output, channels, height, width, num_rois, - pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma); -} - -template -__mlu_func__ void MLUUnion1DeformRoIPoolBackward( - const T *grad_output, const T *input, const T *rois, const T *offset, - T *grad_input, T *grad_offset, const int channels, const int height, - const int width, const int num_rois, const int pooled_height, - const int pooled_width, const T spatial_scale, const int sampling_ratio, - const T gamma) { - for (int bin_index = taskId; - bin_index < num_rois * pooled_width * pooled_height; - bin_index += taskDim) { - int out_batch = bin_index / pooled_width / pooled_height; - int out_height = bin_index / pooled_width % pooled_height; - int out_width = bin_index % pooled_width; - const T *cur_roi = rois + out_batch * ROI_OFFSET; - T *nram_rois = (T *)nram_buffer; - __memcpy((void *)nram_rois, (void *)cur_roi, ROI_OFFSET * sizeof(T), - GDRAM2NRAM); - const int roi_batch = nram_rois[0]; - T roi_x_min = nram_rois[1] * spatial_scale - 0.5; - T roi_y_min = nram_rois[2] * spatial_scale - 0.5; - const T roi_x_max = nram_rois[3] * spatial_scale - 0.5; - const T roi_y_max = nram_rois[4] * spatial_scale - 0.5; - const T roi_width = roi_x_max - roi_x_min; - const T roi_height = roi_y_max - roi_y_min; - const T bin_width = roi_width / static_cast(pooled_width); - const T bin_height = roi_height / static_cast(pooled_height); - const T *offset_input = input + roi_batch * height * width * channels; - T *offset_grad_input = grad_input + roi_batch * height * width * channels; - int roi_bin_grid_height = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_height / pooled_height)); - int roi_bin_grid_width = - (sampling_ratio > 0) - ? sampling_ratio - : static_cast(ceilf(roi_width / pooled_width)); - if (offset != NULL) { - const T *offset_cur = offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width; - roi_x_min += gamma * roi_width * offset_cur[0]; - roi_y_min += - gamma * roi_height * offset_cur[pooled_width * pooled_height]; - } - - /* NRAM partition - * - * If offset != NULL, NRAM partition belows. - * | | - * ping | pong | - * |---------------------------------------------------------------------|-----------|-----------| - * |nram_tmp1|nram_tmp2|nram_tmp3|nram_tmp4|nram_grad_output|nram_sum_tmp|p1|p2|p3|p4|p1|p2|p3|p4| - * - * If offset == NULL, ping and pang will not be needed. - * | | - * |----------------------------------------------------------------------------------| - * | nram_tmp1 | nram_tmp2 | nram_tmp3 | nram_tmp4 | nram_grad_output | - * - */ - - int type_align = NFU_ALIGN_SIZE / sizeof(T); - int channels_max_num_nram = MAX_NRAM_SIZE / sizeof(T); - int channels_nram_split = - channels_max_num_nram / FIVESPLIT / type_align * type_align; - int channel_rem = channels % channels_nram_split; - int channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - if (offset != NULL) { - channels_nram_split = - channels_max_num_nram / THIRTEENSPLIT / type_align * type_align; - channel_rem = channels % channels_nram_split; - channel_loops = - channels / channels_nram_split + (channel_rem != 0 ? 1 : 0); - } - - for (int channel_loop_index = 0; channel_loop_index < channel_loops; - ++channel_loop_index) { - int channels_num = - channels_nram_split >= channels ? channels : channels_nram_split; - const int channel_offset = channel_loop_index * channels_num; - if (channel_loop_index + 1 == channel_loops && channel_rem != 0) { - channels_num = channel_rem; - } - int channels_align = CEIL_ALIGN(channels_num, type_align); - const int32_t nram_sum_tmp_channel = NFU_ALIGN_SIZE / sizeof(T); - int nram_limit = (MAX_NRAM_SIZE / sizeof(T) - 5 * channels_align - - nram_sum_tmp_channel) >> - 1; - int c_slice = 0; - int c_slice_align = 0; - T *nram_tmp1 = (T *)nram_buffer; - T *nram_tmp2 = (T *)nram_buffer + channels_align; - T *nram_tmp3 = (T *)nram_buffer + 2 * channels_align; - T *nram_tmp4 = (T *)nram_buffer + 3 * channels_align; - T *nram_grad_output = nram_tmp4 + channels_align; - T *nram_sum_tmp = NULL; - T *nram_ping_input = NULL; - T *nram_pong_input = NULL; - __bang_write_value((T *)nram_grad_output, channels_align, (T)0); - __asm__ volatile("sync;"); - - if (offset != NULL) { - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - nram_sum_tmp = nram_grad_output + channels_align; - nram_ping_input = nram_sum_tmp + nram_sum_tmp_channel; - nram_pong_input = nram_ping_input + FOURSPLIT * c_slice; - __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); - __bang_write_value((T *)nram_ping_input, FOURSPLIT * c_slice, (T)0); - __bang_write_value((T *)nram_pong_input, FOURSPLIT * c_slice, (T)0); - __asm__ volatile("sync;"); - } - const T num_bins = - static_cast(max(roi_bin_grid_height * roi_bin_grid_width, 1)); - const T value_div = 1.0f / num_bins; - bool is_ping_empty = true; - __memcpy(nram_grad_output, - grad_output + channels * bin_index + channel_offset, - channels_num * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(nram_grad_output, nram_grad_output, value_div, - channels_align); - for (int iy = 0; iy < roi_bin_grid_height; ++iy) { - T y = roi_y_min + out_height * bin_height + - static_cast(iy + .5f) * bin_height / - static_cast(roi_bin_grid_height); - T y_tmp = y; - if (y_tmp < -1.0 || y_tmp > height) { - is_ping_empty = true; - continue; - } - if (y_tmp <= 0) { - y_tmp = 0; - } - int y_low = 0, y_high = 0; - y_low = int(y_tmp); - if (y_low >= height - 1) { - y_high = y_low = height - 1; - y_tmp = T(y_low); - } else { - y_high = y_low + 1; - } - for (int ix = 0; ix < roi_bin_grid_width; ++ix) { - T x = roi_x_min + out_width * bin_width + - static_cast(ix + .5f) * bin_width / - static_cast(roi_bin_grid_width); - const int sample_index = iy * roi_bin_grid_width + ix; - int c_rem = channels_num; - bool is_empty = false; - T w1, w2, w3, w4; - int x_low = 0, x_high = 0; - bilinearInterpolate(width, y_tmp, x, &w1, &w2, &w3, &w4, &x_low, - &x_high, y_low, &is_empty); - if (is_empty) { - is_ping_empty = true; - continue; - } - __bang_mul_scalar((T *)nram_tmp1, (T *)nram_grad_output, w1, - channels_align); - __bang_mul_scalar((T *)nram_tmp2, (T *)nram_grad_output, w2, - channels_align); - __bang_mul_scalar((T *)nram_tmp3, (T *)nram_grad_output, w3, - channels_align); - __bang_mul_scalar((T *)nram_tmp4, (T *)nram_grad_output, w4, - channels_align); - __asm__ volatile("sync;"); - __bang_atomic_add( - (T *)nram_tmp1, - (T *)(offset_grad_input + (y_low * width + x_low) * channels + - channel_offset), - (T *)nram_tmp1, channels_num); - __bang_atomic_add( - (T *)nram_tmp2, - (T *)(offset_grad_input + (y_low * width + x_high) * channels + - channel_offset), - (T *)nram_tmp2, channels_num); - __bang_atomic_add( - (T *)nram_tmp3, - (T *)(offset_grad_input + (y_high * width + x_low) * channels + - channel_offset), - (T *)nram_tmp3, channels_num); - __bang_atomic_add( - (T *)nram_tmp4, - (T *)(offset_grad_input + (y_high * width + x_high) * channels + - channel_offset), - (T *)nram_tmp4, channels_num); - if (offset != NULL) { - c_slice = nram_limit / FOURSPLIT / type_align * type_align; - c_slice_align = 0; - if (is_ping_empty) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - __bang_write_value(nram_ping_input, FOURSPLIT * c_slice_align, - (T)0); - __asm__ volatile("sync;"); - const T *src_offset1 = offset_input + y_low * width * channels + - x_low * channels + channel_offset; - const T *src_offset2 = offset_input + y_low * width * channels + - x_high * channels + channel_offset; - const T *src_offset3 = offset_input + y_high * width * channels + - x_low * channels + channel_offset; - const T *src_offset4 = offset_input + y_high * width * channels + - x_high * channels + channel_offset; - __memcpy(nram_ping_input, src_offset1, c_slice * sizeof(T), - GDRAM2NRAM); - __memcpy(nram_ping_input + c_slice_align, src_offset2, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping_input + 2 * c_slice_align, src_offset3, - c_slice * sizeof(T), GDRAM2NRAM); - __memcpy(nram_ping_input + 3 * c_slice_align, src_offset4, - c_slice * sizeof(T), GDRAM2NRAM); - is_ping_empty = false; - } - int c_offset = 0; - int pongc_slice = 0; - int pongc_slice_align = 0; - while (c_rem > 0) { - c_slice = c_slice > c_rem ? c_rem : c_slice; - c_slice_align = CEIL_ALIGN(c_slice, type_align); - if (sample_index + 1 < roi_bin_grid_height * roi_bin_grid_width) { - int iy_tmp = (sample_index + 1) / roi_bin_grid_width; - int ix_tmp = (sample_index + 1) % roi_bin_grid_width; - T y_tmp = roi_y_min + out_height * bin_height + - static_cast(iy_tmp + .5f) * bin_height / - static_cast(roi_bin_grid_height); - T x_tmp = roi_x_min + out_width * bin_width + - static_cast(ix_tmp + .5f) * bin_width / - static_cast(roi_bin_grid_width); - int x_low_tmp = 0, x_high_tmp = 0, y_low_tmp = 0, - y_high_tmp = 0; - if (y_tmp < -1.0 || y_tmp > height) { - is_empty = true; - } else { - T w1_tmp, w2_tmp, w3_tmp, w4_tmp; - if (y_tmp <= 0) { - y_tmp = 0; - } - y_low_tmp = int(y_tmp); - if (y_low_tmp >= height - 1) { - y_high_tmp = y_low_tmp = height - 1; - y_tmp = T(y_low_tmp); - } else { - y_high_tmp = y_low_tmp + 1; - } - bilinearInterpolate(width, y_tmp, x_tmp, &w1_tmp, &w2_tmp, - &w3_tmp, &w4_tmp, &x_low_tmp, &x_high_tmp, - y_low_tmp, &is_empty); - } - pongc_slice = nram_limit / FOURSPLIT / type_align * type_align; - pongc_slice = - pongc_slice > channels_num ? channels_num : pongc_slice; - pongc_slice_align = CEIL_ALIGN(pongc_slice, type_align); - __bang_write_value(nram_pong_input, - FOURSPLIT * pongc_slice_align, (T)0); - __asm__ volatile("sync;"); - if (!is_empty) { - const T *src_offset1 = offset_input + - y_low_tmp * width * channels + - x_low_tmp * channels + channel_offset; - const T *src_offset2 = offset_input + - y_low_tmp * width * channels + - x_high_tmp * channels + channel_offset; - const T *src_offset3 = offset_input + - y_high_tmp * width * channels + - x_low_tmp * channels + channel_offset; - const T *src_offset4 = offset_input + - y_high_tmp * width * channels + - x_high_tmp * channels + channel_offset; - __memcpy_async(nram_pong_input, src_offset1, - pongc_slice * sizeof(T), GDRAM2NRAM); - __memcpy_async(nram_pong_input + pongc_slice_align, - src_offset2, pongc_slice * sizeof(T), - GDRAM2NRAM); - __memcpy_async(nram_pong_input + 2 * pongc_slice_align, - src_offset3, pongc_slice * sizeof(T), - GDRAM2NRAM); - __memcpy_async(nram_pong_input + 3 * pongc_slice_align, - src_offset4, pongc_slice * sizeof(T), - GDRAM2NRAM); - } - } - - __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, - y - y_low, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, - y_high - y, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, - y_low - y, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input, y - y_high, - c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_width, - c_slice_align); - __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); - const int32_t kernel_width = - c_slice_align / nram_sum_tmp_channel + - (int32_t)(c_slice_align % nram_sum_tmp_channel > 0); - __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, - kernel_width, 1, kernel_width, kernel_width, 1); - __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, - nram_sum_tmp_channel); - __bang_atomic_add( - (T *)nram_sum_tmp, - (T *)(grad_offset + - out_batch * pooled_width * pooled_height * 2 + - out_height * pooled_width + out_width), - (T *)nram_sum_tmp, 1); - __bang_write_value((T *)nram_sum_tmp, nram_sum_tmp_channel, (T)0); - __bang_mul_scalar(nram_tmp1, nram_ping_input + 3 * c_slice_align, - x - x_low, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + 2 * c_slice_align, - x_high - x, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input + c_slice_align, - x_low - x, c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp2, nram_ping_input, x - x_high, - c_slice_align); - __bang_add(nram_tmp1, nram_tmp1, nram_tmp2, c_slice_align); - __bang_mul_scalar(nram_tmp1, nram_tmp1, gamma * roi_height, - c_slice_align); - __bang_mul(nram_tmp1, nram_grad_output, nram_tmp1, c_slice_align); - __bang_sumpool(nram_sum_tmp, nram_tmp1, nram_sum_tmp_channel, 1, - kernel_width, 1, kernel_width, kernel_width, 1); - __bang_reduce_sum(nram_sum_tmp, nram_sum_tmp, - NFU_ALIGN_SIZE / sizeof(T)); - __bang_atomic_add( - (T *)nram_sum_tmp, - (T *)(grad_offset + - out_batch * pooled_width * pooled_height * 2 + - pooled_width * pooled_height + - out_height * pooled_width + out_width), - (T *)nram_sum_tmp, 1); - - T *nram_tmp = nram_ping_input; - nram_ping_input = nram_pong_input; - nram_pong_input = nram_tmp; - c_rem -= c_slice; - c_offset += c_slice; - __asm__ volatile("sync;"); - } - } - } - } - } - } -} - -__mlu_global__ void MLUKernelDeformRoIPoolBackward( - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma) { - switch (data_type) { - case CNRT_FLOAT16: { - MLUUnion1DeformRoIPoolBackward( - (half *)grad_output, (half *)input, (half *)rois, (half *)offset, - (half *)grad_input, (half *)grad_offset, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - case CNRT_FLOAT32: { - MLUUnion1DeformRoIPoolBackward( - (float *)grad_output, (float *)input, (float *)rois, (float *)offset, - (float *)grad_input, (float *)grad_offset, channels, height, width, - num_rois, pooled_height, pooled_width, - static_cast(spatial_scale), sampling_ratio, - static_cast(gamma)); - }; break; - default: { - break; - } - } -} - -void KernelDeformRoIPoolBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma) { - MLUKernelDeformRoIPoolBackward<<>>( - data_type, grad_output, input, rois, offset, grad_input, grad_offset, - channels, height, width, num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); -} diff --git a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu index 7899e52cd3..40ad6396a6 100644 --- a/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu +++ b/mmcv/ops/csrc/common/mlu/ms_deform_attn_mlu_kernel.mlu @@ -32,6 +32,7 @@ /**************************************************************************************** * * NRAM partition backward: + * default kernel * | grad_output_nram | grad_output_nram_temp | grad_weight | * | grad_h_weight | grad_w_weight | top_grad | * | top_grad_temp | spatial_shapes_nram | sampling_loc_nram | @@ -39,18 +40,34 @@ * | deal_size | deal_size | deal_size | * | deal_size | deal_size | 64bytes | * + * small channel kernel + * | nram_grad_output_tl | nram_grad_output_tr | nram_grad_output_bl | + * | nram_grad_output_br | grad_temp1 | grad_temp2 | + * | grad_temp3 | grad_temp4 | nram_loc_w | + * | nram_loc_h | nram_h_low | nram_w_low | + * | nram_h_high | nram_w_high | nram_h_low_temp | + * | nram_h_high_temp | nram_hw | nram_hh | + * | nram_lw | nram_lh | nram_h_low_ptr_offset | + * | nram_h_high_ptr_offset | nram_w_low_ptr_offset | nram_w_high_ptr_offset | + * | nram_w1 | nram_w2 | nram_w3 | + * | nram_w4 | nram_grad_weight | nram_base_ptr | + * | nram_offset_temp | nram_offset1 | nram_offset2 | + * | nram_offset3 | nram_offset4 | nram_w_low_temp | + * | nram_spatial_shapes | nram_level_start_index | nram_h_stride | ****************************************************************************************/ #define TWELVE_SPLIT 12 -#define ALIGN_NUM 64 +#define ALIGN_NUM 32 #define ALIGN_NUM_FOR_REDUCE 32 +#define ELE_COUNT 32 +#define LEN_FLOAT sizeof(float) __nram__ char nram_buffer[MAX_NRAM_SIZE]; template __mlu_func__ void loadNeighborPointsData( const T *data_value_gdram, T *data_value_p1_nram, T *data_value_p2_nram, - T *data_value_p3_nram, T *data_value_p4_nram, const size_t deal_num, + T *data_value_p3_nram, T *data_value_p4_nram, const size_t &deal_num, const int32_t &width, const int32_t &height, const int32_t &num_heads, const int32_t &channels, const T &x, const T &y, const int32_t &head_idx) { const int32_t w_low = floorf(x); @@ -100,11 +117,11 @@ __mlu_func__ void loadNeighborPointsData( } template -__mlu_func__ void bilinearInterpolation( +__mlu_func__ void computeMsDeformAttn( T *data_value_p1_nram, T *data_value_p2_nram, T *data_value_p3_nram, T *data_value_p4_nram, T *sample_point_value, T *auxiliary_b, - const size_t deal_num, const int32_t &width, const int32_t &height, - const T &x, const T &y) { + T *data_col_nram, const T &weight, const size_t &deal_num, + const int32_t &width, const int32_t &height, const T &x, const T &y) { const int32_t w_low = floorf(x); const int32_t h_low = floorf(y); const int32_t w_high = w_low + 1; @@ -156,10 +173,15 @@ __mlu_func__ void bilinearInterpolation( __bang_add((T *)sample_point_value, (T *)sample_point_value, (T *)auxiliary_b, deal_num); } + + __bang_mul_scalar((T *)sample_point_value, (T *)sample_point_value, (T)weight, + deal_num); + __bang_add((T *)data_col_nram, (T *)data_col_nram, (T *)sample_point_value, + deal_num); } template -__mlu_global__ void MLUKernelMsDeformAttnForward( +__mlu_global__ void MLUKernelMsDeformAttnForwardDefault( const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_level_start_index_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, @@ -346,7 +368,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( // compute if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - bilinearInterpolation( + computeMsDeformAttn( (T *)(ping_data_value_p1_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), @@ -359,15 +381,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( (T *)(ping_data_value_p4_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, span_num_deal, spatial_w, - spatial_h, x, y); - __bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight, - span_num_deal); - __bang_add((T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)auxiliary_a, span_num_deal); + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, span_num_deal, spatial_w, spatial_h, x, y); } spatial_w = spatial_w_next_point; @@ -500,7 +517,7 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( // compute if (y > -1 && x > -1 && y < spatial_h && x < spatial_w) { - bilinearInterpolation( + computeMsDeformAttn( (T *)(ping_data_value_p1_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), @@ -513,15 +530,10 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( (T *)(ping_data_value_p4_nram + ((level_idx * num_points + point_idx) % 2) * ping_pong_gap), - (T *)auxiliary_a, (T *)auxiliary_b, channels_align_rem, - spatial_w, spatial_h, x, y); - __bang_mul_scalar((T *)auxiliary_a, (T *)auxiliary_a, (T)weight, - channels_align_rem); - __bang_add((T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)(ping_data_col_nram + - data_col_ping_pong_idx * ping_pong_gap), - (T *)auxiliary_a, channels_align_rem); + (T *)auxiliary_a, (T *)auxiliary_b, + (T *)(ping_data_col_nram + + data_col_ping_pong_idx * ping_pong_gap), + weight, channels_align_rem, spatial_w, spatial_h, x, y); } spatial_w = spatial_w_next_point; @@ -544,7 +556,494 @@ __mlu_global__ void MLUKernelMsDeformAttnForward( return; } -template __mlu_global__ void MLUKernelMsDeformAttnForward( +__mlu_func__ void genMask0101(float *mask_ram, int32_t size) { + int32_t align_num = NFU_ALIGN_SIZE / sizeof(float); + for (int32_t i = 0; i < align_num; ++i) { + mask_ram[i] = i % 2; + } + __asm__ volatile("sync;"); + __memcpy(mask_ram + align_num, mask_ram, NFU_ALIGN_SIZE, NRAM2NRAM, + NFU_ALIGN_SIZE, 0, size / align_num - 2); + __asm__ volatile("sync;"); +} + +template +__mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( + const char *data_value_gdram, const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram) { +#if __BANG_ARCH__ >= 300 + if (coreId == 0x80) { + return; + } + + size_t block_num_per_core, batch_start, deal_g, offset_g; + size_t block_num_rem = 0; + const size_t grid_total = num_queries * num_heads * num_levels * num_points; + if (batch_size >= taskDim) { + block_num_rem = batch_size % taskDim; + block_num_per_core = taskId < block_num_rem ? batch_size / taskDim + 1 + : batch_size / taskDim; + batch_start = taskId < block_num_rem + ? taskId * block_num_per_core + : taskId * block_num_per_core + block_num_rem; + deal_g = grid_total; + offset_g = 0; + } else { + size_t skip_n = taskDim / batch_size; + batch_start = taskId / skip_n; + block_num_per_core = batch_start >= batch_size ? 0 : 1; + deal_g = PAD_UP(grid_total / skip_n, num_levels * num_points); + size_t id = taskId % skip_n; + offset_g = id * deal_g; + deal_g = id < (skip_n - 1) ? deal_g : grid_total - deal_g * (skip_n - 1); + } + + const int32_t float_align = NFU_ALIGN_SIZE / sizeof(float); + int32_t deal_num; + int32_t cut_channel_iter = 2; + + const size_t spatial_size = + PAD_UP(num_levels * 2 * sizeof(int32_t), NFU_ALIGN_SIZE); + const size_t level_start_index_size = + PAD_UP(num_levels * sizeof(int32_t), NFU_ALIGN_SIZE); + + int32_t channel = channels; + int32_t mult; + while (true) { + deal_num = (MAX_NRAM_SIZE - spatial_size - level_start_index_size) / + (8 * channel + 7) / sizeof(T); + deal_num = PAD_DOWN(deal_num, float_align); + deal_num = PAD_DOWN(deal_num, num_levels * num_points); + if (deal_num > 0) { + break; + } else { + channel = channels / cut_channel_iter; + cut_channel_iter += 2; + } + } + mult = channel; + + const int32_t c_rep = channels / channel; + const int32_t c_rem = channels % channel; + + const int32_t g_rep = deal_g / deal_num; + const int32_t g_rem = deal_g % deal_num; + + // nram buffer alloc + char *data_spatial_shapes_nram = nram_buffer; + char *data_level_start_index_nram = data_spatial_shapes_nram + spatial_size; + char *input_tl = data_level_start_index_nram + level_start_index_size; + char *input_tr = input_tl + deal_num * mult * sizeof(T); + char *input_bl = input_tr + deal_num * mult * sizeof(T); + char *input_br = input_bl + deal_num * mult * sizeof(T); + char *weight_tl = input_tl + 4 * deal_num * mult * sizeof(T); + char *weight_tr = weight_tl + deal_num * mult * sizeof(T); + char *weight_bl = weight_tr + deal_num * mult * sizeof(T); + char *weight_br = weight_bl + deal_num * mult * sizeof(T); + char *mask_tl = weight_br + deal_num * mult * sizeof(T); + char *mask_tr = mask_tl + deal_num * sizeof(T); + char *mask_bl = mask_tr + deal_num * sizeof(T); + char *mask_br = mask_bl + deal_num * sizeof(T); + char *point_ram = mask_br + deal_num * sizeof(T); + char *index_tl = point_ram + deal_num * sizeof(T); + char *index_bl = index_tl + deal_num * sizeof(T); + + // nram space reuse + char *grid_ram = weight_tl; + char *mask_ram = weight_bl; + char *coord_x = input_bl; + char *coord_y = coord_x + deal_num * sizeof(T); + char *coord_x_low = input_tl; + char *coord_y_low = coord_x_low + deal_num * sizeof(T); + char *coord_x_low_int = weight_tl; + char *coord_y_low_int = weight_tr; + char *spatial_x = mask_tl; + char *spatial_y = mask_tr; + char *spatial_x_float = weight_bl; + char *spatial_y_float = weight_br; + char *spatial_x_temp = mask_bl; + char *spatial_y_temp = mask_br; + char *base_ptr_offset = weight_tl; + char *auxiliary_a = point_ram; + char *auxiliary_b = weight_bl; + + __memcpy_async(data_spatial_shapes_nram, data_spatial_shapes_gdram, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(data_level_start_index_nram, data_level_start_index_gdram, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __asm__ volatile("sync;"); + + for (int32_t batch_idx = batch_start; + batch_idx < batch_start + block_num_per_core; ++batch_idx) { + for (int32_t grid_iter = 0; grid_iter <= g_rep; ++grid_iter) { + int32_t io_data_num = deal_num; + const int32_t grid_off_base = + batch_idx * grid_total + offset_g + grid_iter * deal_num; + if (grid_iter == g_rep) { + if (g_rem == 0) { + continue; + } else { + io_data_num = g_rem; + } + } + + char *data_col_gdram_start = + data_col_gdram + (batch_idx * num_queries * num_heads * channels + + (offset_g + grid_iter * deal_num) / + (num_levels * num_points) * channels) * + sizeof(float); + + // load data_sampling_loc + __memcpy_async( + grid_ram, data_sampling_loc_gdram + grid_off_base * 2 * sizeof(float), + io_data_num * 2 * sizeof(float), GDRAM2NRAM); + genMask0101((float *)mask_ram, deal_num * 2); + __asm__ volatile("sync;"); + + // generate x and y coordinate vector + // generate spatial_x and spatial_y spatial vector + __bang_collect((float *)coord_y, (float *)grid_ram, (float *)mask_ram, + deal_num * 2); // y + __bang_collect((float *)spatial_x_temp, (float *)data_spatial_shapes_nram, + (float *)mask_ram, + num_levels * 2); // spatial_x + __bang_not((float *)mask_ram, (float *)mask_ram, deal_num * 2); + __bang_collect((float *)coord_x, (float *)grid_ram, (float *)mask_ram, + deal_num * 2); // x + __bang_collect((float *)spatial_y_temp, (float *)data_spatial_shapes_nram, + (float *)mask_ram, + num_levels * 2); // spatial_y + + for (int32_t i = 0; i < num_levels; i++) { + __bang_write_value((int32_t *)spatial_x + i * num_points, num_points, + ((int32_t *)spatial_x_temp)[i]); + __bang_write_value((int32_t *)spatial_y + i * num_points, num_points, + ((int32_t *)spatial_y_temp)[i]); + } + + __bang_int322float_rd((float *)spatial_x_float, (int32_t *)spatial_x, + num_levels * num_points, 0); + __bang_int322float_rd((float *)spatial_y_float, (int32_t *)spatial_y, + num_levels * num_points, 0); + + // map x from [0, 1] to [0, spatial_x]; map y from [0, 1] to [0, + // spatial_y] + __bang_cycle_mul((float *)coord_x, (float *)coord_x, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_sub_scalar((float *)coord_x, (float *)coord_x, (float)0.5, + deal_num); + __bang_cycle_mul((float *)coord_y, (float *)coord_y, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_sub_scalar((float *)coord_y, (float *)coord_y, (float)0.5, + deal_num); + + __bang_floor((float *)coord_x_low, (float *)coord_x, deal_num); + __bang_floor((float *)coord_y_low, (float *)coord_y, deal_num); + + // calc index_tl + const int32_t w_stride = num_heads * channels; + __bang_float2int32_rd((int32_t *)coord_x_low_int, (float *)coord_x_low, + deal_num, 0); + __bang_float2int32_rd((int32_t *)coord_y_low_int, (float *)coord_y_low, + deal_num, 0); + __bang_cycle_mul((int32_t *)index_tl, (int32_t *)coord_y_low_int, + (int32_t *)spatial_x, deal_num, num_levels * num_points); + __bang_add((int32_t *)index_tl, (int32_t *)index_tl, + (int32_t *)coord_x_low_int, deal_num); + __bang_mul_scalar((int32_t *)index_tl, (int32_t *)index_tl, w_stride, + deal_num); + + const int32_t deal_lp_num = deal_num / (num_levels * num_points); + const int32_t h_rep = deal_lp_num / num_heads; + const int32_t h_rem = deal_lp_num % num_heads; + const int32_t head_start = + ((offset_g + grid_iter * deal_num) / (num_levels * num_points)) % + num_heads; + for (int32_t iter = 0; iter < num_heads; ++iter) { + ((int32_t *)base_ptr_offset)[iter] = + ((head_start + iter) % num_heads) * channels; + } + if (h_rep > 0) { + __memcpy((int32_t *)base_ptr_offset + num_heads, + (int32_t *)base_ptr_offset, num_heads * sizeof(int32_t), + NRAM2NRAM, num_heads * sizeof(int32_t), 0, h_rep - 1); + } + if (h_rep > 0 && h_rem > 0) { + __memcpy((int32_t *)base_ptr_offset + h_rep * num_heads, + (int32_t *)base_ptr_offset, h_rem * sizeof(int32_t), + NRAM2NRAM); + } + __bang_transpose((int32_t *)auxiliary_a, (int32_t *)index_tl, deal_lp_num, + num_levels * num_points); + __bang_cycle_add((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + (int32_t *)base_ptr_offset, deal_num, deal_lp_num); + __bang_transpose((int32_t *)index_tl, (int32_t *)auxiliary_a, + num_levels * num_points, deal_lp_num); + + // calc index_bl + __bang_mul_scalar((int32_t *)auxiliary_a, (int32_t *)spatial_x, w_stride, + deal_num); + __bang_cycle_add((int32_t *)index_bl, (int32_t *)index_tl, + (int32_t *)auxiliary_a, deal_num, + num_levels * num_points); + + // calc mask_tl, mask_tr, mask_bl, mask_br + __bang_sub_scalar((float *)spatial_x_float, (float *)spatial_x_float, + (float)1.0, deal_num); + __bang_sub_scalar((float *)spatial_y_float, (float *)spatial_y_float, + (float)1.0, deal_num); + // mask_tl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_low < spatial_y + __bang_ge_scalar((float *)mask_bl, (float *)coord_x_low, (float)0, + deal_num); + __bang_cycle_le((float *)mask_br, (float *)coord_x_low, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_bl, (float *)mask_bl, (float *)mask_br, + deal_num); + + __bang_ge_scalar((float *)mask_tr, (float *)coord_y_low, (float)0, + deal_num); + __bang_cycle_le((float *)mask_br, (float *)coord_y_low, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, + deal_num); + __bang_and((float *)mask_tl, (float *)mask_tr, (float *)mask_bl, + deal_num); + + // mask_tr : 0 <= coord_x_high < spatial_x && 0 <= coord_y_low < spatial_y + __bang_ge_scalar((float *)mask_br, (float *)coord_x_low, (float)(-1.0), + deal_num); + __bang_cycle_lt((float *)auxiliary_a, (float *)coord_x_low, + (float *)spatial_x_float, deal_num, + num_levels * num_points); + __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, + deal_num); + __bang_and((float *)mask_tr, (float *)mask_tr, (float *)mask_br, + deal_num); + + // mask_bl : 0 <= coord_x_low < spatial_x && 0 <= coord_y_high < spatial_y + __bang_ge_scalar((float *)auxiliary_a, (float *)coord_y_low, + (float)(-1.0), deal_num); + __bang_cycle_lt((float *)auxiliary_b, (float *)coord_y_low, + (float *)spatial_y_float, deal_num, + num_levels * num_points); + __bang_and((float *)auxiliary_a, (float *)auxiliary_a, + (float *)auxiliary_b, deal_num); + __bang_and((float *)mask_bl, (float *)mask_bl, (float *)auxiliary_a, + deal_num); + + // mask_br : 0 <= coord_x_high < spatial_x && 0 <= coord_y_high < + // spatial_y + __bang_and((float *)mask_br, (float *)mask_br, (float *)auxiliary_a, + deal_num); + + // calc inner point num + __bang_mul_scalar((float *)weight_tl, (float *)mask_tl, (float)7.0, + deal_num); + __bang_mul_scalar((float *)weight_tr, (float *)mask_tr, (float)5.0, + deal_num); + __bang_add((float *)weight_tl, (float *)weight_tl, (float *)weight_tr, + deal_num); + __bang_mul_scalar((float *)weight_tr, (float *)mask_bl, (float)3.0, + deal_num); + __bang_add((float *)point_ram, (float *)weight_tr, (float *)mask_br, + deal_num); + __bang_add((float *)point_ram, (float *)point_ram, (float *)weight_tl, + deal_num); + + // calc interpolation weight + __bang_sub((float *)weight_bl, (float *)coord_x_low, (float *)coord_x, + deal_num); + __bang_sub((float *)weight_br, (float *)coord_y_low, (float *)coord_y, + deal_num); + __bang_add_scalar((float *)weight_bl, (float *)weight_bl, (float)1.0, + deal_num); + __bang_add_scalar((float *)weight_br, (float *)weight_br, (float)1.0, + deal_num); + + __bang_sub((float *)weight_tl, (float *)coord_x, (float *)coord_x_low, + deal_num); + __bang_sub((float *)weight_tr, (float *)coord_y, (float *)coord_y_low, + deal_num); + __bang_mul((float *)input_tl, (float *)weight_bl, (float *)weight_br, + deal_num); + __bang_mul((float *)input_tl + deal_num, (float *)weight_br, + (float *)weight_tl, deal_num); + __bang_mul((float *)input_tl + 2 * deal_num, (float *)weight_bl, + (float *)weight_tr, deal_num); + __bang_mul((float *)input_tl + 3 * deal_num, (float *)weight_tl, + (float *)weight_tr, deal_num); + + __asm__ volatile("sync;"); + + // extend weight + const int32_t w_rep = channel / ELE_COUNT * ELE_COUNT; + const int32_t w_rem = channel % ELE_COUNT; + if (w_rem != 0) { + const int32_t data_sz = 1 * sizeof(float); + const int32_t dst_str = channel * sizeof(float); + for (int32_t iter = w_rep; iter < channel; ++iter) { + __memcpy_async((float *)weight_tl + iter, (float *)input_tl, data_sz, + NRAM2NRAM, dst_str, data_sz, 4 * deal_num - 1); + } + } + if (w_rep != 0) { + for (int32_t i = 0; i < 4 * deal_num; i++) { + __bang_write_value((float *)weight_tl + i * channel, w_rep, + ((float *)input_tl)[i]); + } + } + + __asm__ volatile("sync;"); + + const char *data_value_gdram_start = + data_value_gdram + + batch_idx * num_keys * num_heads * channels * sizeof(float); + const int32_t c_str = deal_num * channel * sizeof(float); + const int32_t cs_str = num_heads * channels * sizeof(float); + + for (int32_t c_iter = 0; c_iter <= c_rep; ++c_iter) { + int32_t c_real_num = channel; + if (c_iter == c_rep) { + if (c_rem == 0) { + continue; + } else { + c_real_num = c_rem; + } + } + + __bang_write_zero((float *)input_tl, 4 * deal_num * channel); + __asm__ volatile("sync;"); + + // load data_value + for (int32_t p_idx = 0; p_idx < io_data_num; ++p_idx) { + const int32_t inner_point_num = (int32_t)((float *)point_ram)[p_idx]; + const int32_t tl_offset = ((int32_t *)index_tl)[p_idx]; + const int32_t bl_offset = ((int32_t *)index_bl)[p_idx]; + const int32_t level_start_id = + ((int32_t *)data_level_start_index_nram)[(p_idx / num_points) % + num_levels]; + const char *data_value_ptr = + data_value_gdram_start + + (level_start_id * num_heads * channels + c_iter * channel) * + sizeof(float); + + switch (inner_point_num) { + case 16: // 4 points are cached. + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 12: // 2 points are cached. (top_left, top_right) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 4: // 2 points are cached. (bottom_left, bottom_right) + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM, c_str, + cs_str, 1); + break; + case 10: // 2 points are cached. (top_left, bottom_left) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 6: // 2 points are cached. (top_right, bottom_right) + __memcpy_async( + (float *)input_tr + p_idx * channel, + (float *)data_value_ptr + tl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + __memcpy_async( + (float *)input_br + p_idx * channel, + (float *)data_value_ptr + bl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 7: // 1 point is cached. (top_left) + __memcpy_async((float *)input_tl + p_idx * channel, + (float *)data_value_ptr + tl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 5: // 1 point is cached. (top_right) + __memcpy_async( + (float *)input_tr + p_idx * channel, + (float *)data_value_ptr + tl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 3: // 1 point is cached. (bottom_left) + __memcpy_async((float *)input_bl + p_idx * channel, + (float *)data_value_ptr + bl_offset, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + case 1: // 1 point is cached. (bottom_right) + __memcpy_async( + (float *)input_br + p_idx * channel, + (float *)data_value_ptr + bl_offset + num_heads * channels, + c_real_num * sizeof(float), GDRAM2NRAM); + break; + default: + continue; + } + } + + __asm__ volatile("sync;"); + + // interpolation + __bang_mul((float *)input_tl, (float *)input_tl, (float *)weight_tl, + 4 * deal_num * channel); + __bang_add((float *)input_tl, (float *)input_tl, (float *)input_bl, + 2 * deal_num * channel); + __bang_add((float *)input_tl, (float *)input_tl, (float *)input_tr, + deal_num * channel); + + // load attention weight + void *attn_weight = mask_tl; + __memcpy((float *)attn_weight, + (float *)data_attn_weight_gdram + grid_off_base, + io_data_num * sizeof(float), GDRAM2NRAM); + + // calc data_col, muladd attention weight + __bang_transpose((float *)input_tr, (float *)input_tl, deal_num, + channel); + __bang_cycle_mul((float *)input_tr, (float *)input_tr, + (float *)attn_weight, deal_num * channel, deal_num); + __bang_transpose((float *)input_tl, (float *)input_tr, channel, + deal_num); + __bang_sumpool((float *)input_bl, (float *)input_tl, channel, 1, + io_data_num, 1, num_levels * num_points, + num_levels * num_points, 1); + + // store + __memcpy((float *)data_col_gdram_start + c_iter * channel, + (float *)input_bl, c_real_num * sizeof(float), NRAM2GDRAM, + channels * sizeof(float), channel * sizeof(float), + (io_data_num / (num_levels * num_points)) - 1); + } + } + } + __asm__ volatile("sync;"); +#endif + return; +} + +template __mlu_global__ void MLUKernelMsDeformAttnForwardDefault( const char *data_value_gdram, const char *data_spatial_shapes_gdram, const char *data_level_start_index_gdram, const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, @@ -552,7 +1051,7 @@ template __mlu_global__ void MLUKernelMsDeformAttnForward( const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char *data_col_gdram); -void KernelMsDeformAttnForward( +void KernelMsDeformAttnForwardDefault( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const char *data_value_gdram, const char *data_spatial_shapes_gdram, @@ -561,7 +1060,30 @@ void KernelMsDeformAttnForward( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char *data_col_gdram) { - MLUKernelMsDeformAttnForward<<>>( + MLUKernelMsDeformAttnForwardDefault<<>>( + data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, + data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); +} + +template __mlu_global__ void MLUKernelMsDeformAttnForwardSmallChannel( + const char *data_value_gdram, const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram); + +void KernelMsDeformAttnForwardSmallChannel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char *data_value_gdram, + const char *data_spatial_shapes_gdram, + const char *data_level_start_index_gdram, + const char *data_sampling_loc_gdram, const char *data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char *data_col_gdram) { + MLUKernelMsDeformAttnForwardSmallChannel<<>>( data_value_gdram, data_spatial_shapes_gdram, data_level_start_index_gdram, data_sampling_loc_gdram, data_attn_weight_gdram, batch_size, num_keys, num_heads, channels, num_levels, num_queries, num_points, data_col_gdram); @@ -584,15 +1106,15 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr; __memcpy(grad_output_nram, data_value_ptr + offset1, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram, hw, deal_num_real); + __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram, hh, deal_num_real); + __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w1, deal_num_real); // for calc grad_attn_weight - __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num); + __bang_mul_scalar(grad_output_nram, grad_output_nram, w1, deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset1), (T *)top_grad_temp, deal_num_real); } @@ -600,18 +1122,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset2, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); - __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); + __bang_sub(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, hh, deal_num_real); + __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w2, deal_num_real); __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w2, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset2), (T *)top_grad_temp, deal_num_real); } @@ -619,18 +1141,18 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset3, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); - __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, hw, deal_num_real); + __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); + __bang_sub(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w3, deal_num_real); // for calc grad_attn_weight __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w3, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset3), (T *)top_grad_temp, deal_num_real); } @@ -638,63 +1160,61 @@ void __mlu_func__ msDeformAttnCol2imBilinear( int32_t offset4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr; __memcpy(grad_output_nram_temp, data_value_ptr + offset4, deal_num_real * sizeof(T), GDRAM2NRAM); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num); - __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num); - __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num); - __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lw, deal_num_real); + __bang_add(grad_h_weight, grad_h_weight, grad_weight, deal_num_real); + __bang_mul_scalar(grad_weight, grad_output_nram_temp, lh, deal_num_real); + __bang_add(grad_w_weight, grad_w_weight, grad_weight, deal_num_real); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad_temp, w4, deal_num_real); // for calc grad_attn_weight __bang_mul_scalar(grad_output_nram_temp, grad_output_nram_temp, w4, - deal_num); + deal_num_real); __bang_add(grad_output_nram, grad_output_nram, grad_output_nram_temp, - deal_num); + deal_num_real); __bang_atomic_add((T *)top_grad_temp, (T *)(grad_value + offset4), (T *)top_grad_temp, deal_num_real); } - __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num); + __bang_mul(grad_output_nram, grad_output_nram, top_grad, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_output_nram, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else - const int32_t align_num_on_200 = NFU_ALIGN_SIZE / sizeof(float); + const int32_t align_num_on_200 = NFU_ALIGN_SIZE / LEN_FLOAT; recursiveSumPool(grad_output_nram, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); __bang_reduce_sum(grad_output_nram, grad_output_nram, - NFU_ALIGN_SIZE / sizeof(float)); + NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_output_nram, (T *)grad_attn_weight, (T *)grad_output_nram, 1); - __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num); - __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num); - __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num); + __bang_mul_scalar(grad_w_weight, grad_w_weight, width, deal_num_real); + __bang_mul_scalar(top_grad_temp, top_grad, data_attn_weight, deal_num_real); + __bang_mul(grad_w_weight, grad_w_weight, top_grad_temp, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_w_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else recursiveSumPool(grad_w_weight, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_w_weight, grad_w_weight, - NFU_ALIGN_SIZE / sizeof(float)); + __bang_reduce_sum(grad_w_weight, grad_w_weight, NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_w_weight, (T *)(grad_sampling_loc), (T *)grad_w_weight, 1); - __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num); - __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num); + __bang_mul_scalar(grad_h_weight, grad_h_weight, height, deal_num_real); + __bang_mul(grad_h_weight, grad_h_weight, top_grad_temp, deal_num_real); #if __BANG_ARCH__ >= 322 recursiveSumPool(grad_h_weight, 1, deal_num_real, ALIGN_NUM_FOR_REDUCE); #else recursiveSumPool(grad_h_weight, align_num_on_200, deal_num / align_num_on_200, ALIGN_NUM_FOR_REDUCE); - __bang_reduce_sum(grad_h_weight, grad_h_weight, - NFU_ALIGN_SIZE / sizeof(float)); + __bang_reduce_sum(grad_h_weight, grad_h_weight, NFU_ALIGN_SIZE / LEN_FLOAT); #endif __bang_atomic_add((T *)grad_h_weight, (T *)(grad_sampling_loc + 1), (T *)grad_h_weight, 1); } -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, const float *data_sampling_loc, const float *data_attn_weight, const float *grad_output, @@ -708,8 +1228,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t split_num = 8; const int32_t spatial_shapes_size = 64; int32_t deal_num = PAD_DOWN( - (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / sizeof(float), - ALIGN_NUM); + (MAX_NRAM_SIZE - spatial_shapes_size) / split_num / LEN_FLOAT, ALIGN_NUM); float *grad_output_nram = (float *)nram_buffer; float *grad_output_nram_temp = (float *)nram_buffer + deal_num; float *grad_weight = (float *)nram_buffer + 2 * deal_num; @@ -725,10 +1244,8 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( int32_t num_per_core = total_num / taskDim; int32_t num_rem = total_num % taskDim; num_per_core = num_per_core + int32_t(taskId < num_rem); - int32_t start_per_core = - num_rem > taskId - ? (taskId * num_per_core) - : ((num_per_core + 1) * num_rem + (taskId - num_rem) * num_per_core); + int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) + : (num_rem + taskId * num_per_core); int32_t end_per_core = start_per_core + num_per_core; const int32_t C_repeat = channels / deal_num; const int32_t C_tail = channels % deal_num; @@ -758,7 +1275,7 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t grad_sampling_loc_out = num_loop * num_points * 2; for (int32_t p_col = 0; p_col < num_points; ++p_col) { __memcpy(sampling_loc_nram, data_sampling_loc + data_loc_w_ptr, - 2 * sizeof(float), GDRAM2NRAM); + 2 * LEN_FLOAT, GDRAM2NRAM); const float loc_w = sampling_loc_nram[0]; const float loc_h = sampling_loc_nram[1]; const float weight = data_attn_weight[data_weight_ptr]; @@ -789,11 +1306,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( for (int32_t C_loop = 0; C_loop < C_repeat; ++C_loop) { base_ptr = m_col * channels + C_loop * deal_num; - __bang_write_zero(grad_weight, 3 * deal_num); - __bang_write_zero(grad_output_nram, deal_num); + __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); __memcpy(top_grad, grad_output + grad_output_offset + C_loop * deal_num, - deal_num * sizeof(float), GDRAM2NRAM); + deal_num * LEN_FLOAT, GDRAM2NRAM); msDeformAttnCol2imBilinear( top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, @@ -806,10 +1324,12 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( } if (C_tail != 0) { base_ptr = m_col * channels + C_repeat * deal_num; - __bang_write_zero(grad_output_nram, 8 * deal_num); + __bang_write_zero(grad_h_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_w_weight, PAD_UP(channels, ALIGN_NUM)); + __bang_write_zero(grad_output_nram, PAD_UP(channels, ALIGN_NUM)); __memcpy(top_grad, grad_output + grad_output_offset + C_repeat * deal_num, - C_tail * sizeof(float), GDRAM2NRAM); + C_tail * LEN_FLOAT, GDRAM2NRAM); msDeformAttnCol2imBilinear( top_grad_temp, spatial_h, spatial_w, w1, w2, w3, w4, h_low, w_low, h_high, w_high, base_ptr, h_low_ptr_offset, w_low_ptr_offset, @@ -827,7 +1347,711 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( } } -__mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( +void __mlu_func__ computeGridMaskAndOffset( + float *nram_grad_output_tl, float *nram_grad_output_tr, float *nram_loc_w, + float *nram_loc_h, float *nram_h_stride, int32_t *nram_spatial_shapes, + float *nram_w_low_temp, float *nram_h_high_temp, float *nram_w_low, + float *nram_h_low, float *nram_h_high, float *nram_w_high, float *nram_lh, + float *nram_lw, float *nram_hh, float *nram_hw, + float *nram_h_low_ptr_offset, float *nram_h_high_ptr_offset, + float *nram_w_low_ptr_offset, float *nram_w_high_ptr_offset, float *nram_w1, + float *nram_w2, float *nram_w3, float *nram_w4, float *nram_offset_temp, + float *nram_offset1, float *nram_offset2, float *nram_offset3, + float *nram_offset4, float *nram_base_ptr, float *nram_h_low_temp, + int32_t num_deal_grid, int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, const int32_t w_stride, + const int32_t qid_stride) { +#if __BANG_ARCH__ >= 322 + // [num_levels, 2] --> [2, num_levels] + __bang_transpose(nram_grad_output_tl, nram_loc_w, num_deal_grid, 2); + __bang_transpose(nram_loc_w, nram_grad_output_tl, + num_per_time_real * num_heads * num_levels, num_points); + __bang_transpose(nram_loc_h, nram_grad_output_tl + num_deal_grid, + num_per_time_real * num_heads * num_levels, num_points); + __bang_int322float((float *)nram_spatial_shapes, + (int32_t *)nram_spatial_shapes, num_levels * 2, 0); + __bang_transpose(nram_grad_output_tr, (float *)nram_spatial_shapes, + num_levels, 2); + __bang_mul_scalar(nram_h_stride, nram_grad_output_tr + num_levels, w_stride, + num_levels); + __memcpy_async(nram_spatial_shapes, nram_grad_output_tr, + num_levels * 2 * sizeof(float), NRAM2NRAM); + __bang_cycle_mul(nram_loc_w, nram_loc_w, + (float *)nram_spatial_shapes + num_levels, num_deal_grid, + num_levels); + __bang_cycle_mul(nram_loc_h, nram_loc_h, (float *)(nram_spatial_shapes), + num_deal_grid, num_levels); + __bang_sub_scalar(nram_loc_w, nram_loc_w, 0.5, num_deal_grid); + __bang_sub_scalar(nram_loc_h, nram_loc_h, 0.5, num_deal_grid); + // get mask. (h_im > -1 && w_im > -1 && + // h_im < spatial_h && w_im < spatial_w) + __bang_cycle_lt(nram_w_low_temp, nram_loc_w, + (float *)(nram_spatial_shapes + num_levels), num_deal_grid, + num_levels); + __bang_cycle_lt(nram_h_high_temp, nram_loc_h, (float *)(nram_spatial_shapes), + num_deal_grid, num_levels); + __bang_and(nram_w_low_temp, nram_w_low_temp, nram_h_high_temp, num_deal_grid); + __bang_gt_scalar(nram_h_high_temp, nram_loc_h, -1, num_deal_grid); + __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, + num_deal_grid); + __bang_gt_scalar(nram_w_low_temp, nram_loc_w, -1, num_deal_grid); + __bang_and(nram_h_high_temp, nram_h_high_temp, nram_w_low_temp, + num_deal_grid); + __bang_transpose(nram_w_low_temp, nram_h_high_temp, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_high_temp, nram_w_low_temp, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, nram_loc_w, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_loc_w, nram_grad_output_tl, num_deal_grid * sizeof(float), + NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, nram_loc_h, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_loc_h, nram_grad_output_tl, num_deal_grid * sizeof(float), + NRAM2NRAM); + __bang_floor(nram_w_low, nram_loc_w, num_deal_grid); + __bang_floor(nram_h_low, nram_loc_h, num_deal_grid); + __bang_add_scalar(nram_h_high, nram_h_low, 1, num_deal_grid); + __bang_add_scalar(nram_w_high, nram_w_low, 1, num_deal_grid); + __bang_sub(nram_lh, nram_loc_h, nram_h_low, num_deal_grid); + __bang_sub(nram_lw, nram_loc_w, nram_w_low, num_deal_grid); + __bang_fusion(FUSION_FMA, nram_hh, nram_lh, (float)(-1), 1, num_deal_grid); + __bang_fusion(FUSION_FMA, nram_hw, nram_lw, (float)(-1), 1, num_deal_grid); + __bang_transpose(nram_h_low_ptr_offset, nram_h_low, + num_per_time_real * num_heads * num_levels, num_points); + __bang_cycle_mul(nram_h_low_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, + num_deal_grid, num_levels); + __bang_cycle_add(nram_h_high_ptr_offset, nram_h_low_ptr_offset, nram_h_stride, + num_deal_grid, num_levels); + __bang_transpose(nram_w_low_ptr_offset, nram_h_low_ptr_offset, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_low_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_w_low_ptr_offset, nram_h_high_ptr_offset, num_points, + num_per_time_real * num_heads * num_levels); + __memcpy_async(nram_h_high_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_mul_scalar(nram_w_low_ptr_offset, nram_w_low, qid_stride, + num_deal_grid); + __bang_add_scalar(nram_w_high_ptr_offset, nram_w_low_ptr_offset, qid_stride, + num_deal_grid); + __bang_mul(nram_w1, nram_hh, nram_hw, num_deal_grid); + __bang_mul(nram_w2, nram_hh, nram_lw, num_deal_grid); + __bang_mul(nram_w3, nram_lh, nram_hw, num_deal_grid); + __bang_mul(nram_w4, nram_lh, nram_lw, num_deal_grid); + __bang_add(nram_offset1, nram_h_low_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset1, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset1, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset2, nram_h_low_ptr_offset, nram_w_high_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset2, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset2, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset3, nram_h_high_ptr_offset, nram_w_low_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset3, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset3, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + __bang_add(nram_offset4, nram_h_high_ptr_offset, nram_w_high_ptr_offset, + num_deal_grid); + __bang_transpose(nram_offset_temp, nram_offset4, + num_per_time_real * num_heads, num_levels * num_points); + __bang_cycle_add(nram_offset_temp, nram_offset_temp, nram_base_ptr, + num_deal_grid, num_heads); + __bang_transpose(nram_offset4, nram_offset_temp, num_levels * num_points, + num_per_time_real * num_heads); + // h_low >= 0 && w_low >= 0 mask2 + float *mask1 = nram_h_low_ptr_offset; + float *mask2 = nram_h_high_ptr_offset; + float *mask3 = nram_w_low_ptr_offset; + float *mask4 = nram_w_high_ptr_offset; + __bang_ge_scalar(mask1, nram_h_low, 0, num_deal_grid); + __bang_ge_scalar(mask2, nram_w_low, 0, num_deal_grid); + __bang_and(mask2, mask1, mask2, num_deal_grid); + __bang_and(mask2, nram_h_high_temp, mask2, num_deal_grid); + // h_low >= 0 && w_high <= width - 1 mask1 + __bang_transpose(mask3, nram_w_high, + num_per_time_real * num_heads * num_levels, num_points); + __bang_sub_scalar(nram_spatial_shapes, nram_spatial_shapes, 1, + num_levels * 2); + __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes + num_levels), + num_deal_grid, num_levels); + __bang_transpose(mask4, mask3, num_points, + num_per_time_real * num_heads * num_levels); + __bang_and(mask1, mask1, mask4, num_deal_grid); + __bang_and(mask1, nram_h_high_temp, mask1, num_deal_grid); + // h_high <= height - 1 && w_high <= width - 1 mask3 + __bang_transpose(mask3, nram_h_high, + num_per_time_real * num_heads * num_levels, num_points); + __bang_cycle_le(mask3, mask3, (float *)(nram_spatial_shapes), num_deal_grid, + num_levels); + + __bang_transpose(nram_h_low_temp, mask3, num_points, + num_per_time_real * num_heads * num_levels); + __bang_and(mask4, mask4, nram_h_low_temp, num_deal_grid); + __bang_and(mask3, mask4, nram_h_high_temp, num_deal_grid); + // h_high <= height - 1 && w_low >= 0 mask4 + __bang_ge_scalar(nram_w_low_temp, nram_w_low, 0, num_deal_grid); + __bang_and(mask4, nram_h_low_temp, nram_w_low_temp, num_deal_grid); + __bang_and(mask4, mask4, nram_h_high_temp, num_deal_grid); +#endif +} + +void __mlu_func__ loadValue( + float *nram_grad_output_tl, float *nram_grad_output_tr, + float *nram_grad_output_bl, float *nram_grad_output_br, + const float *data_value, const float *grad_output, float *grad_temp1, + float *grad_temp2, float *mask1, float *mask2, float *mask3, float *mask4, + float *nram_offset1, float *nram_offset2, float *nram_offset3, + float *nram_offset4, float *nram_grad_weight, + int32_t *nram_level_start_index, int32_t offset_nram, + int32_t start_per_core, int32_t grid_loop, int32_t num_per_time_theory, + int32_t num_heads, int32_t deal_num_real, int32_t num_per_time_real, + int32_t num_deal_grid, const int32_t num_query, const int32_t num_levels, + const int32_t num_points, int32_t grid_offset, const int32_t spatial_size, + const int32_t qid_stride) { +#if __BANG_ARCH__ >= 322 + int32_t value_offset_temp = 0; + __bang_write_zero(nram_grad_output_tl, 4 * offset_nram); + __sync_io_move_compute(); + __memcpy_async( + grad_temp2, + grad_output + (start_per_core + grid_loop * num_per_time_theory) * + num_heads * deal_num_real, + num_per_time_real * num_heads * deal_num_real * sizeof(float), + GDRAM2NRAM); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + const int32_t b_col = + (grid_offset + loop) / num_query / num_heads / num_levels / num_points; + const int32_t l_col = (grid_offset + loop) / num_points % num_levels; + const int32_t level_start_id = nram_level_start_index[l_col]; + value_offset_temp = + b_col * spatial_size * qid_stride + level_start_id * qid_stride; + if (mask2[loop]) { + __memcpy_async( + nram_grad_output_tl + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset1[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask1[loop]) { + __memcpy_async( + nram_grad_output_tr + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset2[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask4[loop]) { + __memcpy_async( + nram_grad_output_bl + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset3[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + if (mask3[loop]) { + __memcpy_async( + nram_grad_output_br + loop * deal_num_real, + data_value + value_offset_temp + int32_t(nram_offset4[loop]), + deal_num_real * sizeof(float), GDRAM2NRAM); + } + } + for (int32_t m = 0; m < deal_num_real; ++m) { + __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + } + __sync_io_move_compute(); +#endif +} + +void __mlu_func__ computeGradValue( + float *grad_temp1, float *grad_temp2, float *grad_temp3, float *grad_temp4, + float *mask1, float *mask2, float *mask3, float *mask4, float *nram_offset1, + float *nram_offset2, float *nram_offset3, float *nram_offset4, + int32_t *nram_level_start_index, int32_t deal_num_real, + const float *grad_value, float *nram_w1, float *nram_w2, float *nram_w3, + float *nram_w4, int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, const int32_t num_query, + int32_t num_deal_grid, int32_t grid_offset, const int32_t spatial_size, + const int32_t qid_stride, float *nram_grid_offset1, + float *nram_grid_offset2) { +#if __BANG_ARCH__ >= 322 + __bang_transpose(grad_temp3, grad_temp1, + deal_num_real * num_per_time_real * num_heads, + num_levels * num_points); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(grad_temp3, grad_temp3, grad_temp1, + num_deal_grid * deal_num_real, + deal_num_real * num_per_time_real * num_heads); + __bang_transpose(grad_temp4, grad_temp3, num_levels * num_points, + deal_num_real * num_per_time_real * num_heads); + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w1, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + nram_grid_offset1[loop] = ((loop + grid_offset) / num_query / num_heads / + num_levels / num_points) * + spatial_size * qid_stride; + } + __bang_transpose(nram_grid_offset2, nram_grid_offset1, + num_per_time_real * num_heads * num_levels, num_points); + __bang_int322float((float *)nram_level_start_index, nram_level_start_index, + num_levels, 0); + __bang_mul_scalar(nram_grid_offset1, (float *)nram_level_start_index, + qid_stride, num_levels); + __bang_cycle_add(nram_grid_offset2, nram_grid_offset2, nram_grid_offset1, + num_deal_grid, num_levels); + __bang_transpose(nram_grid_offset1, nram_grid_offset2, num_points, + num_per_time_real * num_heads * num_levels); + __bang_add(nram_offset1, nram_offset1, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset2, nram_offset2, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset3, nram_offset3, nram_grid_offset1, num_deal_grid); + __bang_add(nram_offset4, nram_offset4, nram_grid_offset1, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask2[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset1[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w2, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask1[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset2[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w3, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask4[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset3[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } + + __bang_cycle_mul(grad_temp1, grad_temp4, nram_w4, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_transpose(grad_temp3, grad_temp1, deal_num_real, num_deal_grid); + for (int32_t loop = 0; loop < num_deal_grid; ++loop) { + if (mask3[loop]) { + __bang_atomic_add((float *)(grad_temp3 + loop * deal_num_real), + (float *)(grad_value + int32_t(nram_offset4[loop])), + (float *)(grad_temp3 + loop * deal_num_real), + deal_num_real); + } + } +#endif +} + +void __mlu_func__ computeGradAttnWeight( + float *grad_w_weight, float *grad_weight, float *nram_grad_output_tl, + float *nram_grad_output_tr, float *nram_grad_output_bl, + float *nram_grad_output_br, float *grad_temp1, float *grad_temp2, + const float *grad_attn_weight, float *nram_hw, float *nram_hh, + float *nram_lw, float *nram_lh, float *grad_h_weight, float *nram_w1, + float *nram_w2, float *nram_w3, float *nram_w4, int32_t offset_nram, + int32_t num_deal_grid, int32_t deal_num_real, int32_t num_per_time_real, + const int32_t num_heads, const int32_t num_levels, const int32_t num_points, + int32_t grid_offset, float *nram_h_high_temp) { +#if __BANG_ARCH__ >= 322 + __bang_write_zero(grad_w_weight, 2 * offset_nram); + + // grad_output_nram_tl + __bang_transpose(grad_weight, nram_grad_output_tl, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_hh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tl, grad_weight, nram_w1, + num_deal_grid * deal_num_real, num_deal_grid); + // nram_grad_output_tr + __bang_transpose(grad_weight, nram_grad_output_tr, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_lw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_h_weight, grad_h_weight, nram_grad_output_tr, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_hh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_tr, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, grad_weight, nram_w2, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_tr, + num_deal_grid * deal_num_real); + // nram_grad_output_tl + __bang_transpose(grad_weight, nram_grad_output_bl, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_hw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_bl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_lh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_sub(grad_w_weight, grad_w_weight, nram_grad_output_bl, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_bl, grad_weight, nram_w3, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_bl, + num_deal_grid * deal_num_real); + // nram_grad_output_br + __bang_transpose(grad_weight, nram_grad_output_br, num_deal_grid, + deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lw, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_h_weight, grad_h_weight, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_lh, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(grad_w_weight, grad_w_weight, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_cycle_mul(nram_grad_output_br, grad_weight, nram_w4, + num_deal_grid * deal_num_real, num_deal_grid); + __bang_add(nram_grad_output_tl, nram_grad_output_tl, nram_grad_output_br, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_br, nram_grad_output_tl, deal_num_real, + num_deal_grid); + __bang_transpose(nram_grad_output_tr, nram_grad_output_br, + num_per_time_real * num_heads, + num_points * num_levels * deal_num_real); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, + num_deal_grid * deal_num_real, + num_per_time_real * num_heads * deal_num_real); + __bang_transpose(nram_grad_output_br, nram_grad_output_tr, + num_points * num_levels * deal_num_real, + num_per_time_real * num_heads); + + __bang_transpose((float *)nram_grad_output_tr, (float *)nram_grad_output_br, + num_deal_grid, deal_num_real); + recursiveSumPool(nram_grad_output_tr, num_deal_grid, deal_num_real, + ALIGN_NUM); + __bang_float2int32((int *)nram_h_high_temp, nram_h_high_temp, num_deal_grid, + 0); + __nram__ int table[2] = {0, (int)0xffffffff}; + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)nram_grad_output_tr, (char *)nram_grad_output_tr, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + + __bang_atomic_add((float *)nram_grad_output_tr, + (float *)grad_attn_weight + grid_offset, + (float *)nram_grad_output_tr, num_deal_grid); +#endif +} + +void __mlu_func__ computeGradSampingLoc( + const float *grad_sampling_loc, float *nram_grad_output_tl, + float *nram_grad_output_tr, float *grad_h_weight, float *grad_w_weight, + int32_t *nram_spatial_shapes, float *grad_temp1, float *grad_temp2, + float *nram_grad_weight, int32_t num_deal_grid, int32_t deal_num_real, + int32_t num_per_time_real, const int32_t num_heads, + const int32_t num_levels, const int32_t num_points, int32_t grid_offset, + float *nram_h_high_temp) { +#if __BANG_ARCH__ >= 322 + __bang_transpose(nram_grad_output_tl, grad_h_weight, + num_per_time_real * num_heads * num_levels * deal_num_real, + num_points); + __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, + (float *)nram_spatial_shapes, num_deal_grid * deal_num_real, + num_levels); + __bang_transpose(grad_h_weight, nram_grad_output_tl, + num_points * deal_num_real, + num_per_time_real * num_heads * num_levels); + for (int32_t m = 0; m < deal_num_real; ++m) { + __memcpy_async(grad_temp1 + m * num_deal_grid, nram_grad_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + } + __sync_io_move_compute(); + __bang_transpose(nram_grad_output_tr, grad_temp1, + deal_num_real * num_per_time_real * num_heads, + num_levels * num_points); + __bang_transpose(grad_temp1, grad_temp2, num_per_time_real * num_heads, + deal_num_real); + __bang_cycle_mul(nram_grad_output_tr, nram_grad_output_tr, grad_temp1, + num_deal_grid * deal_num_real, + deal_num_real * num_per_time_real * num_heads); + __bang_transpose(grad_temp1, nram_grad_output_tr, + num_levels * num_points * deal_num_real, + num_per_time_real * num_heads); + __bang_mul(grad_h_weight, grad_h_weight, grad_temp1, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_tl, grad_h_weight, num_deal_grid, + deal_num_real); + __memcpy_async(grad_h_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); + recursiveSumPool(grad_h_weight, num_deal_grid, deal_num_real, ALIGN_NUM); + __nram__ int table[2] = {0, (int)0xffffffff}; + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)grad_h_weight, (char *)grad_h_weight, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + __bang_transpose(nram_grad_output_tl, grad_w_weight, + num_per_time_real * num_heads * num_levels * deal_num_real, + num_points); + __bang_cycle_mul(nram_grad_output_tl, nram_grad_output_tl, + (float *)(nram_spatial_shapes + num_levels), + num_deal_grid * deal_num_real, num_levels); + __bang_transpose(grad_w_weight, nram_grad_output_tl, + num_points * deal_num_real, + num_per_time_real * num_heads * num_levels); + __bang_mul(grad_w_weight, grad_w_weight, grad_temp1, + num_deal_grid * deal_num_real); + __bang_transpose(nram_grad_output_tl, grad_w_weight, num_deal_grid, + deal_num_real); + __memcpy(grad_w_weight, nram_grad_output_tl, + num_deal_grid * deal_num_real * sizeof(float), NRAM2NRAM); + recursiveSumPool(grad_w_weight, num_deal_grid, deal_num_real, ALIGN_NUM); + __bang_lut_s32((int *)nram_h_high_temp, (int *)nram_h_high_temp, (int *)table, + num_deal_grid, 64); + __bang_band((char *)grad_w_weight, (char *)grad_w_weight, + (char *)nram_h_high_temp, num_deal_grid * sizeof(float)); + + __memcpy(grad_w_weight + num_deal_grid, grad_h_weight, + num_deal_grid * sizeof(float), NRAM2NRAM); + __bang_transpose(nram_grad_output_tl, grad_w_weight, 2, num_deal_grid); + __bang_atomic_add((float *)nram_grad_output_tl, + (float *)grad_sampling_loc + grid_offset * 2, + (float *)nram_grad_output_tl, 2 * num_deal_grid); + +#endif +} + +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( + const float *data_value, const int32_t *spatial_shapes, + const int32_t *data_level_start_index, const float *data_sampling_loc, + const float *data_attn_weight, const float *grad_output, + const int32_t batch, const int32_t spatial_size, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_query, + const int32_t num_points, float *grad_value, float *grad_sampling_loc, + float *grad_attn_weight) { +#if __BANG_ARCH__ > 322 + const int32_t split_grid_num = 28; + const int32_t split_num_c = 8; + const int32_t C_align = PAD_UP(channels, ALIGN_NUM); + + const int32_t num_hlp = num_heads * num_levels * num_points; + int32_t num_per_time_theory = (MAX_NRAM_SIZE - num_levels * sizeof(float) - + 3 * num_levels * sizeof(int32_t)) / + sizeof(float) / + (split_num_c * C_align + split_grid_num) / + PAD_UP((num_hlp), ALIGN_NUM); + + int32_t deal_grid_num_theory = num_per_time_theory * num_hlp; + + const int32_t offset_nram = num_per_time_theory * C_align * num_hlp; + const int32_t offset_nram_calc = PAD_UP(deal_grid_num_theory, ALIGN_NUM); + float *nram_grad_output_tl = (float *)nram_buffer; + float *nram_grad_output_tr = (float *)nram_buffer + offset_nram; + float *nram_grad_output_bl = (float *)nram_buffer + 2 * offset_nram; + float *nram_grad_output_br = (float *)nram_buffer + 3 * offset_nram; + + float *grad_temp1 = (float *)nram_buffer + 4 * offset_nram; + float *grad_temp2 = (float *)nram_buffer + 5 * offset_nram; + float *grad_temp3 = (float *)nram_buffer + 6 * offset_nram; + float *grad_temp4 = (float *)nram_buffer + 7 * offset_nram; + + float *nram_loc_w = (float *)nram_buffer + split_num_c * offset_nram; + float *nram_loc_h = + (float *)nram_buffer + split_num_c * offset_nram + offset_nram_calc; + float *nram_h_low = + (float *)nram_buffer + split_num_c * offset_nram + 2 * offset_nram_calc; + float *nram_w_low = + (float *)nram_buffer + split_num_c * offset_nram + 3 * offset_nram_calc; + float *nram_h_high = + (float *)nram_buffer + split_num_c * offset_nram + 4 * offset_nram_calc; + float *nram_w_high = + (float *)nram_buffer + split_num_c * offset_nram + 5 * offset_nram_calc; + float *nram_h_low_temp = + (float *)nram_buffer + split_num_c * offset_nram + 6 * offset_nram_calc; + float *nram_h_high_temp = + (float *)nram_buffer + split_num_c * offset_nram + 7 * offset_nram_calc; + + float *nram_hw = + (float *)nram_buffer + split_num_c * offset_nram + 8 * offset_nram_calc; + float *nram_hh = + (float *)nram_buffer + split_num_c * offset_nram + 9 * offset_nram_calc; + float *nram_lw = + (float *)nram_buffer + split_num_c * offset_nram + 10 * offset_nram_calc; + float *nram_lh = + (float *)nram_buffer + split_num_c * offset_nram + 11 * offset_nram_calc; + + float *nram_h_low_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 12 * offset_nram_calc; + float *nram_h_high_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 13 * offset_nram_calc; + float *nram_w_low_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 14 * offset_nram_calc; + float *nram_w_high_ptr_offset = + (float *)nram_buffer + split_num_c * offset_nram + 15 * offset_nram_calc; + + float *nram_w1 = + (float *)nram_buffer + split_num_c * offset_nram + 16 * offset_nram_calc; + float *nram_w2 = + (float *)nram_buffer + split_num_c * offset_nram + 17 * offset_nram_calc; + float *nram_w3 = + (float *)nram_buffer + split_num_c * offset_nram + 18 * offset_nram_calc; + float *nram_w4 = + (float *)nram_buffer + split_num_c * offset_nram + 19 * offset_nram_calc; + + float *nram_grad_weight = + (float *)nram_buffer + split_num_c * offset_nram + 20 * offset_nram_calc; + float *nram_base_ptr = + (float *)nram_buffer + split_num_c * offset_nram + 21 * offset_nram_calc; + float *nram_offset_temp = + (float *)nram_buffer + split_num_c * offset_nram + 22 * offset_nram_calc; + + float *nram_offset1 = + (float *)nram_buffer + split_num_c * offset_nram + 23 * offset_nram_calc; + float *nram_offset2 = + (float *)nram_buffer + split_num_c * offset_nram + 24 * offset_nram_calc; + float *nram_offset3 = + (float *)nram_buffer + split_num_c * offset_nram + 25 * offset_nram_calc; + float *nram_offset4 = + (float *)nram_buffer + split_num_c * offset_nram + 26 * offset_nram_calc; + + float *nram_w_low_temp = + (float *)nram_buffer + split_num_c * offset_nram + 27 * offset_nram_calc; + int32_t *nram_spatial_shapes = + (int32_t *)((float *)nram_buffer + split_num_c * offset_nram + + 28 * offset_nram_calc); + int32_t *nram_level_start_index = + (int32_t *)(nram_spatial_shapes + 2 * num_levels); + float *nram_h_stride = (float *)(nram_level_start_index + 3 * num_levels); + const int32_t total_num = batch * num_query; + int32_t num_per_core = total_num / taskDim; + int32_t num_rem = total_num % taskDim; + num_per_core = num_per_core + int32_t(taskId < num_rem); + num_per_time_theory = + num_per_core > num_per_time_theory ? num_per_time_theory : num_per_core; + int32_t num_deal_grid = num_per_time_theory * num_hlp; + + if (num_per_core == 0) return; + int32_t start_per_core = num_rem > taskId ? (taskId * num_per_core) + : (num_rem + taskId * num_per_core); + + const int32_t qid_stride = num_heads * channels; + int32_t deal_num_real = channels; + + const int32_t repeat_times = num_per_core / num_per_time_theory; + const int32_t tail_num = num_per_core % num_per_time_theory; + + int32_t num_per_time_real = num_per_time_theory; + + for (int32_t loop = 0; loop < num_heads; ++loop) { + nram_base_ptr[loop] = loop * channels; + } + const int32_t w_stride = num_heads * channels; + for (int32_t grid_loop = 0; grid_loop < repeat_times + 1; grid_loop += 1) { + int32_t grid_offset = + (start_per_core + grid_loop * num_per_time_theory) * num_hlp; + if (grid_loop == repeat_times) { + if (tail_num == 0) { + continue; + } else { + grid_offset = + (start_per_core + repeat_times * num_per_time_theory) * num_hlp; + num_per_time_real = tail_num; + num_deal_grid = tail_num * num_hlp; + } + } + + __memcpy_async(nram_spatial_shapes, spatial_shapes, + num_levels * 2 * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(nram_level_start_index, data_level_start_index, + num_levels * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(nram_loc_w, data_sampling_loc + grid_offset * 2, + num_deal_grid * 2 * sizeof(float), GDRAM2NRAM); + __memcpy(nram_grad_weight, data_attn_weight + grid_offset, + num_deal_grid * sizeof(float), GDRAM2NRAM); + computeGridMaskAndOffset( + nram_grad_output_tl, nram_grad_output_tr, nram_loc_w, nram_loc_h, + nram_h_stride, nram_spatial_shapes, nram_w_low_temp, nram_h_high_temp, + nram_w_low, nram_h_low, nram_h_high, nram_w_high, nram_lh, nram_lw, + nram_hh, nram_hw, nram_h_low_ptr_offset, nram_h_high_ptr_offset, + nram_w_low_ptr_offset, nram_w_high_ptr_offset, nram_w1, nram_w2, + nram_w3, nram_w4, nram_offset_temp, nram_offset1, nram_offset2, + nram_offset3, nram_offset4, nram_base_ptr, nram_h_low_temp, + num_deal_grid, num_per_time_real, num_heads, num_levels, num_points, + w_stride, qid_stride); + float *mask1 = nram_h_low_ptr_offset; + float *mask2 = nram_h_high_ptr_offset; + float *mask3 = nram_w_low_ptr_offset; + float *mask4 = nram_w_high_ptr_offset; + loadValue(nram_grad_output_tl, nram_grad_output_tr, nram_grad_output_bl, + nram_grad_output_br, data_value, grad_output, grad_temp1, + grad_temp2, mask1, mask2, mask3, mask4, nram_offset1, + nram_offset2, nram_offset3, nram_offset4, nram_grad_weight, + nram_level_start_index, offset_nram, start_per_core, grid_loop, + num_per_time_theory, num_heads, deal_num_real, num_per_time_real, + num_deal_grid, num_query, num_levels, num_points, grid_offset, + spatial_size, qid_stride); + float *nram_grid_offset1 = nram_loc_h; + float *nram_grid_offset2 = nram_loc_w; + computeGradValue( + grad_temp1, grad_temp2, grad_temp3, grad_temp4, mask1, mask2, mask3, + mask4, nram_offset1, nram_offset2, nram_offset3, nram_offset4, + nram_level_start_index, deal_num_real, grad_value, nram_w1, nram_w2, + nram_w3, nram_w4, num_per_time_real, num_heads, num_levels, num_points, + num_query, num_deal_grid, grid_offset, spatial_size, qid_stride, + nram_grid_offset1, nram_grid_offset2); + + // compute grad_weight + float *grad_weight = grad_temp1; + float *grad_h_weight = grad_temp4; + float *grad_w_weight = grad_temp3; + computeGradAttnWeight( + grad_w_weight, grad_weight, nram_grad_output_tl, nram_grad_output_tr, + nram_grad_output_bl, nram_grad_output_br, grad_temp1, grad_temp2, + grad_attn_weight, nram_hw, nram_hh, nram_lw, nram_lh, grad_h_weight, + nram_w1, nram_w2, nram_w3, nram_w4, offset_nram, num_deal_grid, + deal_num_real, num_per_time_real, num_heads, num_levels, num_points, + grid_offset, nram_h_high_temp); + + // compute grad_sampling_loc + computeGradSampingLoc(grad_sampling_loc, nram_grad_output_tl, + nram_grad_output_tr, grad_h_weight, grad_w_weight, + nram_spatial_shapes, grad_temp1, grad_temp2, + nram_grad_weight, num_deal_grid, deal_num_real, + num_per_time_real, num_heads, num_levels, num_points, + grid_offset, nram_h_high_temp); + } +#endif +} + +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwarDefaultKernel( + const float *data_value, const int32_t *spatial_shapes, + const int32_t *data_level_start_index, const float *data_sampling_loc, + const float *data_attn_weight, const float *grad_output, + const int32_t batch, const int32_t spatial_size, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_query, + const int32_t num_points, float *grad_value, float *grad_sampling_loc, + float *grad_attn_weight); + +__mlu_global__ void MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel( const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, const float *data_sampling_loc, const float *data_attn_weight, const float *grad_output, @@ -836,7 +2060,23 @@ __mlu_global__ void MLUUnion1KernelMsDeformAttnBackward( const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight); -void KernelMsDeformAttnBackward( +void KernelMsDeformAttnBackwardDefaultKernel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const float *data_value, + const int32_t *spatial_shapes, const int32_t *data_level_start_index, + const float *data_sampling_loc, const float *data_attn_weight, + const float *grad_output, const int32_t batch, const int32_t spatial_size, + const int32_t num_heads, const int32_t channels, const int32_t num_levels, + const int32_t num_query, const int32_t num_points, float *grad_value, + float *grad_sampling_loc, float *grad_attn_weight) { + MLUUnion1KernelMsDeformAttnBackwarDefaultKernel<<>>( + data_value, spatial_shapes, data_level_start_index, data_sampling_loc, + data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, + num_levels, num_query, num_points, grad_value, grad_sampling_loc, + grad_attn_weight); +} + +void KernelMsDeformAttnBackwardSmallChannelsKernel( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const float *data_value, const int32_t *spatial_shapes, const int32_t *data_level_start_index, @@ -845,7 +2085,8 @@ void KernelMsDeformAttnBackward( const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_query, const int32_t num_points, float *grad_value, float *grad_sampling_loc, float *grad_attn_weight) { - MLUUnion1KernelMsDeformAttnBackward<<>>( + MLUUnion1KernelMsDeformAttnBackwardSmallChannelsKernel<<>>( data_value, spatial_shapes, data_level_start_index, data_sampling_loc, data_attn_weight, grad_output, batch, spatial_size, num_heads, channels, num_levels, num_query, num_points, grad_value, grad_sampling_loc, diff --git a/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu new file mode 100644 index 0000000000..d7c57da4f4 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/voxelization_mlu_kernel.mlu @@ -0,0 +1,532 @@ +/************************************************************************* + * Copyright (C) 2022 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ + +#include "common_mlu_helper.hpp" + +__nram__ char nram_buffer[MAX_NRAM_SIZE]; + +#if __BANG_ARCH__ >= 322 +__mlu_func__ void computeDynamicVoxelize( + char *points_x, char *points_y, char *points_z, char *auxiliary_a, + char *auxiliary_b, char *auxiliary_c, const float coors_x_min, + const float coors_y_min, const float coors_z_min, const float voxel_x, + const float voxel_y, const float voxel_z, const int32_t grid_x, + const int32_t grid_y, const int32_t grid_z, const int32_t deal_num) { + // x - coors_x_min + __bang_sub_scalar((float *)points_x, (float *)points_x, coors_x_min, + deal_num); + // y - coors_y_min + __bang_sub_scalar((float *)points_y, (float *)points_y, coors_y_min, + deal_num); + // z - coors_z_min + __bang_sub_scalar((float *)points_z, (float *)points_z, coors_z_min, + deal_num); + // (x - coors_x_min) / voxel_x + __bang_mul_scalar((float *)points_x, (float *)points_x, 1.0 / voxel_x, + deal_num); + // (y - coors_y_min) / voxel_y + __bang_mul_scalar((float *)points_y, (float *)points_y, 1.0 / voxel_y, + deal_num); + // (z - coors_z_min) / voxel_z + __bang_mul_scalar((float *)points_z, (float *)points_z, 1.0 / voxel_z, + deal_num); + // c_x = floor((x - coors_x_min) / voxel_x) + __bang_floor((float *)auxiliary_a, (float *)points_x, deal_num); + __bang_float2int32((int32_t *)points_x, (float *)auxiliary_a, deal_num, 0); + // c_y = floor((y - coors_y_min) / voxel_y) + __bang_floor((float *)auxiliary_a, (float *)points_y, deal_num); + __bang_float2int32((int32_t *)points_y, (float *)auxiliary_a, deal_num, 0); + // c_z = floor((z - coors_z_min) / voxel_z) + __bang_floor((float *)auxiliary_a, (float *)points_z, deal_num); + __bang_float2int32((int32_t *)points_z, (float *)auxiliary_a, deal_num, 0); + // c_x >= 0 + __bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_x, (int32_t)0, + deal_num); + // c_x < grid_x + __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_x, grid_x, + deal_num); + // 0 <= c_x < grid_x + __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_b, + (int32_t *)auxiliary_c, deal_num); + // c_y >= 0 + __bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_y, (int32_t)0, + deal_num); + // c_y < grid_y + __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_y, grid_y, + deal_num); + // 0 <= c_y < grid_y + __bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, + (int32_t *)auxiliary_c, deal_num); + // c_x >= 0 && c_x < grid_x && c_y >= 0 && c_y < grid_y + __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + (int32_t *)auxiliary_b, deal_num); + // c_z >= 0 + __bang_ge_scalar((int32_t *)auxiliary_b, (int32_t *)points_z, (int32_t)0, + deal_num); + // c_z < grid_z + __bang_lt_scalar((int32_t *)auxiliary_c, (int32_t *)points_z, grid_z, + deal_num); + // 0 <= c_z < grid_z + __bang_mul((int32_t *)auxiliary_b, (int32_t *)auxiliary_b, + (int32_t *)auxiliary_c, deal_num); + // 0 <= c_x < grid_x && 0 <= c_y < grid_y && 0 <= c_z < grid_z + __bang_mul((int32_t *)auxiliary_a, (int32_t *)auxiliary_a, + (int32_t *)auxiliary_b, deal_num); + __bang_not((int32_t *)auxiliary_c, (int32_t *)auxiliary_a, deal_num); + + __bang_mul((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_a, + deal_num); + __bang_mul_scalar((int32_t *)auxiliary_b, (int32_t *)auxiliary_c, + (int32_t)(-1), deal_num); + __bang_add((int32_t *)points_x, (int32_t *)points_x, (int32_t *)auxiliary_b, + deal_num); + __bang_mul((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_a, + deal_num); + __bang_add((int32_t *)points_y, (int32_t *)points_y, (int32_t *)auxiliary_b, + deal_num); + __bang_mul((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_a, + deal_num); + __bang_add((int32_t *)points_z, (int32_t *)points_z, (int32_t *)auxiliary_b, + deal_num); +} + +__mlu_func__ void computePoint2Voxel(char *coors_x, char *coors_y, + char *coors_z, const int32_t c_x, + const int32_t c_y, const int32_t c_z, + const int32_t max_points, int32_t *num, + int32_t *first_point, + const int32_t deal_idx, + const int32_t deal_num) { + __bang_eq_scalar((int32_t *)coors_x, (int32_t *)coors_x, c_x, deal_num); + __bang_eq_scalar((int32_t *)coors_y, (int32_t *)coors_y, c_y, deal_num); + __bang_eq_scalar((int32_t *)coors_z, (int32_t *)coors_z, c_z, deal_num); + __bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_y, + deal_num); + __bang_mul((int32_t *)coors_x, (int32_t *)coors_x, (int32_t *)coors_z, + deal_num); + if (*num == 0) { + *num = (int32_t)__bang_count((float *)coors_x, deal_num); + if (*num > 0) { + *first_point = + (int32_t)__bang_findfirst1((float *)coors_x, deal_num) + deal_idx; + } + } else { + *num += (int32_t)__bang_count((float *)coors_x, deal_num); + } +} +#endif + +__mlu_global__ void MLUUnion1KernelDynamicVoxelize( + const float *points, int32_t *coors, const float voxel_x, + const float voxel_y, const float voxel_z, const float coors_x_min, + const float coors_y_min, const float coors_z_min, const float coors_x_max, + const float coors_y_max, const float coors_z_max, const int32_t grid_x, + const int32_t grid_y, const int32_t grid_z, const int32_t num_points, + const int32_t num_features) { +#if __BANG_ARCH__ >= 322 + if (coreId == 0x80) { + return; + } + + const int32_t points_rem = num_points % taskDim; + const int32_t points_per_core = + taskId < points_rem ? num_points / taskDim + 1 : num_points / taskDim; + const int32_t points_start = taskId < points_rem + ? taskId * points_per_core + : taskId * points_per_core + points_rem; + + const int32_t split_num = 9; + const int32_t deal_num = + PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(float), NFU_ALIGN_SIZE); + const int32_t repeat = points_per_core / deal_num; + const int32_t rem = points_per_core % deal_num; + const int32_t ping_pong_gap = 3 * deal_num * sizeof(float); + + char *points_x = nram_buffer; + char *points_y = points_x + deal_num * sizeof(float); + char *points_z = points_y + deal_num * sizeof(float); + char *auxiliary_a = points_x + 2 * ping_pong_gap; + char *auxiliary_b = auxiliary_a + deal_num * sizeof(float); + char *auxiliary_c = auxiliary_b + deal_num * sizeof(float); + + int32_t *coors_z_start = coors + points_start; + int32_t *coors_y_start = coors + num_points + points_start; + int32_t *coors_x_start = coors + num_points * 2 + points_start; + + if (repeat > 0) { + __memcpy_async(points_x, points + points_start * num_features, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __memcpy_async(points_y, points + points_start * num_features + 1, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __memcpy_async(points_z, points + points_start * num_features + 2, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __asm__ volatile("sync;"); + } + if (repeat > 1) { + __memcpy_async(points_x + ping_pong_gap, + points + (points_start + deal_num) * num_features, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __memcpy_async(points_y + ping_pong_gap, + points + (points_start + deal_num) * num_features + 1, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __memcpy_async(points_z + ping_pong_gap, + points + (points_start + deal_num) * num_features + 2, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + computeDynamicVoxelize(points_x, points_y, points_z, auxiliary_a, + auxiliary_b, auxiliary_c, coors_x_min, coors_y_min, + coors_z_min, voxel_x, voxel_y, voxel_z, grid_x, + grid_y, grid_z, deal_num); + __asm__ volatile("sync;"); + } + + for (int32_t i = 0; i < repeat - 2; ++i) { + __memcpy_async(coors_x_start + i * deal_num, + points_x + (i % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_y_start + i * deal_num, + points_y + (i % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_z_start + i * deal_num, + points_z + (i % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(points_x + (i % 2) * ping_pong_gap, + points + (points_start + (i + 2) * deal_num) * num_features, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), deal_num - 1); + __memcpy_async( + points_y + (i % 2) * ping_pong_gap, + points + (points_start + (i + 2) * deal_num) * num_features + 1, + sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float), + deal_num - 1); + __memcpy_async( + points_z + (i % 2) * ping_pong_gap, + points + (points_start + (i + 2) * deal_num) * num_features + 2, + sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float), + deal_num - 1); + computeDynamicVoxelize(points_x + ((i + 1) % 2) * ping_pong_gap, + points_y + ((i + 1) % 2) * ping_pong_gap, + points_z + ((i + 1) % 2) * ping_pong_gap, + auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min, + coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z, + grid_x, grid_y, grid_z, deal_num); + __asm__ volatile("sync;"); + } + + if (repeat >= 2) { + __memcpy_async(coors_x_start + (repeat - 2) * deal_num, + points_x + (repeat % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_y_start + (repeat - 2) * deal_num, + points_y + (repeat % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_z_start + (repeat - 2) * deal_num, + points_z + (repeat % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + } + if (rem > 0) { + __memcpy_async(points_x + (repeat % 2) * ping_pong_gap, + points + (points_start + repeat * deal_num) * num_features, + sizeof(float), GDRAM2NRAM, sizeof(float), + num_features * sizeof(float), rem - 1); + __memcpy_async( + points_y + (repeat % 2) * ping_pong_gap, + points + (points_start + repeat * deal_num) * num_features + 1, + sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float), + rem - 1); + __memcpy_async( + points_z + (repeat % 2) * ping_pong_gap, + points + (points_start + repeat * deal_num) * num_features + 2, + sizeof(float), GDRAM2NRAM, sizeof(float), num_features * sizeof(float), + rem - 1); + } + if (repeat > 0) { + computeDynamicVoxelize(points_x + ((repeat - 1) % 2) * ping_pong_gap, + points_y + ((repeat - 1) % 2) * ping_pong_gap, + points_z + ((repeat - 1) % 2) * ping_pong_gap, + auxiliary_a, auxiliary_b, auxiliary_c, coors_x_min, + coors_y_min, coors_z_min, voxel_x, voxel_y, voxel_z, + grid_x, grid_y, grid_z, deal_num); + } + __asm__ volatile("sync;"); + + if (repeat > 0) { + __memcpy_async(coors_x_start + (repeat - 1) * deal_num, + points_x + ((repeat - 1) % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_y_start + (repeat - 1) * deal_num, + points_y + ((repeat - 1) % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_z_start + (repeat - 1) * deal_num, + points_z + ((repeat - 1) % 2) * ping_pong_gap, + deal_num * sizeof(int32_t), NRAM2GDRAM); + } + if (rem > 0) { + computeDynamicVoxelize(points_x + (repeat % 2) * ping_pong_gap, + points_y + (repeat % 2) * ping_pong_gap, + points_z + (repeat % 2) * ping_pong_gap, auxiliary_a, + auxiliary_b, auxiliary_c, coors_x_min, coors_y_min, + coors_z_min, voxel_x, voxel_y, voxel_z, grid_x, + grid_y, grid_z, rem); + __asm__ volatile("sync;"); + __memcpy_async(coors_x_start + repeat * deal_num, + points_x + (repeat % 2) * ping_pong_gap, + rem * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_y_start + repeat * deal_num, + points_y + (repeat % 2) * ping_pong_gap, + rem * sizeof(int32_t), NRAM2GDRAM); + __memcpy_async(coors_z_start + repeat * deal_num, + points_z + (repeat % 2) * ping_pong_gap, + rem * sizeof(int32_t), NRAM2GDRAM); + } +#endif +} + +__mlu_global__ void MLUUnion1KernelPoint2Voxel(int32_t *coors, + int32_t *point_to_pointidx, + int32_t *point_to_voxelidx, + const int32_t num_points, + const int32_t max_points) { +#if __BANG_ARCH__ >= 322 + if (coreId == 0x80) { + return; + } + + const int32_t split_num = 6; + const int32_t deal_num = + PAD_DOWN(MAX_NRAM_SIZE / split_num / sizeof(int32_t), NFU_ALIGN_SIZE); + const int32_t ping_pong_gap = 3 * deal_num * sizeof(int32_t); + + char *coors_x = nram_buffer; + char *coors_y = coors_x + deal_num * sizeof(int32_t); + char *coors_z = coors_y + deal_num * sizeof(int32_t); + + int32_t *coors_z_start = coors; + int32_t *coors_y_start = coors + num_points; + int32_t *coors_x_start = coors + num_points * 2; + + for (int32_t point_idx = taskId; point_idx < num_points; + point_idx += taskDim) { + if (coors_x_start[point_idx] == -1) { + point_to_pointidx[point_idx] = -1; + point_to_voxelidx[point_idx] = -1; + continue; + } + + int32_t c_x = coors_x_start[point_idx]; + int32_t c_y = coors_y_start[point_idx]; + int32_t c_z = coors_z_start[point_idx]; + + int32_t deal_total_num = point_idx; + int32_t repeat = deal_total_num / deal_num; + int32_t rem = deal_total_num % deal_num; + int32_t num = 0; + int32_t first_point = -1; + + if (repeat > 0) { + __memcpy_async(coors_x, coors_x_start, deal_num * sizeof(int32_t), + GDRAM2NRAM); + __memcpy_async(coors_y, coors_y_start, deal_num * sizeof(int32_t), + GDRAM2NRAM); + __memcpy_async(coors_z, coors_z_start, deal_num * sizeof(int32_t), + GDRAM2NRAM); + __asm__ volatile("sync;"); + } + + for (int32_t i = 0; i < repeat - 1; ++i) { + __memcpy_async(coors_x + ((i + 1) % 2) * ping_pong_gap, + coors_x_start + (i + 1) * deal_num, + deal_num * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(coors_y + ((i + 1) % 2) * ping_pong_gap, + coors_y_start + (i + 1) * deal_num, + deal_num * sizeof(int32_t), GDRAM2NRAM); + __memcpy_async(coors_z + ((i + 1) % 2) * ping_pong_gap, + coors_z_start + (i + 1) * deal_num, + deal_num * sizeof(int32_t), GDRAM2NRAM); + computePoint2Voxel( + coors_x + (i % 2) * ping_pong_gap, coors_y + (i % 2) * ping_pong_gap, + coors_z + (i % 2) * ping_pong_gap, c_x, c_y, c_z, max_points, &num, + &first_point, i * deal_num, deal_num); + __asm__ volatile("sync;"); + } + + if (rem > 0) { + __memcpy_async(coors_x + (repeat % 2) * ping_pong_gap, + coors_x_start + repeat * deal_num, rem * sizeof(int32_t), + GDRAM2NRAM); + __memcpy_async(coors_y + (repeat % 2) * ping_pong_gap, + coors_y_start + repeat * deal_num, rem * sizeof(int32_t), + GDRAM2NRAM); + __memcpy_async(coors_z + (repeat % 2) * ping_pong_gap, + coors_z_start + repeat * deal_num, rem * sizeof(int32_t), + GDRAM2NRAM); + } + if (repeat > 0) { + computePoint2Voxel(coors_x + ((repeat - 1) % 2) * ping_pong_gap, + coors_y + ((repeat - 1) % 2) * ping_pong_gap, + coors_z + ((repeat - 1) % 2) * ping_pong_gap, c_x, c_y, + c_z, max_points, &num, &first_point, + (repeat - 1) * deal_num, deal_num); + } + __asm__ volatile("sync;"); + + if (rem > 0) { + computePoint2Voxel(coors_x + (repeat % 2) * ping_pong_gap, + coors_y + (repeat % 2) * ping_pong_gap, + coors_z + (repeat % 2) * ping_pong_gap, c_x, c_y, c_z, + max_points, &num, &first_point, repeat * deal_num, + rem); + __asm__ volatile("sync;"); + } + + if (num == 0) { + point_to_pointidx[point_idx] = point_idx; + } else if (num > 0) { + point_to_pointidx[point_idx] = first_point; + } + + if (num < max_points) { + point_to_voxelidx[point_idx] = num; + } else { + point_to_voxelidx[point_idx] = -1; + } + } +#endif +} + +__mlu_global__ void MLUUnion1KernelCalcPointsPerVoxel( + int32_t *point_to_pointidx, int32_t *point_to_voxelidx, + int32_t *coor_to_voxelidx, int32_t *num_points_per_voxel, + int32_t *voxel_num, const int32_t max_voxels, const int32_t num_points) { +#if __BANG_ARCH__ >= 322 + if (coreId == 0) { + int32_t voxel_num_temp = 0; + for (int32_t point_idx = 0; point_idx < num_points; ++point_idx) { + int32_t point_pos_in_voxel = point_to_voxelidx[point_idx]; + coor_to_voxelidx[point_idx] = -1; + if (point_pos_in_voxel == -1) { + continue; + } else if (point_pos_in_voxel == 0) { + int32_t voxel_idx = voxel_num_temp; + if (voxel_num_temp >= max_voxels) { + continue; + } + voxel_num_temp += 1; + coor_to_voxelidx[point_idx] = voxel_idx; + num_points_per_voxel[voxel_idx] = 1; + } else { + int32_t point_idx_temp = point_to_pointidx[point_idx]; + int32_t voxel_idx = coor_to_voxelidx[point_idx_temp]; + if (voxel_idx != -1) { + coor_to_voxelidx[point_idx] = voxel_idx; + num_points_per_voxel[voxel_idx] += 1; + } + } + } + *voxel_num = voxel_num_temp; + } +#endif +} + +__mlu_global__ void MLUUnion1KernelAssignVoxelsCoors( + const float *points, int32_t *temp_coors, int32_t *point_to_voxelidx, + int32_t *coor_to_voxelidx, float *voxels, int32_t *coors, + const int32_t max_points, const int32_t num_points, + const int32_t num_features) { +#if __BANG_ARCH__ >= 322 + if (coreId == 0x80) { + return; + } + + int32_t points_per_core = num_points / taskDim; + int32_t points_rem = num_points % taskDim; + int32_t points_start = taskId < points_rem + ? taskId * (points_per_core + 1) + : taskId * points_per_core + points_rem; + int32_t points_end = taskId < points_rem ? points_start + points_per_core + 1 + : points_start + points_per_core; + + for (int32_t point_idx = points_start; point_idx < points_end; ++point_idx) { + int32_t num = point_to_voxelidx[point_idx]; + int32_t voxel_idx = coor_to_voxelidx[point_idx]; + if (num > -1 && voxel_idx > -1) { + float *voxels_offset = + voxels + voxel_idx * max_points * num_features + num * num_features; + const float *points_offset = points + point_idx * num_features; + __memcpy_async(voxels_offset, points_offset, num_features * sizeof(float), + GDRAM2GDRAM); + + if (num == 0) { + int32_t *coors_offset = coors + voxel_idx * 3; + __memcpy_async(coors_offset, temp_coors + point_idx, sizeof(int32_t), + GDRAM2GDRAM, sizeof(int32_t), + num_points * sizeof(int32_t), 2); + } + } + } + __asm__ volatile("sync;"); +#endif +} + +void KernelDynamicVoxelize(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const void *points, void *coors, + const float voxel_x, const float voxel_y, + const float voxel_z, const float coors_x_min, + const float coors_y_min, const float coors_z_min, + const float coors_x_max, const float coors_y_max, + const float coors_z_max, const int32_t grid_x, + const int32_t grid_y, const int32_t grid_z, + const int32_t num_points, + const int32_t num_features) { + MLUUnion1KernelDynamicVoxelize<<>>( + (float *)points, (int32_t *)coors, voxel_x, voxel_y, voxel_z, coors_x_min, + coors_y_min, coors_z_min, coors_x_max, coors_y_max, coors_z_max, grid_x, + grid_y, grid_z, num_points, num_features); +} + +void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, void *coors, void *point_to_pointidx, + void *point_to_voxelidx, const int32_t num_points, + const int32_t max_points) { + MLUUnion1KernelPoint2Voxel<<>>( + (int32_t *)coors, (int32_t *)point_to_pointidx, + (int32_t *)point_to_voxelidx, num_points, max_points); +} + +void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, void *point_to_pointidx, + void *point_to_voxelidx, void *coor_to_voxelidx, + void *num_points_per_voxel, void *voxel_num, + const int32_t max_voxels, + const int32_t num_points) { + MLUUnion1KernelCalcPointsPerVoxel<<>>( + (int32_t *)point_to_pointidx, (int32_t *)point_to_voxelidx, + (int32_t *)coor_to_voxelidx, (int32_t *)num_points_per_voxel, + (int32_t *)voxel_num, max_voxels, num_points); +} + +void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const void *points, + void *temp_coors, void *point_to_voxelidx, + void *coor_to_voxelidx, void *voxels, void *coors, + const int32_t max_points, const int32_t num_points, + const int32_t num_features) { + MLUUnion1KernelAssignVoxelsCoors<<>>( + (float *)points, (int32_t *)temp_coors, (int32_t *)point_to_voxelidx, + (int32_t *)coor_to_voxelidx, (float *)voxels, (int32_t *)coors, + max_points, num_points, num_features); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp new file mode 100644 index 0000000000..000f8882b1 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/ball_query_mlu.cpp @@ -0,0 +1,47 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" + +void ball_query_forward_mlu(int b, int n, int m, float min_radius, + float max_radius, int nsample, const Tensor new_xyz, + const Tensor xyz, Tensor idx) { + auto new_xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + new_xyz, new_xyz.suggest_memory_format()); + auto xyz_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + xyz, new_xyz.suggest_memory_format()); + auto idx_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + idx, new_xyz.suggest_memory_format()); + + MluOpTensorDescriptor new_xyz_desc, xyz_desc, idx_desc; + new_xyz_desc.set(new_xyz_contiguous); + xyz_desc.set(xyz_contiguous); + idx_desc.set(idx_contiguous); + + auto new_xyz_impl = torch_mlu::getMluTensorImpl(new_xyz_contiguous); + auto xyz_impl = torch_mlu::getMluTensorImpl(xyz_contiguous); + auto idx_impl = torch_mlu::getMluTensorImpl(idx_contiguous); + auto new_xyz_ptr = new_xyz_impl->cnnlMalloc(); + auto xyz_ptr = xyz_impl->cnnlMalloc(); + auto idx_ptr = idx_impl->cnnlMalloc(); + + auto handle = mluOpGetCurrentHandle(); + mluOpBallQuery(handle, new_xyz_desc.desc(), new_xyz_ptr, xyz_desc.desc(), + xyz_ptr, min_radius, max_radius, nsample, idx_desc.desc(), + idx_ptr); +} + +void ball_query_forward_impl(int b, int n, int m, float min_radius, + float max_radius, int nsample, + const Tensor new_xyz, const Tensor xyz, + Tensor idx); + +REGISTER_DEVICE_IMPL(ball_query_forward_impl, MLU, ball_query_forward_mlu); diff --git a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp index 4d73cbbe59..90a625c4a2 100644 --- a/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/deform_roi_pool_mlu.cpp @@ -9,254 +9,59 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelDeformRoIPoolForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t data_type, - const void *input, const void *rois, - const void *offset, void *output, - const int channels, const int height, - const int width, const int num_rois, - const int pooled_height, const int pooled_width, - const float spatial_scale, - const int sampling_ratio, const float gamma); - -void KernelDeformRoIPoolBackward( - cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, - cnrtDataType_t data_type, const void *grad_output, const void *input, - const void *rois, const void *offset, void *grad_input, void *grad_offset, - const int channels, const int height, const int width, const int num_rois, - const int pooled_height, const int pooled_width, const float spatial_scale, - const int sampling_ratio, const float gamma); - -// policy function for forward and backward -static void policyFunc(const int bin_num, cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type) { - const size_t cluster_limit = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - ; - const size_t core_limit = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - const size_t bin_num_align = CEIL_ALIGN(bin_num, core_limit); - k_dim->x = core_limit; - k_dim->y = (bin_num_align / core_limit) > cluster_limit - ? cluster_limit - : (bin_num_align / core_limit); - k_dim->z = 1; - *k_type = CNRT_FUNC_TYPE_UNION1; -} +#include "mlu_common_helper.h" void DeformRoIPoolForwardMLUKernelLauncher(Tensor input, Tensor rois, Tensor offset, Tensor output, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, float gamma) { - // Check dtype. - TORCH_CHECK( - input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, - "input type should be Float or Half, got ", input.scalar_type()); - TORCH_CHECK(input.scalar_type() == rois.scalar_type(), - "rois should have the same type as input"); - - // Check shape. - TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), - "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), - "D."); - if (offset.defined() && offset.numel() > 0) { - TORCH_CHECK(input.scalar_type() == offset.scalar_type(), - "offset should have the same type as input"); - TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", - offset.dim(), "D."); - TORCH_CHECK( - (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), - "while rois.size(0)) = ", rois.size(0), ". They should be the same."); - TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", - "but now offset.size(1) = ", offset.size(1), "."); - TORCH_CHECK((offset.size(2) == output.size(2)), - "offset.size(2) = ", offset.size(2), - "while output.size(2)) = ", output.size(2), - ". They should be the same."); - TORCH_CHECK((offset.size(3) == output.size(3)), - "offset.size(3) = ", offset.size(3), - "while output.size(3)) = ", output.size(3), - ". They should be the same."); - } - - TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, - "spatial_scale should be within (0, 1], got ", spatial_scale, - "."); - - // compute kernel params - auto height = input.size(2); - auto width = input.size(3); - auto channels = input.size(1); - auto num_rois = output.size(0); - - if (output.numel() == 0) { - output = at::zeros({num_rois, channels, pooled_height, pooled_width}, - input.options()); - return; - } - - // zero element check - TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", - input.size(0)); - TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", - rois.numel()); - if (input.numel() == 0 || output.numel() == 0) { - return; - } - - // large tensor check - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(input.numel() < max_input_num, - "input.numel() should be less than 2147483648, got ", - input.numel()); - TORCH_CHECK(rois.numel() < max_input_num, - "rois.numel() should be less than 2147483648, got ", - rois.numel()); - TORCH_CHECK(output.numel() < max_input_num, - "output.numel() should be less than 2147483648, got ", - output.numel()); - TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, - "offset.numel() should be less than 2147483648, got ", - offset.numel()); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); - - at::Tensor output_ = - at::empty({num_rois, channels, pooled_height, pooled_width}, - input.options(), memory_format); - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto output_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format); + + MluOpTensorDescriptor input_desc, rois_desc, offset_desc, output_desc; + input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + output_desc.set_with_layout(output_contiguous, MLUOP_LAYOUT_NHWC); + + mluOpTensorDescriptor_t offset_real_desc = NULL; + void *offset_ptr = NULL; + if (offset.defined() && offset.numel() > 0) { + auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + offset, offset.suggest_memory_format()); + offset_desc.set(offset_contiguous); + offset_real_desc = offset_desc.desc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous); + offset_ptr = offset_impl->cnnlMalloc(); + } // get ptr of tensors auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_ptr = input_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); auto rois_ptr = rois_impl->cnnlMalloc(); - auto offset_impl = torch_mlu::getMluTensorImpl(offset); - auto offset_ptr = offset_impl->cnnlMalloc(); - auto output_impl = torch_mlu::getMluTensorImpl(output_); + auto output_impl = torch_mlu::getMluTensorImpl(output_contiguous); auto output_ptr = output_impl->cnnlMalloc(); - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input_.dtype()); + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpDeformRoiPoolForward( + handle, input_desc.desc(), input_ptr, rois_desc.desc(), rois_ptr, + offset_real_desc, offset_ptr, pooled_height, pooled_width, spatial_scale, + sampling_ratio, gamma, output_desc.desc(), output_ptr); - // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelDeformRoIPoolForward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelDeformRoIPoolForward(k_dim, k_type, queue, data_type, input_ptr, - rois_ptr, offset_ptr, output_ptr, channels, height, - width, num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); - - output.copy_(output_); + output.copy_(output_contiguous); } void DeformRoIPoolBackwardMLUKernelLauncher( Tensor grad_output, Tensor input, Tensor rois, Tensor offset, Tensor grad_input, Tensor grad_offset, int pooled_height, int pooled_width, float spatial_scale, int sampling_ratio, float gamma) { - // Check dtype. - TORCH_CHECK( - input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, - "input type should be Float or Half, got ", input.scalar_type()); - TORCH_CHECK(input.scalar_type() == grad_output.scalar_type(), - "grad_output should have the same type as input"); - TORCH_CHECK(input.scalar_type() == rois.scalar_type(), - "rois should have the same type as input"); - TORCH_CHECK(input.scalar_type() == grad_input.scalar_type(), - "grad_input should have the same type as input"); - - // Check shape. - TORCH_CHECK(grad_output.dim() == 4, "grad_output should be 4d tensor, got ", - grad_output.dim(), "D."); - TORCH_CHECK(input.dim() == 4, "input should be 4d tensor, got ", input.dim(), - "D."); - TORCH_CHECK(rois.dim() == 2, "rois should be 2d tensor, got ", rois.dim(), - "D."); - if (offset.defined() && offset.numel() > 0) { - TORCH_CHECK(input.scalar_type() == offset.scalar_type(), - "offset should have the same type as input"); - TORCH_CHECK(offset.dim() == 4, "offset should be 4d tensor, got ", - offset.dim(), "D."); - TORCH_CHECK( - (offset.size(0) == rois.size(0)), "offset.size(0) = ", offset.size(0), - "while rois.size(0)) = ", rois.size(0), ". They should be the same."); - TORCH_CHECK((offset.size(1) == 2), "offset.size(1) should be 2, ", - "but now offset.size(1) = ", offset.size(1), "."); - TORCH_CHECK((offset.size(2) == grad_output.size(2)), - "offset.size(2) = ", offset.size(2), - "while grad_output.size(2)) = ", grad_output.size(2), - ". They should be the same."); - TORCH_CHECK((offset.size(3) == grad_output.size(3)), - "offset.size(3) = ", offset.size(3), - "while grad_output.size(3)) = ", grad_output.size(3), - ". They should be the same."); - } - - TORCH_CHECK(spatial_scale > 0 && spatial_scale <= 1, - "spatial_scale should be within (0, 1], got ", spatial_scale); - - // Check relationship between tensor. - TORCH_CHECK((grad_output.size(0) == rois.size(0)), - "grad_output.size(0) = ", grad_output.size(0), - "while rois.size(0)) = ", rois.size(0), - ". They should be the same."); - TORCH_CHECK((grad_output.size(1) == input.size(1)), - "grad_output.size(1) = ", grad_output.size(1), - "while input.size(1)) = ", input.size(1), - ". They should be the same."); - TORCH_CHECK((grad_output.size(2) == pooled_height), - "grad_output.size(2) = ", grad_output.size(2), - "while pooled_height = ", pooled_height, - ". They should be the same."); - TORCH_CHECK((grad_output.size(3) == pooled_width), - "grad_output.size(3) = ", grad_output.size(3), - "while pooled_width = ", pooled_width, - ". They should be the same."); - - // compute kernel params - auto batch = input.size(0); - auto channels = input.size(1); - auto height = input.size(2); - auto width = input.size(3); - auto num_rois = grad_output.size(0); - - // zero element check - TORCH_CHECK(input.size(0) != 0, "input.size(0) should not be zero, got ", - input.size(0)); - TORCH_CHECK(rois.numel() != 0, "rois.numel() should not be zero, got ", - rois.numel()); - if (input.numel() == 0 || grad_output.numel() == 0) { - return; - } - - // large tensor check - const size_t max_input_num = 2147483648; // 2^31, 2G num - TORCH_CHECK(input.numel() < max_input_num, - "input.numel() should be less than 2147483648, got ", - input.numel()); - TORCH_CHECK(rois.numel() < max_input_num, - "rois.numel() should be less than 2147483648, got ", - rois.numel()); - TORCH_CHECK(grad_output.numel() < max_input_num, - "grad_output.numel() should be less than 2147483648, got ", - grad_output.numel()); - TORCH_CHECK(!offset.defined() || offset.numel() < max_input_num, - "offset.numel() should be less than 2147483648, got ", - offset.numel()); - auto memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim()); auto grad_output_ = @@ -264,45 +69,56 @@ void DeformRoIPoolBackwardMLUKernelLauncher( memory_format = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); auto input_ = torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format); - at::Tensor grad_input_ = at::empty({batch, channels, height, width}, - input.options(), memory_format) - .zero_(); - - // calculate task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFunc(num_rois * pooled_height * pooled_width, &k_dim, &k_type); - - // get compute queue - auto queue = torch_mlu::getCurQueue(); + auto rois_contiguous = + torch_mlu::cnnl::ops::cnnl_contiguous(rois, rois.suggest_memory_format()); + auto grad_input_ = + torch_mlu::cnnl::ops::cnnl_contiguous(grad_input, memory_format); // get ptr of tensors auto grad_output_impl = torch_mlu::getMluTensorImpl(grad_output_); auto grad_output_ptr = grad_output_impl->cnnlMalloc(); auto input_impl = torch_mlu::getMluTensorImpl(input_); auto input_ptr = input_impl->cnnlMalloc(); - auto rois_impl = torch_mlu::getMluTensorImpl(rois); + auto rois_impl = torch_mlu::getMluTensorImpl(rois_contiguous); auto rois_ptr = rois_impl->cnnlMalloc(); - auto offset_impl = torch_mlu::getMluTensorImpl(offset); - auto offset_ptr = offset_impl->cnnlMalloc(); auto grad_input_impl = torch_mlu::getMluTensorImpl(grad_input_); auto grad_input_ptr = grad_input_impl->cnnlMalloc(); - auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset); - auto grad_offset_ptr = grad_offset_impl->cnnlMalloc(); - - // get comput dtype of input - cnrtDataType_t data_type = torch_mlu::toCnrtDtype(input.dtype()); - // launch kernel - CNLOG(INFO) << "Launch Kernel KernelDeformRoIPoolBackward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelDeformRoIPoolBackward(k_dim, k_type, queue, data_type, grad_output_ptr, - input_ptr, rois_ptr, offset_ptr, grad_input_ptr, - grad_offset_ptr, channels, height, width, - num_rois, pooled_height, pooled_width, - spatial_scale, sampling_ratio, gamma); + MluOpTensorDescriptor grad_output_desc, input_desc, rois_desc, offset_desc, + grad_input_desc, grad_offset_desc; + grad_output_desc.set_with_layout(grad_output_, MLUOP_LAYOUT_NHWC); + input_desc.set_with_layout(input_, MLUOP_LAYOUT_NHWC); + rois_desc.set(rois_contiguous); + grad_input_desc.set_with_layout(grad_input_, MLUOP_LAYOUT_NHWC); + mluOpTensorDescriptor_t offset_real_desc = NULL; + void *offset_ptr = NULL; + if (offset.defined() && offset.numel() > 0) { + auto offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + offset, offset.suggest_memory_format()); + offset_desc.set(offset_contiguous); + offset_real_desc = offset_desc.desc(); + auto offset_impl = torch_mlu::getMluTensorImpl(offset_contiguous); + offset_ptr = offset_impl->cnnlMalloc(); + } + mluOpTensorDescriptor_t grad_offset_real_desc = NULL; + void *grad_offset_ptr = NULL; + if (grad_offset.defined() && grad_offset.numel() > 0) { + auto grad_offset_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_offset, grad_offset.suggest_memory_format()); + grad_offset_desc.set(grad_offset_contiguous); + grad_offset_real_desc = grad_offset_desc.desc(); + auto grad_offset_impl = torch_mlu::getMluTensorImpl(grad_offset_contiguous); + grad_offset_ptr = grad_offset_impl->cnnlMalloc(); + } + // get compute handle + auto handle = mluOpGetCurrentHandle(); + mluOpDeformRoiPoolBackward( + handle, grad_output_desc.desc(), grad_output_ptr, input_desc.desc(), + input_ptr, rois_desc.desc(), rois_ptr, offset_real_desc, offset_ptr, + pooled_height, pooled_width, spatial_scale, sampling_ratio, gamma, + grad_input_desc.desc(), grad_input_ptr, grad_offset_real_desc, + grad_offset_ptr); grad_input.copy_(grad_input_); } diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp new file mode 100644 index 0000000000..3a76b49715 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.cpp @@ -0,0 +1,136 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" + +// Descriptors +mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type) { + const std::map mapping_type = { + {std::string("c10::Half"), MLUOP_DTYPE_HALF}, + {std::string("float"), MLUOP_DTYPE_FLOAT}, + {std::string("double"), MLUOP_DTYPE_DOUBLE}, + {std::string("int8"), MLUOP_DTYPE_INT8}, + {std::string("signed char"), MLUOP_DTYPE_INT8}, + {std::string("short int"), MLUOP_DTYPE_INT16}, + {std::string("short"), MLUOP_DTYPE_INT16}, + {std::string("int"), MLUOP_DTYPE_INT32}, + {std::string("long int"), MLUOP_DTYPE_INT64}, + {std::string("long"), MLUOP_DTYPE_INT64}, + {std::string("unsigned char"), MLUOP_DTYPE_UINT8}, + {std::string("bool"), MLUOP_DTYPE_BOOL}, + {std::string("c10::complex"), MLUOP_DTYPE_COMPLEX_HALF}, + {std::string("c10::complex"), MLUOP_DTYPE_COMPLEX_FLOAT}}; + + if (mapping_type.find(std::string(data_type.name())) != mapping_type.end()) { + return mapping_type.find(std::string(data_type.name()))->second; + } + return MLUOP_DTYPE_INVALID; +} + +// laytout +mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input) { + auto suggest_memory_format = input.suggest_memory_format(); + mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; + switch (input.dim()) { + case 4: + layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast) + ? MLUOP_LAYOUT_NHWC + : MLUOP_LAYOUT_NCHW; + break; + case 5: + layout = (suggest_memory_format == at::MemoryFormat::ChannelsLast3d) + ? MLUOP_LAYOUT_NDHWC + : MLUOP_LAYOUT_NCDHW; + break; + default: + layout = MLUOP_LAYOUT_ARRAY; + } + return layout; +} + +void MluOpTensorDescriptor::set(Tensor t) { + mluOpDataType_t data_type = getMluOpDataType(t.dtype()); + mluOpTensorLayout_t layout = getMluOpSuggestLayout(t); + int t_dim = t.dim(); + std::vector dim_array; + if (t_dim == 0) { + dim_array.push_back( + 1); // ScalarTensor(0-dim 1-item Tensor) view like size = 1 as default; + } else { + for (int i = 0; i < t_dim; i++) { + dim_array.push_back(static_cast(t.sizes().vec()[i])); + } + } + set_desc(t, layout, data_type, dim_array); +} + +void MluOpTensorDescriptor::set_with_layout(Tensor t, + mluOpTensorLayout_t layout) { + mluOpDataType_t data_type = getMluOpDataType(t.dtype()); + int t_dim = t.dim(); + std::vector shape_info = checkUpperBoundAndCastTo(t.sizes().vec()); + std::vector stride_info = + checkUpperBoundAndCastTo(t.strides().vec()); + if (layout == MLUOP_LAYOUT_NHWC || layout == MLUOP_LAYOUT_NDHWC || + layout == MLUOP_LAYOUT_NLC) { + convertShapeAndStride(shape_info, stride_info); + } else if (layout == MLUOP_LAYOUT_HWCN) { + auto convertDepthWiseConvShapeStride = [](const std::vector& vec, + std::vector& target_vec, + std::vector& stride_vec) { + // NCHW --> HWCN + target_vec[0] = static_cast(vec[2]); + target_vec[1] = static_cast(vec[3]); + target_vec[2] = static_cast(vec[1]); + target_vec[3] = static_cast(vec[0]); + // Calculate Stride just like contiguous of HWCN. + stride_vec[3] = 1; + stride_vec[2] = target_vec[3] * stride_vec[3]; + stride_vec[1] = target_vec[2] * stride_vec[2]; + stride_vec[0] = target_vec[1] * stride_vec[1]; + }; + convertDepthWiseConvShapeStride(t.sizes().vec(), shape_info, stride_info); + } + TORCH_CHECK(mluOpSetTensorDescriptorEx( + desc_, layout, data_type, t_dim, shape_info.data(), + stride_info.data()) == MLUOP_STATUS_SUCCESS, + "mluOpSetTensorDescriptorEx execution failed."); +} + +void MluOpTensorDescriptor::set_desc(const at::Tensor& t, + mluOpTensorLayout_t layout, + mluOpDataType_t dtype, + std::vector& dims) { + int dimNb = dims.size(); + mluOpSetTensorDescriptor(desc_, layout, dtype, dimNb, dims.data()); +} + +// Handles +std::once_flag mmcv_mluop_init_flag; +std::mutex mmcv_mluop_mutex; +static std::vector mmcv_mluop_handles; + +mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index) { + std::call_once(mmcv_mluop_init_flag, + []() // Init mmcv_mluop_handles 1-device <-> 1-handle + { + c10::DeviceIndex num_devices = torch_mlu::device_count(); + mmcv_mluop_handles.resize(num_devices); + }); + + if (device_index == -1) { + device_index = torch_mlu::current_device(); + } + std::lock_guard mmcv_mluop_guard(mmcv_mluop_mutex); + auto queue = torch_mlu::getCurrentQueue(device_index).queue(); + mmcv_mluop_handles[device_index].setQueue(queue); + return mmcv_mluop_handles[device_index].handle; +} diff --git a/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h new file mode 100644 index 0000000000..436f055f04 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h @@ -0,0 +1,99 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#pragma once +#include +#include + +#include "aten.h" +#include "mlu_op.h" +#include "pytorch_device_registry.hpp" + +#define MLUOP_MAJOR 0 +#define MLUOP_MINOR 5 +#define MLUOP_PATCHLEVEL 302 + +mluOpDataType_t getMluOpDataType(const caffe2::TypeMeta& data_type); +mluOpTensorLayout_t getMluOpSuggestLayout(const at::Tensor& input); + +class MluOpTensorDescriptor { + public: + MluOpTensorDescriptor() { mluOpCreateTensorDescriptor(&desc_); }; + ~MluOpTensorDescriptor() { mluOpDestroyTensorDescriptor(desc_); } + + void set(at::Tensor); + void set_with_layout(at::Tensor, mluOpTensorLayout_t layout); + mluOpTensorDescriptor_t desc() { return desc_; } + + private: + mluOpTensorDescriptor_t desc_; + void set_desc(const at::Tensor&, mluOpTensorLayout_t, mluOpDataType_t, + std::vector& dims); +}; + +mluOpHandle_t mluOpGetCurrentHandle(c10::DeviceIndex device_index = -1); + +class MluOpHandle { + public: + MluOpHandle() : handle(nullptr) { mluOpCreate(&handle); } + ~MluOpHandle() { + if (handle) { + mluOpDestroy(handle); + handle = nullptr; + } + } + void setQueue(cnrtQueue_t queue) { mluOpSetQueue(handle, queue); } + mluOpHandle_t handle; +}; + +// modify tensor size and stride order based on +// channels_first to channels_last or channels_last_3d. +// which this is not same with pytorch original layout, +// this real layout is based on data storage real order. +// example: modify channels_last tensor dim to nhwc tensor desc. +// N C H W --> N H W C +// C*H*W 1 W C --> C*H*W W C 1 +template +void convertShapeAndStride(std::vector& shape_info, + std::vector& stride_info) { + TORCH_MLU_CHECK(shape_info.size() == stride_info.size(), + "shape size need equal to stride size."); + const int dim = shape_info.size(); + std::vector temp_shape_info(dim); + std::vector temp_stride_info(dim); + temp_shape_info[0] = shape_info[0]; + temp_stride_info[0] = stride_info[0]; + for (size_t i = 0; i < dim - 1; ++i) { + const int index = (i + 1) % (dim - 1) + 1; + temp_shape_info[i + 1] = shape_info[index]; + temp_stride_info[i + 1] = stride_info[index]; + } + shape_info.assign(temp_shape_info.begin(), temp_shape_info.end()); + stride_info.assign(temp_stride_info.begin(), temp_stride_info.end()); +} + +// torch tensor provides int64_t type of shape and stride, +// but mluops descriptor requires type int32. +// use this function to ensure safe CAST, or report an error. +template +std::vector checkUpperBoundAndCastTo(const std::vector& input) { + std::vector output; + output.reserve(input.size()); + for (const auto& val : input) { + if (val > std::numeric_limits::max()) { + TORCH_MLU_CHECK(false, "Requires dim size not greater than ", + std::numeric_limits::max(), ". But got ", val, + "."); + } + output.push_back(static_cast(val)); + } + return output; +} diff --git a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp index e93fd984aa..f8e884d971 100644 --- a/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/ms_deform_attn_mlu.cpp @@ -14,7 +14,15 @@ #define MIN(a, b) (((a) < (b)) ? (a) : (b)) -void KernelMsDeformAttnForward( +typedef enum { + MS_DEFORM_ATTN_FORWARD_INVALID = 0, /*!< Index is invalid. */ + MS_DEFORM_ATTN_FORWARD_DEFAULT = + 1, /*!< MLUKernelMsDeformAttnForwardDefault */ + MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL = + 2, /*!< MLUKernelMsDeformAttnForwardSmallChannel */ +} MsDeformAttnForwardPolicy; + +void KernelMsDeformAttnForwardDefault( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const char* data_value_gdram, const char* data_spatial_shapes_gdram, @@ -23,7 +31,37 @@ void KernelMsDeformAttnForward( const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, char* data_col_gdram); -void KernelMsDeformAttnBackward( +void KernelMsDeformAttnForwardSmallChannel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const char* data_value_gdram, + const char* data_spatial_shapes_gdram, + const char* data_level_start_index_gdram, + const char* data_sampling_loc_gdram, const char* data_attn_weight_gdram, + const int32_t batch_size, const int32_t num_keys, const int32_t num_heads, + const int32_t channels, const int32_t num_levels, const int32_t num_queries, + const int32_t num_points, char* data_col_gdram); + +typedef enum { + MS_DEFORM_ATTN_BACKWARD_DEFAULT = 0, + MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL = 1, +} MsDeformAttnBackwardKernelPolicy; + +MsDeformAttnBackwardKernelPolicy msDeformAttnBackwardPolicyFunc( + const int32_t channels, const int32_t num_levels, const int32_t num_points, + const int32_t num_heads) { + const int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + const int num_hlp = num_heads * num_levels * num_points; + int num_per_time_theory = (nram_size - num_levels * sizeof(float) - + 3 * num_levels * sizeof(int32_t)) / + sizeof(float) / (8 * PAD_UP(channels, 32) + 28) / + PAD_UP((num_hlp), 32); + if (num_per_time_theory >= 1) { + return MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL; + } + return MS_DEFORM_ATTN_BACKWARD_DEFAULT; +} + +void KernelMsDeformAttnBackwardDefaultKernel( cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, const cnrtDataType_t d_type, const float* data_value, const int32_t* spatial_shapes, const int32_t* data_level_start_index, @@ -32,10 +70,23 @@ void KernelMsDeformAttnBackward( const int32_t num_heads, const int32_t channels, const int32_t num_levels, const int32_t num_queries, const int32_t num_points, float* grad_value, float* grad_sampling_loc, float* grad_attn_weight); + +void KernelMsDeformAttnBackwardSmallChannelsKernel( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const cnrtDataType_t d_type, const float* data_value, + const int32_t* spatial_shapes, const int32_t* data_level_start_index, + const float* data_sampling_loc, const float* data_attn_weight, + const float* grad_output, const int32_t batch, const int32_t spatial_size, + const int32_t num_heads, const int32_t channels, const int32_t num_levels, + const int32_t num_query, const int32_t num_points, float* grad_value, + float* grad_sampling_loc, float* grad_attn_weight); + // policy function -static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, - const int batch_size, const int num_queries, - const int num_heads) { +MsDeformAttnForwardPolicy msDeformAttnForwardPolicyFunc( + cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, const int32_t batch_size, + const int32_t num_keys, const int32_t num_heads, const int32_t channels, + const int32_t num_levels, const int32_t num_queries, + const int32_t num_points) { k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); k_dim->y = MIN((batch_size * num_queries * num_heads + k_dim->x - 1) / k_dim->x, @@ -46,6 +97,16 @@ static void policyFuncForward(cnrtDim3_t* k_dim, cnrtFunctionType_t* k_type, #else *k_type = CNRT_FUNC_TYPE_UNION1; #endif + + int32_t nram_size = torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore); + if (num_levels * num_points * 3 * sizeof(int32_t) > nram_size) { + return MS_DEFORM_ATTN_FORWARD_DEFAULT; + } else if (channels > nram_size / 12 / sizeof(float) || channels > 96 || + channels < 16) { + return MS_DEFORM_ATTN_FORWARD_DEFAULT; + } else { + return MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL; + } } // policy function for backward @@ -196,7 +257,9 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, // calculate task dimension cnrtDim3_t k_dim; cnrtFunctionType_t k_type; - policyFuncForward(&k_dim, &k_type, batch_size, num_queries, num_heads); + MsDeformAttnForwardPolicy policy = msDeformAttnForwardPolicyFunc( + &k_dim, &k_type, batch_size, num_keys, num_heads, channels, num_levels, + num_queries, num_points); // get compute queue auto queue = torch_mlu::getCurQueue(); @@ -222,15 +285,33 @@ Tensor ms_deform_attn_mlu_forward(const Tensor& value, cnrtDataType_t data_type = torch_mlu::toCnrtDtype(value.dtype()); // launch kernel - CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForward<<<" << k_dim.x - << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelMsDeformAttnForward( - k_dim, k_type, queue, data_type, (char*)value_ptr, - (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, - (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, - num_heads, channels, num_levels, num_queries, num_points, - (char*)output_ptr); + switch (policy) { + default: { + VLOG(5) << "MsDeformAttnForward Policy not supported"; + }; break; + case MS_DEFORM_ATTN_FORWARD_DEFAULT: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardDefault<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardDefault( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + case MS_DEFORM_ATTN_FORWARD_SMALL_CHANNEL: { + CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnForwardSmallChannel<<<" + << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; + KernelMsDeformAttnForwardSmallChannel( + k_dim, k_type, queue, data_type, (char*)value_ptr, + (char*)spatial_shapes_ptr, (char*)level_start_index_ptr, + (char*)sampling_loc_ptr, (char*)attn_weight_ptr, batch_size, num_keys, + num_heads, channels, num_levels, num_queries, num_points, + (char*)output_ptr); + break; + } + } output = output.view({batch_size, num_queries, num_heads * channels}); return output; @@ -391,14 +472,32 @@ void ms_deform_attn_mlu_backward( // launch kernel CNLOG(INFO) << "Launch Kernel MLUKernelMsDeformAttnBackward<<<" << k_dim.x << ", " << k_dim.y << ", " << k_dim.z << ">>>"; - - KernelMsDeformAttnBackward( - k_dim, k_type, queue, data_type, (float*)value_ptr, - (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, - (float*)sampling_loc_ptr, (float*)attn_weight_ptr, - (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, - num_levels, num_queries, num_points, (float*)grad_value_ptr, - (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + MsDeformAttnBackwardKernelPolicy kernelPolicy = + msDeformAttnBackwardPolicyFunc(channels, num_levels, num_points, + num_heads); + switch (kernelPolicy) { + default: { + VLOG(5) << "NotImplemented."; + } break; + case MS_DEFORM_ATTN_BACKWARD_DEFAULT: { + KernelMsDeformAttnBackwardDefaultKernel( + k_dim, k_type, queue, data_type, (float*)value_ptr, + (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, + (float*)sampling_loc_ptr, (float*)attn_weight_ptr, + (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points, (float*)grad_value_ptr, + (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + } break; + case MS_DEFORM_ATTN_BACKWARD_SMALL_CHANNEL: { + KernelMsDeformAttnBackwardSmallChannelsKernel( + k_dim, k_type, queue, data_type, (float*)value_ptr, + (int32_t*)spatial_shapes_ptr, (int32_t*)level_start_index_ptr, + (float*)sampling_loc_ptr, (float*)attn_weight_ptr, + (float*)grad_output_ptr, batch_size, num_keys, num_heads, channels, + num_levels, num_queries, num_points, (float*)grad_value_ptr, + (float*)grad_sampling_loc_ptr, (float*)grad_attn_weight_ptr); + } break; + } } Tensor ms_deform_attn_impl_forward(const Tensor& value, diff --git a/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp new file mode 100644 index 0000000000..9b45a17805 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/nms_rotated_mlu.cpp @@ -0,0 +1,53 @@ +/************************************************************************* + * Copyright (C) 2021 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "mlu_common_helper.h" + +Tensor nms_rotated_mlu(Tensor boxes, Tensor scores, float iou_threshold) { + if (boxes.numel() == 0) { + return at::empty({0}, boxes.options().dtype(at::kLong)); + } + + int boxes_num = boxes.size(0); + auto boxes_ = torch_mlu::cnnl::ops::cnnl_contiguous(boxes); + auto scores_ = torch_mlu::cnnl::ops::cnnl_contiguous(scores); + auto output = at::empty({boxes_num}, boxes.options().dtype(at::kInt)); + auto output_size = at::empty({1}, scores.options().dtype(at::kInt)); + + MluOpTensorDescriptor boxes_desc, scores_desc, output_desc; + boxes_desc.set(boxes_); + scores_desc.set(scores_); + output_desc.set(output); + + // workspace + size_t workspace_size = 0; + auto handle = mluOpGetCurrentHandle(); + mluOpGetNmsRotatedWorkspaceSize(handle, boxes_desc.desc(), &workspace_size); + auto workspace = at::empty(workspace_size, boxes.options().dtype(at::kByte)); + + auto boxes_impl = torch_mlu::getMluTensorImpl(boxes_); + auto boxes_ptr = boxes_impl->cnnlMalloc(); + auto scores_impl = torch_mlu::getMluTensorImpl(scores_); + auto scores_ptr = scores_impl->cnnlMalloc(); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + auto output_size_impl = torch_mlu::getMluTensorImpl(output_size); + auto output_size_ptr = output_size_impl->cnnlMalloc(); + + mluOpNmsRotated(handle, iou_threshold, boxes_desc.desc(), boxes_ptr, + scores_desc.desc(), scores_ptr, workspace_ptr, workspace_size, + output_desc.desc(), output_ptr, (int *)output_size_ptr); + int output_num = *static_cast(output_size.cpu().data_ptr()); + auto ret = output.to(boxes.options().dtype(at::kLong)); + return ret.slice(0, 0, output_num); +} diff --git a/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp new file mode 100644 index 0000000000..165aae1715 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/sparse_conv_mlu.cpp @@ -0,0 +1,446 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include + +#include + +#include "mlu_common_helper.h" +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +template +std::vector GetIndicePairsForwardMLUKernelLauncher( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose) { + // The following code is copied from + // mmcv/ops/csrc/pytorch/cuda/spconv_ops_cuda.cu to ensure the output is + // available for network train. The outputs of this function have correct + // shape but wrong value. + auto numAct = indices.size(0); + auto kernelVolume = kernelSize[0]; + int sub_m = (int)_subM; + int transpose = (int)_transpose; + int batch = (int)batchSize; + auto coorDim = indices.size(1) - 1; + + for (int i = 1; i < kernelSize.size(); ++i) { + kernelVolume *= kernelSize[i]; + } + + auto outputVolume = outSpatialShape[0]; + for (int i = 1; i < outSpatialShape.size(); ++i) { + outputVolume *= outSpatialShape[i]; + } + torch::Tensor indicePairs = at::full({kernelVolume, 2, numAct}, -1, + indices.options().dtype(at::kInt)); + torch::Tensor indiceNum = + at::zeros({kernelVolume}, indices.options().dtype(at::kInt)); + int out_size = sub_m == 1 + ? numAct + : std::min(numAct * kernelVolume, batch * outputVolume); + torch::Tensor out_indices = + at::zeros({out_size, coorDim + 1}, indices.options().dtype(at::kInt)); + auto indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indices, at::MemoryFormat::Contiguous); + auto indicePairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto indiceNum_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indiceNum, at::MemoryFormat::Contiguous); + auto out_indices_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + out_indices, at::MemoryFormat::Contiguous); + + std::vector input_space; + std::vector filter_space; + std::vector output_space; + std::vector padding32; + std::vector stride32; + std::vector dilation32; + for (int i = 0; i < NDim; i++) { + input_space.push_back(spatialShape[i]); + filter_space.push_back(kernelSize[i]); + output_space.push_back(outSpatialShape[i]); + padding32.push_back(padding[i]); + stride32.push_back(stride[i]); + dilation32.push_back(dilation[i]); + } + MluOpTensorDescriptor indices_desc, out_indices_desc, indicePairs_desc, + indiceNum_desc; + indices_desc.set(indices_contiguous); + indicePairs_desc.set(indicePairs_contiguous); + indiceNum_desc.set(indiceNum_contiguous); + out_indices_desc.set(out_indices_contiguous); + { + mluOpTensorLayout_t layout = MLUOP_LAYOUT_ARRAY; + mluOpDataType_t dtype = MLUOP_DTYPE_INT32; + std::vector dims; + dims = {numAct, coorDim + 1}; + mluOpSetTensorDescriptor(indices_desc.desc(), layout, dtype, dims.size(), + dims.data()); + dims = {kernelVolume, 2, numAct}; + mluOpSetTensorDescriptor(indicePairs_desc.desc(), layout, dtype, + dims.size(), dims.data()); + dims = {kernelVolume}; + mluOpSetTensorDescriptor(indiceNum_desc.desc(), layout, dtype, dims.size(), + dims.data()); + dims = {out_size, coorDim + 1}; + mluOpSetTensorDescriptor(out_indices_desc.desc(), layout, dtype, + dims.size(), dims.data()); + } + + mluOpSparseConvolutionDescriptor_t sparse_conv_desc; + mluOpCreateSparseConvolutionDescriptor(&sparse_conv_desc); + mluOpSetSparseConvolutionDescriptor( + sparse_conv_desc, NDim + 2, batch, padding32.data(), stride32.data(), + dilation32.data(), input_space.data(), filter_space.data(), + output_space.data(), sub_m, transpose, 0); + + auto handle = mluOpGetCurrentHandle(); + size_t workspace_size = 0; + mluOpGetIndicePairsWorkspaceSize( + handle, sparse_conv_desc, indices_desc.desc(), indicePairs_desc.desc(), + out_indices_desc.desc(), indiceNum_desc.desc(), &workspace_size); + auto indice_workspace_size = + at::empty(workspace_size, indices.options().dtype(at::kByte)); + + auto indices_impl = torch_mlu::getMluTensorImpl(indices_contiguous); + auto out_indices_impl = torch_mlu::getMluTensorImpl(out_indices_contiguous); + auto indicePairs_impl = torch_mlu::getMluTensorImpl(indicePairs_contiguous); + auto indiceNum_impl = torch_mlu::getMluTensorImpl(indiceNum_contiguous); + auto indice_workspace_impl = + torch_mlu::getMluTensorImpl(indice_workspace_size); + + auto indices_ptr = indices_impl->cnnlMalloc(); + auto out_indices_ptr = out_indices_impl->cnnlMalloc(); + auto indicePairs_ptr = indicePairs_impl->cnnlMalloc(); + auto indiceNum_ptr = indiceNum_impl->cnnlMalloc(); + auto indice_workspace_ptr = indice_workspace_impl->cnnlMalloc(); + + mluOpGetIndicePairs(handle, sparse_conv_desc, indices_desc.desc(), + indices_ptr, indice_workspace_ptr, workspace_size, + indicePairs_desc.desc(), indicePairs_ptr, + out_indices_desc.desc(), out_indices_ptr, + indiceNum_desc.desc(), indiceNum_ptr); + int num_act_out = 0; + mluOpGetSparseConvolutionNumActOut(sparse_conv_desc, &num_act_out); + mluOpDestroySparseConvolutionDescriptor(sparse_conv_desc); + if (!sub_m) { + return {out_indices.slice(0, 0, num_act_out), indicePairs, indiceNum}; + } else { + return {indices, indicePairs, indiceNum}; + } +} + +torch::Tensor IndiceConvForwardMLUKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor indicePairs, + torch::Tensor indiceNum, int64_t numActOut, int64_t _inverse, + int64_t _subM) { + auto indice_num_cpu = indiceNum.to({torch::kCPU}); + auto indice_num_cpu_64 = indice_num_cpu.data_ptr(); + int indice_num_len = indiceNum.numel(); + int64_t indice_num[indice_num_len]; + for (int i = 0; i < indice_num_len; ++i) { + indice_num[i] = (int64_t)(((int *)indice_num_cpu_64)[i]); + } + + // generate empty output + int C = filters.dim() == 4 ? filters.size(3) : filters.size(4); + torch::Tensor output = + at::zeros({numActOut, C}, features.options().dtype(at::kFloat)); + // generate descriptor + auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto output_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + output, at::MemoryFormat::Contiguous); + + MluOpTensorDescriptor features_desc, filters_desc, indice_pairs_desc, + output_desc; + features_desc.set(features_contiguous); + filters_desc.set(filters_contiguous); + indice_pairs_desc.set(indice_pairs_contiguous); + output_desc.set(output_contiguous); + + // set layout + { + mluOpTensorLayout_t layout; + mluOpDataType_t dtype; + int dim; + int dims[8]; + + // features_desc + mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // filters_desc + mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // indice_pairs_desc + mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims); + + // output_desc + mluOpGetTensorDescriptor(output_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(output_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, dim, + dims); + } + + auto handle = mluOpGetCurrentHandle(); + size_t workspace_size = 0; + mluOpGetIndiceConvolutionForwardWorkspaceSize( + handle, features_desc.desc(), filters_desc.desc(), + indice_pairs_desc.desc(), output_desc.desc(), indice_num, numActOut, + _inverse, _subM, &workspace_size); + + auto workspace = + at::empty(workspace_size, features.options().dtype(at::kByte)); + + auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous); + auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous); + auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous); + auto workspace_impl = torch_mlu::getMluTensorImpl(workspace); + + auto features_ptr = features_impl->cnnlMalloc(); + auto filters_ptr = filters_impl->cnnlMalloc(); + auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc(); + auto workspace_ptr = workspace_impl->cnnlMalloc(); + + // outputs + auto output_impl = torch_mlu::getMluTensorImpl(output); + auto output_ptr = output_impl->cnnlMalloc(); + mluOpIndiceConvolutionForward( + handle, features_desc.desc(), features_ptr, filters_desc.desc(), + filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + numActOut, _inverse, _subM, workspace_ptr, workspace_size, + output_desc.desc(), output_ptr); + + return output; +} + +std::vector IndiceConvBackwardMLUKernelLauncher( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM) { + auto indice_num_cpu = indiceNum.to({torch::kCPU}); + auto indice_num_cpu_64 = indice_num_cpu.data_ptr(); + int indice_num_len = indiceNum.numel(); + int64_t indice_num[indice_num_len]; + for (int i = 0; i < indice_num_len; ++i) { + indice_num[i] = (int64_t)(((int *)(indice_num_cpu_64))[i]); + } + + // generate empty input_grad + torch::Tensor input_grad = at::zeros({features.size(0), features.size(1)}, + features.options().dtype(at::kFloat)); + torch::Tensor filters_grad; + if (filters.dim() == 4) { + int h = filters.size(0); + int w = filters.size(1); + int c = filters.size(2); + int n = filters.size(3); + filters_grad = at::zeros({h, w, c, n}, filters.options().dtype(at::kFloat)); + } else if (filters.dim() == 5) { + int d = filters.size(0); + int h = filters.size(1); + int w = filters.size(2); + int c = filters.size(3); + int n = filters.size(4); + filters_grad = + at::zeros({d, h, w, c, n}, filters.options().dtype(at::kFloat)); + } + + auto features_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + auto output_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + outGrad, at::MemoryFormat::Contiguous); + auto indice_pairs_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + indicePairs, at::MemoryFormat::Contiguous); + auto input_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + features, at::MemoryFormat::Contiguous); + auto filters_grad_contiguous = torch_mlu::cnnl::ops::cnnl_contiguous( + filters, at::MemoryFormat::Contiguous); + + MluOpTensorDescriptor features_desc, output_grad_desc, filters_desc, + indice_pairs_desc, input_grad_desc, filters_grad_desc; + features_desc.set(features_contiguous); + filters_desc.set(filters_contiguous); + output_grad_desc.set(output_grad_contiguous); + indice_pairs_desc.set(indice_pairs_contiguous); + input_grad_desc.set(input_grad_contiguous); + filters_grad_desc.set(filters_grad_contiguous); + + // need to set desc layout with mluOp functions + { + mluOpTensorLayout_t layout; + mluOpDataType_t dtype; + int dim; + int dims[8]; + + // features_desc + mluOpGetTensorDescriptor(features_desc.desc(), &layout, &dtype, &dim, dims); + mluOpSetTensorDescriptor(features_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // filters_desc + mluOpGetTensorDescriptor(filters_desc.desc(), &layout, &dtype, &dim, dims); + if (dim == 4) { + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_HWCN, dtype, + dim, dims); + } else { + mluOpSetTensorDescriptor(filters_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + } + + // output_grad_desc + mluOpGetTensorDescriptor(output_grad_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(output_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + + // indice_pairs_desc + mluOpGetTensorDescriptor(indice_pairs_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(indice_pairs_desc.desc(), MLUOP_LAYOUT_ARRAY, + dtype, dim, dims); + + // input_grad_desc + mluOpGetTensorDescriptor(input_grad_desc.desc(), &layout, &dtype, &dim, + dims); + mluOpSetTensorDescriptor(input_grad_desc.desc(), MLUOP_LAYOUT_ARRAY, dtype, + dim, dims); + } + + auto handle = mluOpGetCurrentHandle(); + size_t data_workspace_size = 0; + mluOpGetIndiceConvolutionBackwardDataWorkspaceSize( + handle, output_grad_desc.desc(), filters_desc.desc(), + indice_pairs_desc.desc(), input_grad_desc.desc(), indice_num, _inverse, + &data_workspace_size); + + size_t filters_workspace_size = 0; + mluOpGetIndiceConvolutionBackwardFilterWorkspaceSize( + handle, features_desc.desc(), output_grad_desc.desc(), + indice_pairs_desc.desc(), filters_grad_desc.desc(), indice_num, _inverse, + _subM, &filters_workspace_size); + + auto indice_convbpdata_workspace = + at::empty(data_workspace_size, features.options().dtype(at::kByte)); + auto indice_convbpfilter_workspace = + at::empty(filters_workspace_size, filters.options().dtype(at::kByte)); + + auto features_impl = torch_mlu::getMluTensorImpl(features_contiguous); + auto filters_impl = torch_mlu::getMluTensorImpl(filters_contiguous); + auto output_grad_impl = torch_mlu::getMluTensorImpl(output_grad_contiguous); + auto indice_pairs_impl = torch_mlu::getMluTensorImpl(indice_pairs_contiguous); + auto indice_convbpdata_workspace_impl = + torch_mlu::getMluTensorImpl(indice_convbpdata_workspace); + auto indice_convbpfilter_workspace_impl = + torch_mlu::getMluTensorImpl(indice_convbpfilter_workspace); + + auto features_ptr = features_impl->cnnlMalloc(); + auto filters_ptr = filters_impl->cnnlMalloc(); + auto output_grad_ptr = output_grad_impl->cnnlMalloc(); + auto indice_pairs_ptr = indice_pairs_impl->cnnlMalloc(); + auto indice_convbpdata_workspace_ptr = + indice_convbpdata_workspace_impl->cnnlMalloc(); + auto indice_convbpfilter_workspace_ptr = + indice_convbpfilter_workspace_impl->cnnlMalloc(); + + // outputs + auto input_grad_impl = torch_mlu::getMluTensorImpl(input_grad); + auto input_grad_ptr = input_grad_impl->cnnlMalloc(); + auto filters_grad_impl = torch_mlu::getMluTensorImpl(filters_grad); + auto filters_grad_ptr = filters_grad_impl->cnnlMalloc(); + + mluOpIndiceConvolutionBackwardData( + handle, output_grad_desc.desc(), output_grad_ptr, filters_desc.desc(), + filters_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + _inverse, _subM, indice_convbpdata_workspace_ptr, data_workspace_size, + input_grad_desc.desc(), input_grad_ptr); + + mluOpIndiceConvolutionBackwardFilter( + handle, features_desc.desc(), features_ptr, output_grad_desc.desc(), + output_grad_ptr, indice_pairs_desc.desc(), indice_pairs_ptr, indice_num, + _inverse, _subM, indice_convbpfilter_workspace_ptr, + filters_workspace_size, filters_grad_desc.desc(), filters_grad_ptr); + + std::vector result; + result.push_back(input_grad); + result.push_back(filters_grad); + return result; +} + +torch::Tensor indice_conv_forward_mlu(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM) { + return IndiceConvForwardMLUKernelLauncher( + features, filters, indicePairs, indiceNum, numActOut, _inverse, _subM); +} + +std::vector indice_conv_backward_mlu( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM) { + return IndiceConvBackwardMLUKernelLauncher( + features, filters, outGrad, indicePairs, indiceNum, _inverse, _subM); +} + +torch::Tensor indice_conv_forward_impl(torch::Tensor features, + torch::Tensor filters, + torch::Tensor indicePairs, + torch::Tensor indiceNum, + int64_t numActOut, int64_t _inverse, + int64_t _subM); + +std::vector indice_conv_backward_impl( + torch::Tensor features, torch::Tensor filters, torch::Tensor outGrad, + torch::Tensor indicePairs, torch::Tensor indiceNum, int64_t _inverse, + int64_t _subM); + +REGISTER_DEVICE_IMPL(indice_conv_forward_impl, MLU, indice_conv_forward_mlu); +REGISTER_DEVICE_IMPL(indice_conv_backward_impl, MLU, indice_conv_backward_mlu); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<2>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<3>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template std::vector GetIndicePairsForwardMLUKernelLauncher<4>( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); diff --git a/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp new file mode 100644 index 0000000000..c3d31bc0e5 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/voxelization_mlu.cpp @@ -0,0 +1,268 @@ +/************************************************************************* + * Copyright (C) 2022 by Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +#define MIN(a, b) (((a) < (b)) ? (a) : (b)) + +void KernelDynamicVoxelize( + cnrtDim3_t k_dim, cnrtFunctionType_t k_type, cnrtQueue_t queue, + const void *points, void *coors, const float voxel_x, const float voxel_y, + const float voxel_z, const float coors_x_min, const float coors_y_min, + const float coors_z_min, const float coors_x_max, const float coors_y_max, + const float coors_z_max, const int32_t grid_x, const int32_t grid_y, + const int32_t grid_z, const int32_t num_points, const int32_t num_features); + +void KernelPoint2Voxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, void *coors, void *point_to_pointidx, + void *point_to_voxelidx, const int32_t num_points, + const int32_t max_points); + +void KernelCalcPointsPerVoxel(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, void *point_to_pointidx, + void *point_to_voxelidx, void *coor_to_voxelidx, + void *num_points_per_voxel, void *voxel_num, + const int32_t max_voxels, + const int32_t num_points); + +void KernelAssignVoxelsCoors(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const void *points, + void *temp_coors, void *point_to_voxelidx, + void *coor_to_voxelidx, void *voxels, void *coors, + const int32_t max_points, const int32_t num_points, + const int32_t num_features); + +// policy function +static void policyFuncDefault(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type, + const int num_points) { + k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + k_dim->y = MIN((num_points + k_dim->x - 1) / k_dim->x, + torch_mlu::getDeviceAttr(cnrtAttrClusterCount)); + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_UNION1; +} + +// policy function +static void policyFuncCalcPointsPerVoxel(cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type, + const int num_points) { + k_dim->x = 1; + k_dim->y = 1; + k_dim->z = 1; + *k_type = CNRT_FUNC_TYPE_BLOCK; +} + +int HardVoxelizeForwardMLUKernelLauncher( + const at::Tensor &points, at::Tensor &voxels, at::Tensor &coors, + at::Tensor &num_points_per_voxel, const std::vector voxel_size, + const std::vector coors_range, const int max_points, + const int max_voxels, const int NDim = 3) { + // check datatype + TORCH_CHECK(points.scalar_type() == at::kFloat, + "points type should be Float, got ", points.scalar_type(), "."); + TORCH_CHECK(voxels.scalar_type() == at::kFloat, + "voxels type should be Float, got ", voxels.scalar_type(), "."); + TORCH_CHECK(coors.scalar_type() == at::kInt, + "coors type should be Float, got ", coors.scalar_type(), "."); + TORCH_CHECK(num_points_per_voxel.scalar_type() == at::kInt, + "num_points_per_voxel type should be Float, got ", + num_points_per_voxel.scalar_type(), "."); + + // check shape + TORCH_CHECK(points.dim() == 2, "points should be a 2d tensor, got ", + points.dim(), "D."); + TORCH_CHECK(voxels.dim() == 3, "voxels should be a 3d tensor, got ", + voxels.dim(), "D."); + TORCH_CHECK(coors.dim() == 2, "coors should be a 2d tensor, got ", + coors.dim(), "D."); + TORCH_CHECK(num_points_per_voxel.dim() == 1, + "num_points_per_voxel should be a 1d tensor, got ", + num_points_per_voxel.dim(), "D."); + + const int num_points = points.size(0); + const int num_features = points.size(1); + + TORCH_CHECK(points.size(0) == num_points, + "the 1st dimensions of points should be num_points, got ", + points.size(0), "."); + TORCH_CHECK(points.size(1) == num_features, + "the 2nd dimensions of points should be num_features, got ", + points.size(1), "."); + TORCH_CHECK(voxels.size(0) == max_voxels, + "the 1st dimensions of voxels should be max_voxels, got ", + voxels.size(0), "."); + TORCH_CHECK(voxels.size(1) == max_points, + "the 2nd dimensions of voxels should be max_points, got ", + voxels.size(1), "."); + TORCH_CHECK(voxels.size(2) == num_features, + "the 3rd dimensions of voxels should be num_features, got ", + voxels.size(2), "."); + TORCH_CHECK(coors.size(0) == max_voxels, + "the 1st dimensions of coors should be max_voxels, got ", + coors.size(0), "."); + TORCH_CHECK(coors.size(1) == 3, + "the 2nd dimensions of coors should be 3, got ", coors.size(1), + "."); + TORCH_CHECK(num_points_per_voxel.size(0) == max_voxels, + "the 1st dimensions of num_points_per_voxel should be 3, got ", + num_points_per_voxel.size(0), "."); + + // large tensor check + const size_t max_input_size = 2147483648; + TORCH_CHECK(points.numel() < max_input_size, + "points element num should be less than 2^31, got ", + points.numel(), "."); + TORCH_CHECK(voxels.numel() < max_input_size, + "voxels element num should be less than 2^31, got ", + voxels.numel(), "."); + TORCH_CHECK(coors.numel() < max_input_size, + "coors element num should be less than 2^31, got ", coors.numel(), + "."); + + // check zero element + if (max_points == 0 || max_voxels == 0) { + return 0; + } + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto points_ = points.contiguous(); + auto points_impl = torch_mlu::getMluTensorImpl(points_); + auto points_ptr = points_impl->cnnlMalloc(); + auto voxels_ = voxels.contiguous(); + auto voxels_impl = torch_mlu::getMluTensorImpl(voxels_); + auto voxels_ptr = voxels_impl->cnnlMalloc(); + auto coors_ = coors.contiguous(); + auto coors_impl = torch_mlu::getMluTensorImpl(coors_); + auto coors_ptr = coors_impl->cnnlMalloc(); + auto num_points_per_voxel_ = num_points_per_voxel.contiguous(); + auto num_points_per_voxel_impl = + torch_mlu::getMluTensorImpl(num_points_per_voxel_); + auto num_points_per_voxel_ptr = num_points_per_voxel_impl->cnnlMalloc(); + + // calculate task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFuncDefault(&k_dim, &k_type, num_points); + + // 1. link point to corresponding voxel coors + const float voxel_x = voxel_size[0]; + const float voxel_y = voxel_size[1]; + const float voxel_z = voxel_size[2]; + const float coors_x_min = coors_range[0]; + const float coors_y_min = coors_range[1]; + const float coors_z_min = coors_range[2]; + const float coors_x_max = coors_range[3]; + const float coors_y_max = coors_range[4]; + const float coors_z_max = coors_range[5]; + + const int grid_x = round((coors_x_max - coors_x_min) / voxel_x); + const int grid_y = round((coors_y_max - coors_y_min) / voxel_y); + const int grid_z = round((coors_z_max - coors_z_min) / voxel_z); + + auto temp_coors = + at::zeros({NDim, num_points}, points.options().dtype(at::kInt)) + .contiguous(); + auto temp_coors_impl = torch_mlu::getMluTensorImpl(temp_coors); + auto temp_coors_ptr = temp_coors_impl->cnnlMalloc(); + + KernelDynamicVoxelize(k_dim, k_type, queue, points_ptr, temp_coors_ptr, + voxel_x, voxel_y, voxel_z, coors_x_min, coors_y_min, + coors_z_min, coors_x_max, coors_y_max, coors_z_max, + grid_x, grid_y, grid_z, num_points, num_features); + + // 2. map point to the idx of the corresponding voxel, find duplicate coor + auto point_to_pointidx = at::zeros( + { + num_points, + }, + points.options().dtype(at::kInt)) + .contiguous(); + auto point_to_pointidx_impl = torch_mlu::getMluTensorImpl(point_to_pointidx); + auto point_to_pointidx_ptr = point_to_pointidx_impl->cnnlMalloc(); + auto point_to_voxelidx = at::zeros( + { + num_points, + }, + points.options().dtype(at::kInt)) + .contiguous(); + auto point_to_voxelidx_impl = torch_mlu::getMluTensorImpl(point_to_voxelidx); + auto point_to_voxelidx_ptr = point_to_voxelidx_impl->cnnlMalloc(); + + KernelPoint2Voxel(k_dim, k_type, queue, temp_coors_ptr, point_to_pointidx_ptr, + point_to_voxelidx_ptr, num_points, max_points); + + // calculate task dimension + cnrtDim3_t k_dim_calc_points_per_voxel; + cnrtFunctionType_t k_type_calc_points_per_voxel; + policyFuncCalcPointsPerVoxel(&k_dim_calc_points_per_voxel, + &k_type_calc_points_per_voxel, num_points); + + // 3. determine voxel num and voxel's coor index + auto coor_to_voxelidx = at::zeros( + { + num_points, + }, + points.options().dtype(at::kInt)) + .contiguous(); + auto coor_to_voxelidx_impl = torch_mlu::getMluTensorImpl(coor_to_voxelidx); + auto coor_to_voxelidx_ptr = coor_to_voxelidx_impl->cnnlMalloc(); + auto voxel_num = at::zeros( + { + 1, + }, + points.options().dtype(at::kInt)) + .contiguous(); + auto voxel_num_impl = torch_mlu::getMluTensorImpl(voxel_num); + auto voxel_num_ptr = voxel_num_impl->cnnlMalloc(); + + KernelCalcPointsPerVoxel( + k_dim_calc_points_per_voxel, k_type_calc_points_per_voxel, queue, + point_to_pointidx_ptr, point_to_voxelidx_ptr, coor_to_voxelidx_ptr, + num_points_per_voxel_ptr, voxel_num_ptr, max_voxels, num_points); + + // 4. copy point features and coors of each voxels to voxels + KernelAssignVoxelsCoors(k_dim, k_type, queue, points_ptr, temp_coors_ptr, + point_to_voxelidx_ptr, coor_to_voxelidx_ptr, + voxels_ptr, coors_ptr, max_points, num_points, + num_features); + + auto voxel_num_cpu = voxel_num.to(at::kCPU); + int voxel_num_int = voxel_num_cpu.data_ptr()[0]; + + return voxel_num_int; +} + +int hard_voxelize_forward_mlu(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim) { + return HardVoxelizeForwardMLUKernelLauncher( + points, voxels, coors, num_points_per_voxel, voxel_size, coors_range, + max_points, max_voxels, NDim); +}; + +int hard_voxelize_forward_impl(const at::Tensor &points, at::Tensor &voxels, + at::Tensor &coors, + at::Tensor &num_points_per_voxel, + const std::vector voxel_size, + const std::vector coors_range, + const int max_points, const int max_voxels, + const int NDim); + +REGISTER_DEVICE_IMPL(hard_voxelize_forward_impl, MLU, + hard_voxelize_forward_mlu); diff --git a/mmcv/ops/csrc/pytorch/nms_rotated.cpp b/mmcv/ops/csrc/pytorch/nms_rotated.cpp index b07ed5aa11..1d49c37dd6 100644 --- a/mmcv/ops/csrc/pytorch/nms_rotated.cpp +++ b/mmcv/ops/csrc/pytorch/nms_rotated.cpp @@ -17,6 +17,11 @@ Tensor nms_rotated_npu(const Tensor dets, const Tensor scores, const Tensor labels, const float iou_threshold); #endif +#ifdef MMCV_WITH_MLU +Tensor nms_rotated_mlu(const Tensor dets, const Tensor scores, + const float iou_threshold); +#endif + // Interface for Python // inline is needed to prevent multiple function definitions when this header is // included by different cpps @@ -36,6 +41,10 @@ Tensor nms_rotated(const Tensor dets, const Tensor scores, const Tensor order, return nms_rotated_npu(dets, scores, labels, iou_threshold); #else AT_ERROR("Not compiled with NPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (dets.device().type() == at::kMLU) { + return nms_rotated_mlu(dets, scores, iou_threshold); #endif } diff --git a/mmcv/ops/csrc/pytorch/spconv_ops.cpp b/mmcv/ops/csrc/pytorch/spconv_ops.cpp index 09c8110ad8..723c6c7b90 100644 --- a/mmcv/ops/csrc/pytorch/spconv_ops.cpp +++ b/mmcv/ops/csrc/pytorch/spconv_ops.cpp @@ -35,6 +35,26 @@ std::vector get_indice_pairs_forward_cuda( padding, dilation, outPadding, _subM, _transpose); }; +template +std::vector GetIndicePairsForwardMLUKernelLauncher( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose); + +template +std::vector get_indice_pairs_forward_mlu( + torch::Tensor indices, int64_t batchSize, + std::vector outSpatialShape, std::vector spatialShape, + std::vector kernelSize, std::vector stride, + std::vector padding, std::vector dilation, + std::vector outPadding, int64_t _subM, int64_t _transpose) { + return GetIndicePairsForwardMLUKernelLauncher( + indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride, + padding, dilation, outPadding, _subM, _transpose); +} + template std::vector GetIndicePairsBackwardCUDAKernelLauncher( torch::Tensor indices, torch::Tensor gridOut, int64_t batchSize, @@ -71,6 +91,12 @@ std::vector get_indice_pairs_forward( padding, dilation, outPadding, _subM, _transpose); #else AT_ERROR("get_indice_pairs is not compiled with GPU support"); +#endif +#ifdef MMCV_WITH_MLU + } else if (indices.device().type() == at::kMLU) { + return get_indice_pairs_forward_mlu( + indices, batchSize, outSpatialShape, spatialShape, kernelSize, stride, + padding, dilation, outPadding, _subM, _transpose); #endif } else { AT_ERROR("get_indice_pairs is not implemented on CPU"); diff --git a/mmcv/ops/nms.py b/mmcv/ops/nms.py index 14df44a4be..feab4f3cad 100644 --- a/mmcv/ops/nms.py +++ b/mmcv/ops/nms.py @@ -410,11 +410,12 @@ def nms_rotated(dets: Tensor, input_labels = scores.new_empty(0, dtype=torch.int) else: input_labels = labels - if dets.device.type == 'npu': + if dets.device.type in ('npu', 'mlu'): order = scores.new_empty(0, dtype=torch.long) - coefficient = 57.29578 # 180 / PI - for i in range(dets.size()[0]): - dets_cw[i][4] *= coefficient # radians to angle + if dets.device.type == 'npu': + coefficient = 57.29578 # 180 / PI + for i in range(dets.size()[0]): + dets_cw[i][4] *= coefficient # radians to angle keep_inds = ext_module.nms_rotated(dets_cw, scores, order, dets_cw, input_labels, iou_threshold, multi_label) diff --git a/setup.py b/setup.py index 0453159503..6040117e6c 100644 --- a/setup.py +++ b/setup.py @@ -211,6 +211,7 @@ def get_extensions(): include_dirs = [] + extra_objects = [] is_rocm_pytorch = False try: from torch.utils.cpp_extension import ROCM_HOME @@ -238,16 +239,98 @@ def get_extensions(): torch.is_mlu_available()) or \ os.getenv('FORCE_MLU', '0') == '1': from torch_mlu.utils.cpp_extension import MLUExtension + + def get_mluops_version(file_path): + with open(file_path) as f: + for line in f: + if re.search('MLUOP_MAJOR', line): + major = line.strip().split(' ')[2] + if re.search('MLUOP_MINOR', line): + minor = line.strip().split(' ')[2] + if re.search('MLUOP_PATCHLEVEL', line): + patchlevel = line.strip().split(' ')[2] + mluops_version = f'v{major}.{minor}.{patchlevel}' + return mluops_version + + mmcv_mluops_version = get_mluops_version( + './mmcv/ops/csrc/pytorch/mlu/mlu_common_helper.h') + mlu_ops_path = os.getenv('MMCV_MLU_OPS_PATH') + if mlu_ops_path: + exists_mluops_version = get_mluops_version( + mlu_ops_path + '/bangc-ops/mlu_op.h') + if exists_mluops_version != mmcv_mluops_version: + print('the version of mlu-ops provided is %s,' + ' while %s is needed.' % + (exists_mluops_version, mmcv_mluops_version)) + exit() + try: + if os.path.exists('mlu-ops'): + if os.path.islink('mlu-ops'): + os.remove('mlu-ops') + os.symlink(mlu_ops_path, 'mlu-ops') + elif os.path.abspath('mlu-ops') != mlu_ops_path: + os.symlink(mlu_ops_path, 'mlu-ops') + else: + os.symlink(mlu_ops_path, 'mlu-ops') + except Exception: + raise FileExistsError( + 'mlu-ops already exists, please move it out,' + 'or rename or remove it.') + else: + if not os.path.exists('mlu-ops'): + import requests + mluops_url = 'https://github.com/Cambricon/mlu-ops/' + \ + 'archive/refs/tags/' + mmcv_mluops_version + '.zip' + req = requests.get(mluops_url) + with open('./mlu-ops.zip', 'wb') as f: + try: + f.write(req.content) + except Exception: + raise ImportError('failed to download mlu-ops') + + from zipfile import BadZipFile, ZipFile + with ZipFile('./mlu-ops.zip', 'r') as archive: + try: + archive.extractall() + dir_name = archive.namelist()[0].split('/')[0] + os.rename(dir_name, 'mlu-ops') + except BadZipFile: + print('invalid mlu-ops.zip file') + else: + exists_mluops_version = get_mluops_version( + './mlu-ops/bangc-ops/mlu_op.h') + if exists_mluops_version != mmcv_mluops_version: + print('the version of provided mlu-ops is %s,' + ' while %s is needed.' % + (exists_mluops_version, mmcv_mluops_version)) + exit() + define_macros += [('MMCV_WITH_MLU', None)] - mlu_args = os.getenv('MMCV_MLU_ARGS') - extra_compile_args['cncc'] = [mlu_args] if mlu_args else [] + mlu_args = os.getenv('MMCV_MLU_ARGS', '-DNDEBUG ') + mluops_includes = [] + mluops_includes.append('-I' + + os.path.abspath('./mlu-ops/bangc-ops')) + mluops_includes.append( + '-I' + os.path.abspath('./mlu-ops/bangc-ops/kernels')) + extra_compile_args['cncc'] = [mlu_args] + \ + mluops_includes if mlu_args else mluops_includes + extra_compile_args['cxx'] += ['-fno-gnu-unique'] op_files = glob.glob('./mmcv/ops/csrc/pytorch/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/cpu/*.cpp') + \ glob.glob('./mmcv/ops/csrc/pytorch/mlu/*.cpp') + \ - glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + glob.glob('./mmcv/ops/csrc/common/mlu/*.mlu') + \ + glob.glob( + './mlu-ops/bangc-ops/core/**/*.cpp', recursive=True) + \ + glob.glob( + './mlu-ops/bangc-ops/kernels/**/*.cpp', recursive=True) + \ + glob.glob( + './mlu-ops/bangc-ops/kernels/**/*.mlu', recursive=True) + extra_objects = glob.glob( + './mlu-ops/bangc-ops/kernels/kernel_wrapper/*.o') extension = MLUExtension include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common')) include_dirs.append(os.path.abspath('./mmcv/ops/csrc/common/mlu')) + include_dirs.append(os.path.abspath('./mlu-ops/bangc-ops')) elif (hasattr(torch.backends, 'mps') and torch.backends.mps.is_available()) or os.getenv( 'FORCE_MPS', '0') == '1': @@ -309,6 +392,7 @@ def get_extensions(): sources=op_files, include_dirs=include_dirs, define_macros=define_macros, + extra_objects=extra_objects, extra_compile_args=extra_compile_args) extensions.append(ext_ops) return extensions diff --git a/tests/test_ops/test_ball_query.py b/tests/test_ops/test_ball_query.py index d3fc7912c5..a3f6518197 100644 --- a/tests/test_ops/test_ball_query.py +++ b/tests/test_ops/test_ball_query.py @@ -3,55 +3,59 @@ import torch from mmcv.ops import ball_query +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_ball_query(): - new_xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], - [-2.2769, 2.7817, -0.2334], - [-0.4003, 2.4666, -0.5116], - [-0.0740, 1.3147, -1.3625], - [-0.0740, 1.3147, -1.3625]], - [[-2.0289, 2.4952, -0.1708], - [-2.0668, 6.0278, -0.4875], - [0.4066, 1.4211, -0.2947], - [-2.0289, 2.4952, -0.1708], - [-2.0289, 2.4952, -0.1708]]]).cuda() +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +def test_ball_query(device): + new_xyz = torch.tensor( + [[[-0.0740, 1.3147, -1.3625], [-2.2769, 2.7817, -0.2334], + [-0.4003, 2.4666, -0.5116], [-0.0740, 1.3147, -1.3625], + [-0.0740, 1.3147, -1.3625]], + [[-2.0289, 2.4952, -0.1708], [-2.0668, 6.0278, -0.4875], + [0.4066, 1.4211, -0.2947], [-2.0289, 2.4952, -0.1708], + [-2.0289, 2.4952, -0.1708]]], + device=device) - xyz = torch.tensor([[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], - [-0.4003, 2.4666, - -0.5116], [-0.5251, 2.4379, -0.8466], - [-0.9691, 1.1418, - -1.3733], [-0.2232, 0.9561, -1.3626], - [-2.2769, 2.7817, -0.2334], - [-0.2822, 1.3192, -1.3645], [0.1533, 1.5024, -1.0432], - [0.4917, 1.1529, -1.3496]], - [[-2.0289, 2.4952, - -0.1708], [-0.7188, 0.9956, -0.5096], - [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], - [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], - [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], - [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, - -1.2000]]]).cuda() + xyz = torch.tensor( + [[[-0.0740, 1.3147, -1.3625], [0.5555, 1.0399, -1.3634], + [-0.4003, 2.4666, -0.5116], [-0.5251, 2.4379, -0.8466], + [-0.9691, 1.1418, -1.3733], [-0.2232, 0.9561, -1.3626], + [-2.2769, 2.7817, -0.2334], [-0.2822, 1.3192, -1.3645], + [0.1533, 1.5024, -1.0432], [0.4917, 1.1529, -1.3496]], + [[-2.0289, 2.4952, -0.1708], [-0.7188, 0.9956, -0.5096], + [-2.0668, 6.0278, -0.4875], [-1.9304, 3.3092, 0.6610], + [0.0949, 1.4332, 0.3140], [-1.2879, 2.0008, -0.7791], + [-0.7252, 0.9611, -0.6371], [0.4066, 1.4211, -0.2947], + [0.3220, 1.4447, 0.3548], [-0.9744, 2.3856, -1.2000]]], + device=device) idx = ball_query(0, 0.2, 5, xyz, new_xyz) - expected_idx = torch.tensor([[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], - [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]], - [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], - [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]]).cuda() + expected_idx = torch.tensor( + [[[0, 0, 0, 0, 0], [6, 6, 6, 6, 6], [2, 2, 2, 2, 2], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], + [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]], + device=device) assert torch.all(idx == expected_idx) # test dilated ball query idx = ball_query(0.2, 0.4, 5, xyz, new_xyz) - expected_idx = torch.tensor([[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6], - [2, 3, 2, 2, 2], [0, 5, 7, 0, 0], - [0, 5, 7, 0, 0]], - [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], - [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], - [0, 0, 0, 0, 0]]]).cuda() + expected_idx = torch.tensor( + [[[0, 5, 7, 0, 0], [6, 6, 6, 6, 6], [2, 3, 2, 2, 2], [0, 5, 7, 0, 0], + [0, 5, 7, 0, 0]], + [[0, 0, 0, 0, 0], [2, 2, 2, 2, 2], [7, 7, 7, 7, 7], [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]]], + device=device) assert torch.all(idx == expected_idx) diff --git a/tests/test_ops/test_nms_rotated.py b/tests/test_ops/test_nms_rotated.py index bee562a6f1..88b41fec85 100644 --- a/tests/test_ops/test_nms_rotated.py +++ b/tests/test_ops/test_nms_rotated.py @@ -3,7 +3,7 @@ import pytest import torch -from mmcv.utils import IS_CUDA_AVAILABLE, IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE class TestNmsRotated: @@ -16,7 +16,11 @@ class TestNmsRotated: pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_ml_nms_rotated(self, device): from mmcv.ops import nms_rotated @@ -58,7 +62,11 @@ def test_ml_nms_rotated(self, device): pytest.param( 'cuda', marks=pytest.mark.skipif( - not IS_CUDA_AVAILABLE, reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) ]) def test_nms_rotated(self, device): from mmcv.ops import nms_rotated diff --git a/tests/test_ops/test_spconv.py b/tests/test_ops/test_spconv.py index 098ff2189a..17ca5678ed 100644 --- a/tests/test_ops/test_spconv.py +++ b/tests/test_ops/test_spconv.py @@ -10,6 +10,8 @@ if torch.__version__ == 'parrots': pytest.skip('not supported in parrots now', allow_module_level=True) +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + def make_sparse_convmodule(in_channels, out_channels, @@ -76,21 +78,29 @@ def make_sparse_convmodule(in_channels, return layers -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') -def test_make_sparse_convmodule(): +@pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +def test_make_sparse_convmodule(device): torch.cuda.empty_cache() voxel_features = torch.tensor([[6.56126, 0.9648336, -1.7339306, 0.315], [6.8162713, -2.480431, -1.3616394, 0.36], [11.643568, -4.744306, -1.3580885, 0.16], [23.482342, 6.5036807, 0.5806964, 0.35]], dtype=torch.float32, - device='cuda') # n, point_features + device=device) # n, point_features coordinates = torch.tensor( [[0, 12, 819, 131], [0, 16, 750, 136], [1, 16, 705, 232], [1, 35, 930, 469]], dtype=torch.int32, - device='cuda') # n, 4(batch, ind_x, ind_y, ind_z) + device=device) # n, 4(batch, ind_x, ind_y, ind_z) # test input_sp_tensor = SparseConvTensor(voxel_features, coordinates, @@ -105,7 +115,7 @@ def test_make_sparse_convmodule(): padding=0, conv_type='SubMConv3d', norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - order=('conv', 'norm', 'act')).cuda() + order=('conv', 'norm', 'act')).to(device) assert isinstance(sparse_block0[0], SubMConv3d) assert sparse_block0[0].in_channels == 4 assert sparse_block0[0].out_channels == 16 @@ -118,16 +128,18 @@ def test_make_sparse_convmodule(): out_features = sparse_block0(input_sp_tensor) assert out_features.features.shape == torch.Size([4, 16]) - sparse_block1 = make_sparse_convmodule( - 4, - 16, - 3, - 'test1', - stride=1, - padding=0, - conv_type='SparseInverseConv3d', - norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - order=('norm', 'act', 'conv')).cuda() - assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d) - assert isinstance(sparse_block1[1], torch.nn.ReLU) - assert isinstance(sparse_block1[2], SparseInverseConv3d) + # device == mlu: not support inverse==1 yet + if device != 'mlu': + sparse_block1 = make_sparse_convmodule( + 4, + 16, + 3, + 'test1', + stride=1, + padding=0, + conv_type='SparseInverseConv3d', + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + order=('norm', 'act', 'conv')).to(device) + assert isinstance(sparse_block1[2], SparseInverseConv3d) + assert isinstance(sparse_block1[0], torch.nn.BatchNorm1d) + assert isinstance(sparse_block1[1], torch.nn.ReLU) diff --git a/tests/test_ops/test_voxelization.py b/tests/test_ops/test_voxelization.py index 1422e0a3bd..cd01eb46e6 100644 --- a/tests/test_ops/test_voxelization.py +++ b/tests/test_ops/test_voxelization.py @@ -4,7 +4,7 @@ import torch from mmcv.ops import Voxelization -from mmcv.utils import IS_NPU_AVAILABLE +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE, IS_NPU_AVAILABLE def _get_voxel_points_indices(points, coors, voxel): @@ -17,7 +17,7 @@ def _get_voxel_points_indices(points, coors, voxel): pytest.param( 'cuda:0', marks=pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support')) + not IS_CUDA_AVAILABLE, reason='requires CUDA support')) ]) def test_voxelization(device_type): voxel_size = [0.5, 0.5, 0.5] @@ -63,8 +63,7 @@ def test_voxelization(device_type): assert num_points_current_voxel == expected_num_points_per_voxel[i] -@pytest.mark.skipif( - not torch.cuda.is_available(), reason='requires CUDA support') +@pytest.mark.skipif(not IS_CUDA_AVAILABLE, reason='requires CUDA support') def test_voxelization_nondeterministic(): voxel_size = [0.5, 0.5, 0.5] point_cloud_range = [0, -40, -3, 70.4, 40, 1] @@ -140,6 +139,41 @@ def test_voxelization_nondeterministic(): assert len(coors_set) == len(coors) == len(coors_all_set) +@pytest.mark.parametrize('device_type', [ + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) +]) +def test_voxelization_mlu(device_type): + voxel_size = [0.5, 0.5, 0.5] + point_cloud_range = [0, -40, -3, 70.4, 40, 1] + + voxel_dict = np.load( + 'tests/data/for_3d_ops/test_voxel.npy', allow_pickle=True).item() + expected_coors = voxel_dict['coors'] + expected_voxels = voxel_dict['voxels'] + expected_num_points_per_voxel = voxel_dict['num_points_per_voxel'] + points = voxel_dict['points'] + + points = torch.tensor(points) + max_num_points = 1000 + hard_voxelization = Voxelization(voxel_size, point_cloud_range, + max_num_points) + + device = torch.device(device_type) + + # test hard_voxelization on mlu + points = points.contiguous().to(device) + coors, voxels, num_points_per_voxel = hard_voxelization.forward(points) + coors = coors.cpu().detach().numpy() + voxels = voxels.cpu().detach().numpy() + num_points_per_voxel = num_points_per_voxel.cpu().detach().numpy() + assert np.all(coors == expected_coors) + assert np.all(voxels == expected_voxels) + assert np.all(num_points_per_voxel == expected_num_points_per_voxel) + + @pytest.mark.parametrize('device_type', [ pytest.param( 'npu',