diff --git a/docs/en/understand_mmcv/ops.md b/docs/en/understand_mmcv/ops.md index aa44b59372..0d955f3441 100644 --- a/docs/en/understand_mmcv/ops.md +++ b/docs/en/understand_mmcv/ops.md @@ -10,7 +10,7 @@ We implement common ops used in detection, segmentation, etc. | BBoxOverlaps | | √ | √ | | BorderAlign | | √ | | | BoxIouRotated | √ | √ | | -| CARAFE | | √ | | +| CARAFE | | √ | √ | | ChamferDistance | | √ | | | CrissCrossAttention | | √ | | | ContourExpand | √ | | | diff --git a/docs/zh_cn/understand_mmcv/ops.md b/docs/zh_cn/understand_mmcv/ops.md index 9a83fcbdc4..ef05a6fefa 100644 --- a/docs/zh_cn/understand_mmcv/ops.md +++ b/docs/zh_cn/understand_mmcv/ops.md @@ -10,7 +10,7 @@ MMCV 提供了检测、分割等任务中常用的算子 | BBoxOverlaps | | √ | √ | | BorderAlign | | √ | | | BoxIouRotated | √ | √ | | -| CARAFE | | √ | | +| CARAFE | | √ | √ | | ChamferDistance | | √ | | | CrissCrossAttention | | √ | | | ContourExpand | √ | | | diff --git a/mmcv/ops/carafe.py b/mmcv/ops/carafe.py index 18230c0807..cb2d346458 100644 --- a/mmcv/ops/carafe.py +++ b/mmcv/ops/carafe.py @@ -158,8 +158,6 @@ def forward(ctx, features: Tensor, masks: Tensor, kernel_size: int, def backward( ctx, grad_output: Tensor) -> Tuple[Tensor, Tensor, None, None, None]: - assert grad_output.is_cuda - features, masks, rfeatures = ctx.saved_tensors kernel_size = ctx.kernel_size group_size = ctx.group_size diff --git a/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu new file mode 100644 index 0000000000..ac5ea0d653 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu @@ -0,0 +1,552 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "carafe_utils.hpp" +#include "common_mlu_helper.hpp" + +#define INDEX3(n, h, w, c, strN, strH, strW) \ + (strN) * (n) + (strH) * (h) + (strW) * (w) + (c) + +#define NRAM_BLOCK PAD_DOWN(MAX_NRAM_SIZE / 5, NRAM_ALIGN_SIZE) + +__nram__ char nram_buf[MAX_NRAM_SIZE]; + +namespace forward { +struct BlockId { + int Ho; + int Wo; + int G; + int Cg; + int Kh; + int Kw; + int Hi; + int Wi; +}; + +// start indices of block +struct BlockStart { + int Ho; + int Wo; + int G; + int Cg; + int Kh; + int Kw; + int Hi; + int Wi; + int C; +}; + +struct BlockEnd { + int Ho; + int Wo; + int Kh; + int Kw; + int Hi; + int Wi; +}; + +struct BlockSize { + int Ho; + int Wo; + int G; + int Cg; + int Kh; + int Kw; + int Hi; + int Wi; +}; + +template +__mlu_func__ void carafeForwardBLOCK(T *input, T *mask, + const CarafeForwardParam param, + const CarafeForwardBlockDim block_dim, + const CarafeForwardGridDim grid_dim, + T *output) { + // data block info + BlockId blkId; + BlockStart blkStart; + BlockEnd blkEnd; + BlockSize blkSize; + + // set pointers on NRAM arrays + + // input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_(G*Cg)] + T *input_nram = (T *)nram_buf; + + // mask_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Kh*Kw)] + T *mask_nram = input_nram + param.input_nram_size; + + // output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)] + T *output_nram = mask_nram + param.mask_nram_size; + + // sum_array[blkDim_(G*Cg)] + T *sum_array = output_nram + param.output_nram_size; + + /* ===== loop over N, grid_dim(Ho,Wo,G,Cg) + * iterations are distributed over computing cores + */ + for (int loop_index = taskId; loop_index < param.job_num; + loop_index += taskDim) { + // block idx + blkId.Cg = loop_index; + blkId.G = blkId.Cg / grid_dim.Cg; + blkId.Wo = blkId.G / grid_dim.G; + blkId.Ho = blkId.Wo / grid_dim.Wo; + int sample_idx = blkId.Ho / grid_dim.Ho; + + blkId.Cg %= grid_dim.Cg; + blkId.G %= grid_dim.G; + blkId.Wo %= grid_dim.Wo; + blkId.Ho %= grid_dim.Ho; + + // block starting indices + blkStart.Ho = blkId.Ho * block_dim.Ho; + blkStart.Wo = blkId.Wo * block_dim.Wo; + blkStart.G = blkId.G * block_dim.G; + blkStart.Cg = blkId.Cg * block_dim.Cg; + blkStart.C = blkStart.G * param.Cg + blkStart.Cg; + + // block size + blkSize.Ho = block_dim.Ho; + blkSize.Wo = block_dim.Wo; + blkSize.G = block_dim.G; + blkSize.Cg = block_dim.Cg; + + // take care of blocks near the end of each dimension + if (blkId.Ho == (grid_dim.Ho - 1)) { + blkSize.Ho = param.Ho - (grid_dim.Ho - 1) * block_dim.Ho; + } + if (blkId.Wo == (grid_dim.Wo - 1)) { + blkSize.Wo = param.Wo - (grid_dim.Wo - 1) * block_dim.Wo; + } + if (blkId.G == (grid_dim.G - 1)) { + blkSize.G = param.group_size - (grid_dim.G - 1) * block_dim.G; + } + if (blkId.Cg == (grid_dim.Cg - 1)) { + blkSize.Cg = param.Cg - (grid_dim.Cg - 1) * block_dim.Cg; + } + + // block end indices + blkEnd.Ho = blkStart.Ho + blkSize.Ho - 1; + blkEnd.Wo = blkStart.Wo + blkSize.Wo - 1; + + // set output_nram to zero + __nramset(output_nram, param.output_nram_size, T(0)); + + // loop blocks of kernel window: grid_dim.(Kh, Kw) + for (blkId.Kh = 0; blkId.Kh < grid_dim.Kh; ++blkId.Kh) { + blkStart.Kh = blkId.Kh * block_dim.Kh; + blkSize.Kh = block_dim.Kh; + if (blkId.Kh == (grid_dim.Kh - 1)) { + blkSize.Kh = param.kernel_size - (grid_dim.Kh - 1) * block_dim.Kh; + } + blkEnd.Kh = blkStart.Kh + blkSize.Kh - 1; + + blkStart.Hi = blkStart.Ho / param.scale_factor - param.kernel_size_half + + blkStart.Kh; + blkEnd.Hi = + blkEnd.Ho / param.scale_factor - param.kernel_size_half + blkEnd.Kh; + blkSize.Hi = blkEnd.Hi - blkStart.Hi + 1; + + for (blkId.Kw = 0; blkId.Kw < grid_dim.Kw; ++blkId.Kw) { + blkStart.Kw = blkId.Kw * block_dim.Kw; + blkSize.Kw = block_dim.Kw; + if (blkId.Kw == (grid_dim.Kw - 1)) { + blkSize.Kw = param.kernel_size - (grid_dim.Kw - 1) * block_dim.Kw; + } + blkEnd.Kw = blkStart.Kw + blkSize.Kw - 1; + + blkStart.Wi = blkStart.Wo / param.scale_factor - + param.kernel_size_half + blkStart.Kw; + blkEnd.Wi = + blkEnd.Wo / param.scale_factor - param.kernel_size_half + blkEnd.Kw; + blkSize.Wi = blkEnd.Wi - blkStart.Wi + 1; + + // load input block from gdram2nram + // + // input_nram[ | input[ sample_idx, + // 0:blkSize.Hi-1, | blkStart.Hi + 0:blkSize.Hi-1, + // 0:blkSize.Wi-1, | blkStart.Wi + 0:blkSize.Wi-1, + // 0:blkSize.G-1 | blkStart.G + 0:blkSize.G-1 + // 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1] + // + // To skip out of bound indices: + // + // input_nram[ + // hi_start_local:hi_end_local, + // wi_start_local:wi_end_local, ...] + // = input[n, + // hi_start_global:hi_end_global, + // wi_start_global:wi_end_global, ...] + // + int hi_start_local = 0; + int hi_start_global = blkStart.Hi; + if (blkStart.Hi < 0) { + hi_start_local = -blkStart.Hi; + hi_start_global = 0; + } + int wi_start_local = 0; + int wi_start_global = blkStart.Wi; + if (blkStart.Wi < 0) { + wi_start_local = -blkStart.Wi; + wi_start_global = 0; + } + int hi_end_local = blkSize.Hi - 1; + int hi_end_global = blkEnd.Hi; + if (blkEnd.Hi > param.Hi - 1) { + hi_end_global = param.Hi - 1; + hi_end_local -= blkEnd.Hi - hi_end_global; + } + int wi_end_local = blkSize.Wi - 1; + int wi_end_global = blkEnd.Wi; + if (blkEnd.Wi > param.Wi - 1) { + wi_end_global = param.Wi - 1; + wi_end_local -= blkEnd.Wi - wi_end_global; + } + + int dst_offset = param.input_nram_stride_h * hi_start_local + + param.input_nram_stride_w * wi_start_local; + T *dst = input_nram + dst_offset; + + int src_offset = INDEX3(sample_idx, hi_start_global, wi_start_global, + blkStart.C, param.input_stride_n, + param.input_stride_h, param.input_stride_w); + T *src = input + src_offset; + + int input_seg_num_h = hi_end_local - hi_start_local + 1; + int input_seg_num_w = wi_end_local - wi_start_local + 1; + for (int i = 0; i < input_seg_num_h; ++i) { + loadStr3D(dst, src, blkSize.Cg, blkSize.G, input_seg_num_w, + param.input_nram_stride_g, param.input_nram_stride_w, + param.input_stride_g, param.input_stride_w); + dst += param.input_nram_stride_h; + src += param.input_stride_h; + } + + /* load mask block from gdram2nram + * + * mask_nram[ | mask[sample_idx, + * 0:blkSize.Ho-1 , | blkStart.Ho + 0:blkSize.Ho-1, + * 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1, + * 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1, + * 0:blkSize.Kh-1, | blkStart.Kh + 0:blkSize.Kh-1, + * 0:blkSize.Kw-1] | blkStart.Kw + 0:blkSize.Kw-1] + */ + src_offset = INDEX3(blkStart.Wo, blkStart.G, blkStart.Kh, blkStart.Kw, + param.mask_stride_w, param.mask_stride_g, + param.mask_stride_kh); + src_offset += sample_idx * param.mask_stride_n + + blkStart.Ho * param.mask_stride_h; + + for (int ho = 0; ho < blkSize.Ho; ++ho) { + dst = mask_nram + ho * param.mask_nram_stride_h; + src = mask + src_offset + ho * param.mask_stride_h; + + for (int wo = 0; wo < blkSize.Wo; ++wo) { + loadStr3D(dst, src, blkSize.Kw, blkSize.Kh, blkSize.G, + param.mask_nram_stride_kh, param.mask_nram_stride_g, + param.mask_stride_kh, param.mask_stride_g); + dst += param.mask_nram_stride_w; + src += param.mask_stride_w; + } + } + + // loop each pixel of the output block + for (int ho = 0; ho < blkSize.Ho; ++ho) { + int kernel_hi_start_global = (blkStart.Ho + ho) / param.scale_factor - + param.kernel_size_half + blkStart.Kh; + int kernel_hi_start_local = kernel_hi_start_global - blkStart.Hi; + + // int kernel_hi_end_global = kernel_hi_start_global + blkSize.Kh - 1; + // int kernel_hi_end_local = kernel_hi_end_global - blkStart.Hi; + + // exclude out of bound indices which should be ignored + int kh_min = hi_start_local - kernel_hi_start_local > 0 + ? hi_start_local - kernel_hi_start_local + : 0; + int kh_max = hi_end_local - kernel_hi_start_local < blkSize.Kh - 1 + ? hi_end_local - kernel_hi_start_local + : blkSize.Kh - 1; + + for (int wo = 0; wo < blkSize.Wo; ++wo) { + int kernel_wi_start_global = + (blkStart.Wo + wo) / param.scale_factor - + param.kernel_size_half + blkStart.Kw; + int kernel_wi_start_local = kernel_wi_start_global - blkStart.Wi; + + // exclude out of bound indices wwich should be ignored + int kw_min = wi_start_local - kernel_wi_start_local > 0 + ? wi_start_local - kernel_wi_start_local + : 0; + int kw_max = wi_end_local - kernel_wi_start_local < blkSize.Kw - 1 + ? wi_end_local - kernel_wi_start_local + : blkSize.Kw - 1; + + // output_nram[ho, wo, g, c] = sum(mask_nram[ho, wo, g, kh, kw] + // * input_nram[hi+kh, wi+kw, g, c], + // for (kh,kw) in [0:blkSize.Kw-1] x [0:blkSize.Kh-1]) + // + // sum(mask_nram[ho, wo, g, kh, kw] + // * input_nram[hi+kh, wi+kw, g, c], (kh,kw)) + // + T *mask_array = mask_nram + param.mask_nram_stride_h * ho + + param.mask_nram_stride_w * wo; + + for (int kh = kh_min; kh <= kh_max; ++kh) { + for (int kw = kw_min; kw <= kw_max; ++kw) { + T *src = + input_nram + + param.input_nram_stride_h * (kernel_hi_start_local + kh) + + param.input_nram_stride_w * (kernel_wi_start_local + kw); + + int mask_index = param.mask_nram_stride_kh * kh + kw; + + // mlutiply mask weight with channels for each channel group + T *sum = sum_array; + + for (int g = 0; g < blkSize.G; ++g) { + __bang_mul_const(sum, src, mask_array[mask_index], + param.block_Cg_NFU); + // + // NOTE: Since block_Cg_NFU >= block_Cg_stride, + // overlapped writing may occur on sum_array. + // So this loop must be executed in order to + // avoid data contamination, as shown below. + // + // |-----block_Cg_NFU---------| + // xxxxxxxxxxxxxxxxxxxxyyyzzzzz------------ + // |---block_Cg_stride---|^^^^^will be overwritten + // in the next iteration. + // + // x: actual data used, y: not used, z: overwritten + // + sum += param.input_nram_stride_g; + src += param.input_nram_stride_g; + mask_index += param.mask_nram_stride_g; + } // loop blk_G + + // add array[blk_G * blk_C] to output_nram + dst = output_nram + param.output_nram_stride_h * ho + + param.output_nram_stride_w * wo; + + __bang_add(dst, dst, sum_array, param.output_nram_stride_w); + } // end loop blk_Kw + } // end loop blk_Kh + } // end loop blk_Wo + } // end loop blk_Ho + } // end loop grid_dim.Kw + } // end loop grid_dim.Kh + + /* write output from nram2gdram + * + * output_nram[ | output[sample_idx, + * 0:blkSize.Ho-1, | blkStart.Ho + 0:blkSize.Ho-1, + * 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1, + * 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1, + * 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1] + */ + int dst_offset = INDEX3(sample_idx, blkStart.Ho, blkStart.Wo, blkStart.C, + param.output_stride_n, param.output_stride_h, + param.output_stride_w); + T *dst = output + dst_offset; + T *src = output_nram; + for (int i = 0; i < blkSize.Ho; ++i) { + storeStr3D(dst, src, blkSize.Cg, blkSize.G, blkSize.Wo, + param.output_stride_g, param.output_stride_w, + param.output_nram_stride_g, param.output_nram_stride_w); + dst += param.output_stride_h; + src += param.output_nram_stride_h; + } + } // end loop N, grid_dim.(Hi,Wi,G,Cg) +} + +template +__mlu_global__ void MLUBLOCKKernelCarafeForward( + const void *input, const void *mask, const CarafeForwardParam param, + const CarafeForwardBlockDim block_dim, const CarafeForwardGridDim grid_dim, + void *output) { + carafeForwardBLOCK((T *)input, (T *)mask, param, block_dim, grid_dim, + (T *)output); +} +} // namespace forward + +namespace backward { +template +__mlu_func__ void CarafeCompute(T *input, T *mask, T *grad_output, + T *grad_input, T *grad_mask, const int n, + const int hi, const int wi, const int c, + const int k_up, const int group, + const int scale) { + char *input_buff = nram_buf; + char *mask_buff = input_buff + NRAM_BLOCK; + char *grad_input_buff = mask_buff + NRAM_BLOCK; + char *grad_output_buff = grad_input_buff + NRAM_BLOCK; + char *grad_mask_buff = grad_output_buff + NRAM_BLOCK; + + int wo = wi * scale; + int ho = hi * scale; + int out_num = n * ho * wo * group; + int group_size = c / group; + int repeat = out_num / taskDim + (int)(taskId < out_num % taskDim); + int num_align = PAD_DOWN(NRAM_BLOCK / sizeof(T), NFU_ALIGN_SIZE / sizeof(T)); + int num_per_loop = group_size / num_align; + int rem_for_loop = group_size % num_align; + int rem_for_loop_align = PAD_UP(rem_for_loop, NFU_ALIGN_SIZE / sizeof(T)); + for (int k = 0; k < repeat; k++) { + int iter = k * taskDim + taskId; + int group_k = iter % group; + int w_k = (iter / group) % wo; + int h_k = (iter / wo / group) % ho; + int n_k = (iter / ho / wo / group) % n; + int h_i = h_k / scale; + int w_i = w_k / scale; + int start_h = h_i - ((k_up - 1) / 2); + int end_h = h_i + ((k_up - 1) / 2) + 1; + int start_w = w_i - ((k_up - 1) / 2); + int end_w = w_i + ((k_up - 1) / 2) + 1; + T *base_mask = (T *)mask + n_k * ho * wo * group * k_up * k_up + + h_k * wo * group * k_up * k_up + w_k * group * k_up * k_up + + group_k * k_up * k_up; + T *base_grad_mask = (T *)grad_mask + n_k * ho * wo * group * k_up * k_up + + h_k * wo * group * k_up * k_up + + w_k * group * k_up * k_up + group_k * k_up * k_up; + + __bang_write_zero((T *)grad_input_buff, NRAM_BLOCK / sizeof(T)); + __bang_write_zero((T *)grad_mask_buff, NRAM_BLOCK / sizeof(T)); + __bang_write_zero((T *)grad_output_buff, NRAM_BLOCK / sizeof(T)); + + __memcpy((T *)mask_buff, (T *)base_mask, k_up * k_up * sizeof(T), + GDRAM2NRAM); + for (int i = 0; i < num_per_loop; i++) { + __bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T)); + T *base_grad_output = (T *)grad_output + n_k * ho * wo * c + + h_k * wo * c + w_k * c + group_k * group_size + + i * num_align; + __memcpy((T *)grad_output_buff, (T *)base_grad_output, + num_align * sizeof(T), GDRAM2NRAM); + for (int ih = start_h; ih < end_h; ih++) { + for (int iw = start_w; iw < end_w; iw++) { + if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) { + continue; + } + int mask_ih = ih - h_i + (k_up - 1) / 2; + int mask_iw = iw - w_i + (k_up - 1) / 2; + int mask_index = mask_ih * k_up + mask_iw; + int input_index = n_k * hi * wi * c + ih * wi * c + iw * c + + group_k * group_size + i * num_align; + T *base_input = (T *)input + input_index; + T *base_grad_input = (T *)grad_input + input_index; + __memcpy((T *)input_buff, (T *)base_input, num_align * sizeof(T), + GDRAM2NRAM); + __bang_mul_const((T *)grad_input_buff, (T *)grad_output_buff, + ((T *)mask_buff)[mask_index], num_align); + __bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input, + (T *)grad_input_buff, num_align); + __bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff, + num_align); + + __bang_sumpool((T *)input_buff, (T *)input_buff, + NFU_ALIGN_SIZE / sizeof(T), + num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, + num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1); + + __bang_reduce_sum((T *)input_buff, (T *)input_buff, + NFU_ALIGN_SIZE / sizeof(T)); + ((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0]; + } + } + } + if (rem_for_loop) { + __bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T)); + T *base_grad_output = (T *)grad_output + n_k * ho * wo * c + + h_k * wo * c + w_k * c + group_k * group_size + + num_per_loop * num_align; + __memcpy((T *)grad_output_buff, (T *)base_grad_output, + rem_for_loop * sizeof(T), GDRAM2NRAM); + for (int ih = start_h; ih < end_h; ih++) { + for (int iw = start_w; iw < end_w; iw++) { + if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) { + continue; + } + int mask_ih = ih - h_i + (k_up - 1) / 2; + int mask_iw = iw - w_i + (k_up - 1) / 2; + int mask_index = mask_ih * k_up + mask_iw; + int input_index = n_k * hi * wi * c + ih * wi * c + iw * c + + group_k * group_size + num_per_loop * num_align; + T *base_input = (T *)input + input_index; + T *base_grad_input = (T *)grad_input + input_index; + __memcpy((T *)input_buff, (T *)base_input, rem_for_loop * sizeof(T), + GDRAM2NRAM); + __bang_mul_const((T *)grad_input_buff, (T *)grad_output_buff, + ((T *)mask_buff)[mask_index], rem_for_loop_align); + __bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input, + (T *)grad_input_buff, rem_for_loop); + __bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff, + rem_for_loop_align); + + __bang_sumpool( + (T *)input_buff, (T *)input_buff, NFU_ALIGN_SIZE / sizeof(T), + rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, + rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1); + __bang_reduce_sum((T *)input_buff, (T *)input_buff, + NFU_ALIGN_SIZE / sizeof(T)); + + ((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0]; + } + } + } + __memcpy((T *)base_grad_mask, (T *)grad_mask_buff, k_up * k_up * sizeof(T), + NRAM2GDRAM); + } +} + +template +__mlu_global__ void MLUUnion1KernelCarafeBackward( + const void *input, const void *mask, const void *grad_output, + void *grad_input, void *grad_mask, const int n, const int hi, const int wi, + const int c, const int k_up, const int group, const int scale) { + CarafeCompute((T *)input, (T *)mask, (T *)grad_output, (T *)grad_input, + (T *)grad_mask, n, hi, wi, c, k_up, group, scale); +} +} // namespace backward + +void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *input, const void *mask, + const CarafeForwardParam ¶m, + const CarafeForwardBlockDim &block_dim, + const CarafeForwardGridDim &grid_dim, void *output) { + if (d_type == CNRT_FLOAT16) { + forward::MLUBLOCKKernelCarafeForward<<>>( + input, mask, param, block_dim, grid_dim, output); + } else { + forward::MLUBLOCKKernelCarafeForward<<>>( + input, mask, param, block_dim, grid_dim, output); + } +} + +void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t dtype, + const void *input, const void *mask, + const void *grad_output, void *grad_input, + void *grad_mask, const int n, const int hi, + const int wi, const int c, const int k_up, + const int group, const int scale) { + if (dtype == CNRT_FLOAT16) { + backward::MLUUnion1KernelCarafeBackward + <<>>(input, mask, grad_output, grad_input, + grad_mask, n, hi, wi, c, k_up, group, scale); + } else { + backward::MLUUnion1KernelCarafeBackward + <<>>(input, mask, grad_output, grad_input, + grad_mask, n, hi, wi, c, k_up, group, scale); + } +} diff --git a/mmcv/ops/csrc/common/mlu/carafe_utils.hpp b/mmcv/ops/csrc/common/mlu/carafe_utils.hpp new file mode 100644 index 0000000000..09ca60ab11 --- /dev/null +++ b/mmcv/ops/csrc/common/mlu/carafe_utils.hpp @@ -0,0 +1,95 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#ifndef CARAFE_UTILS_HPP_ +#define CARAFE_UTILS_HPP_ + +#define NRAM_ALIGN_SIZE 64 + +struct CarafeForwardParam { + int N; // batch size + int Hi; // input height + int Wi; // input width + int Ci; // input channels + int Ho; // output height + int Wo; // output width + int Cg; // channels per group + + int kernel_size; // kernel_size + int group_size; // group_size + int scale_factor; // scale_factor + int kernel_size_half; // kernel half size (K-1)/2 + int kernel_size_sq; // square of kernel size + + int dtype_size; // size of tensor data type + + // Host arrays' geometry + int input_stride_g; + int input_stride_w; + int input_stride_h; + int input_stride_n; + int input_size; + int mask_stride_kh; + int mask_stride_g; + int mask_stride_w; + int mask_stride_h; + int mask_stride_n; + int mask_size; + int output_stride_g; + int output_stride_w; + int output_stride_h; + int output_stride_n; + int output_size; + + // NRAM arrays' geometry + int input_nram_stride_g; + int input_nram_stride_w; + int input_nram_stride_h; + int input_nram_size; + int mask_nram_stride_kh; + int mask_nram_stride_g; + int mask_nram_stride_w; + int mask_nram_stride_h; + int mask_nram_size; + int output_nram_stride_g; + int output_nram_stride_w; + int output_nram_stride_h; + int output_nram_size; + + // for address/compute alignment + int align_size_NRAM; // for addressing on NRAM + int align_size_NFU; // for NFU operation length + int block_Cg_NFU; // for bang_mul_const + + int job_num; // total job number +}; + +struct CarafeForwardBlockDim { + int Ho; // block size of output height + int Wo; // block size of output width + int Kh; // block size of kernel height + int Kw; // block size of kernel width + int G; // block size of groups + int Cg; // block size of channels within a group + int Hi; // block size of input height + int Wi; // block size of input width +}; + +struct CarafeForwardGridDim { + int Ho; // number of blocks of output height + int Wo; + int Kh; + int Kw; + int G; + int Cg; +}; + +#endif // CARAFE_UTILS_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp index 669a9d78e0..89d0151096 100644 --- a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp +++ b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp @@ -35,6 +35,148 @@ #define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y)) +/*! + * @brief loads data from global DRAM to NRAM with 2D pattern. + * + * @param[out] dst + * Pointer to NRAM that stores dst data. + * @param[in] src + * Pointer to global DRAM that stores src data. + * @param[in] size + * The byte size of segment in the lower dimension. + * @param[in] dst_str + * The data stride in bytes between segments in the lower dimension of dst. + * @param[in] src_str + * The data stride in bytes between segments in the lower dimension of src. + * @param[in] seg_num + * The total count of data segments in the lower dimension. + */ +template +__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str, + const int src_str, const int seg_num) { + if (dst_str == src_str && size == src_str) { + __memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM); + } else if ((size == src_str || src_str <= dst_str) && + src_str * sizeof(T) <= 512) { + // gather data less than 512Bytes to improve IO efficiency + T *tmp = (T *)dst + (dst_str - src_str) * seg_num; + __memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T), + GDRAM2NRAM); + if (dst_str != src_str) { + __memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T), + src_str * sizeof(T), seg_num - 1); + } + } else { + __memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T), + src_str * sizeof(T), seg_num - 1); + } +} + +/*! + * @brief loads data from global DRAM to NRAM with 3D pattern. + * + * @param[out] dst + * Pointer to NRAM that stores dst data. + * @param[in] src + * Pointer to global DRAM that stores src data. + * @param[in] size + * The byte size of segment in the lowest dimension. + * @param[in] seg_num_in + * The total count of data segments in the lowest dimension. + * @param[in] seg_num_out + * The total count of data segments in the middle dimension. + * @param[in] dst_str_in + * The data stride in bytes between segments in the lowest dimension of dst. + * @param[in] dst_str_out + * The data stride in bytes between segments in the middle dimension of dst. + * @param[in] src_str_in + * The data stride in bytes between segments in the lowest dimension of src. + * @param[in] src_str_out + * The data stride in bytes between segments in the middle dimension of src. + */ +template +__mlu_func__ void loadStr3D(T *dst, T *src, const int size, + const int seg_num_in, const int seg_num_out, + const int dst_str_in, const int dst_str_out, + const int src_str_in, const int src_str_out) { + T *tmp_dst = dst; + T *tmp_src = src; + + for (int i = 0; i < seg_num_out; ++i) { + loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in); + tmp_src += src_str_out; + tmp_dst += dst_str_out; + } +} + +/*! + * @brief stores data from NRAM to global DRAM with 2D pattern. + * + * @param[out] dst + * Pointer to global DRAM that stores dst data. + * @param[in] src + * Pointer to NRAM that stores src data. + * @param[in] size + * The byte size of segment in the lower dimension. + * @param[in] dst_str + * The data stride in bytes between segments in the lower dimension of dst. + * @param[in] src_str + * The data stride in bytes between segments in the lower dimension of src. + * @param[in] seg_num + * The total count of data segments in the lower dimension. + */ +template +__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num, + const int dst_str, const int src_str) { + if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) { + // gather data less than 512Bytes to improve IO efficiency + if (dst_str != src_str) { + __memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T), + src_str * sizeof(T), seg_num - 1); + } + __memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM); + } else { + __memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T), + src_str * sizeof(T), seg_num - 1); + } +} + +/*! + * @brief stores data from NRAM to global DRAM with 3D pattern. + * + * @param[out] dst + * Pointer to global DRAM that stores dst data. + * @param[in] src + * Pointer to NRAM that stores src data. + * @param[in] size + * The byte size of segment in the lowest dimension. + * @param[in] seg_num_in + * The total count of data segments in the lowest dimension. + * @param[in] seg_num_out + * The total count of data segments in the middle dimension. + * @param[in] dst_str_in + * The data stride in bytes between segments in the lowest dimension of dst. + * @param[in] dst_str_out + * The data stride in bytes between segments in the middle dimension of dst. + * @param[in] src_str_in + * The data stride in bytes between segments in the lowest dimension of src. + * @param[in] src_str_out + * The data stride in bytes between segments in the middle dimension of src. + */ +template +__mlu_func__ void storeStr3D(T *dst, T *src, const int size, + const int seg_num_in, const int seg_num_out, + const int dst_str_in, const int dst_str_out, + const int src_str_in, const int src_str_out) { + T *tmp_dst = dst; + T *tmp_src = src; + for (int i = 0; i < seg_num_out; ++i) { + storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in); + tmp_src += src_str_out; + tmp_dst += dst_str_out; + } +} + /*! * @brief Converts int32 to float32 data type. * diff --git a/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp index 72dbe5880b..65bf3856a8 100644 --- a/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp +++ b/mmcv/ops/csrc/common/pytorch_mlu_helper.hpp @@ -21,8 +21,10 @@ #define PAD_DOWN(x, y) (((x) / (y)) * (y)) +#define CEIL_DIV(x, y) (((x) + (y)-1) / (y)) + #define CEIL_ALIGN(x, y) (((x) + (y)-1) / (y) * (y)) -#endif +#endif // MMCV_WITH_MLU #endif // PYTORCH_MLU_HELPER_HPP_ diff --git a/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp new file mode 100644 index 0000000000..25e0b85d12 --- /dev/null +++ b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp @@ -0,0 +1,429 @@ +/************************************************************************* + * Copyright (C) 2022 Cambricon. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS + * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. + * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY + * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, + * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE + * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + *************************************************************************/ +#include "carafe_utils.hpp" +#include "pytorch_device_registry.hpp" +#include "pytorch_mlu_helper.hpp" + +void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, const cnrtDataType_t d_type, + const void *input, const void *mask, + const CarafeForwardParam ¶m, + const CarafeForwardBlockDim &block_dim, + const CarafeForwardGridDim &grid_dim, void *output); + +void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, + cnrtQueue_t queue, cnrtDataType_t dtype, + const void *input, const void *mask, + const void *grad_output, void *grad_input, + void *grad_mask, const int n, const int hi, + const int wi, const int c, const int k_up, + const int group, const int scale); + +// Get total NRAM usage and set strides of NRAM arrays. +static void getNramUsage(CarafeForwardParam *param, + CarafeForwardBlockDim *block_dim, int *nram_usage) { + // input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg] + block_dim->Hi = CEIL_DIV(block_dim->Ho, param->scale_factor) + 1; + block_dim->Wi = CEIL_DIV(block_dim->Wo, param->scale_factor) + 1; + + param->input_nram_stride_g = PAD_UP(block_dim->Cg, param->align_size_NRAM); + param->input_nram_stride_w = param->input_nram_stride_g * block_dim->G; + param->input_nram_stride_h = + (block_dim->Wi + block_dim->Kw - 1) * param->input_nram_stride_w; + param->input_nram_size = + (block_dim->Hi + block_dim->Kh - 1) * param->input_nram_stride_h; + + // mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw] + param->mask_nram_stride_kh = block_dim->Kw; + param->mask_nram_stride_g = block_dim->Kh * param->mask_nram_stride_kh; + param->mask_nram_stride_w = block_dim->G * param->mask_nram_stride_g; + param->mask_nram_stride_h = block_dim->Wo * param->mask_nram_stride_w; + param->mask_nram_size = + PAD_UP(block_dim->Ho * param->mask_nram_stride_h, param->align_size_NRAM); + + // output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)] + param->output_nram_stride_g = param->input_nram_stride_g; + param->output_nram_stride_w = + PAD_UP(param->input_nram_stride_w, param->align_size_NFU); + param->output_nram_stride_h = block_dim->Wo * param->output_nram_stride_w; + param->output_nram_size = block_dim->Ho * param->output_nram_stride_h; + + // sum_array[blkDim_(G*Cg)] + + // ensure the last mul_const on Cg does not exceed memory boundary + int sum_array_size_bang_mul_const = + (block_dim->G - 1) * param->input_nram_stride_g + + PAD_UP(param->input_nram_stride_g, param->align_size_NFU); + + int sum_array_size = + std::max(param->output_nram_stride_w, sum_array_size_bang_mul_const); + + *nram_usage = param->input_nram_size + param->mask_nram_size + + param->output_nram_size + sum_array_size; +} + +// Policy Function for Forward +static void genPolicyForward(CarafeForwardParam *param, + CarafeForwardBlockDim *block_dim, + CarafeForwardGridDim *grid_dim, cnrtDim3_t *k_dim, + cnrtFunctionType_t *k_type) { + // device info + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + auto core_num = core_dim * cluster_num; + + // maximum NRAM size as the number of + auto max_nram_size = + torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore) / param->dtype_size; + + // determine grid and block dimensions + + // set initial values for block_dim and grid_dim + block_dim->Ho = param->Ho; + block_dim->Wo = param->Wo; + block_dim->Kh = param->kernel_size; + block_dim->Kw = param->kernel_size; + block_dim->G = param->group_size; + block_dim->Cg = param->Cg; + + grid_dim->Ho = 1; + grid_dim->Wo = 1; + grid_dim->Kh = 1; + grid_dim->Kw = 1; + grid_dim->G = 1; + grid_dim->Cg = 1; + + // decrease the block size to fit in the NRAM. + int nram_usage = 0; + while (true) { + getNramUsage(param, block_dim, &nram_usage); + + if (nram_usage > max_nram_size) { + // decrease Ho + // decrease block_Ho and block_Wo evenly + // so that the block is close to a square. + if (block_dim->Ho > 1 && block_dim->Ho >= block_dim->Wo) { + grid_dim->Ho += 1; + block_dim->Ho = CEIL_DIV(param->Ho, grid_dim->Ho); + } else if (block_dim->Wo > 1 && block_dim->Wo > block_dim->Ho) { + // decrease Wo + grid_dim->Wo += 1; + block_dim->Wo = CEIL_DIV(param->Wo, grid_dim->Wo); + } else if (block_dim->Kh > 1) { + // decrease Kh + grid_dim->Kh += 1; + block_dim->Kh = CEIL_DIV(param->kernel_size, grid_dim->Kh); + // reset Hi, Wi to maximize NRAM usage + grid_dim->Ho = 1; + block_dim->Ho = param->Ho; + grid_dim->Wo = 1; + block_dim->Wo = param->Wo; + } else if (block_dim->Kw > 1) { + // decrease Kw + grid_dim->Kw += 1; + block_dim->Kw = CEIL_DIV(param->kernel_size, grid_dim->Kw); + // reset Kh + grid_dim->Kh = 1; + block_dim->Kh = param->kernel_size; + } else if (block_dim->G > 1) { + // decrease G + grid_dim->G += 1; + block_dim->G = CEIL_DIV(param->group_size, grid_dim->G); + // reset Kw + grid_dim->Kw = 1; + block_dim->Kw = param->kernel_size; + } else if (block_dim->Cg > 1) { + // decrease block_Cg + // This is done in the last since c is the continuous dim + // (input layout is NHWC) and large c can improve + // IO & compute efficiency. + grid_dim->Cg += 1; + block_dim->Cg = CEIL_DIV(param->Cg, grid_dim->Cg); + // reset G + grid_dim->G = 1; + block_dim->G = param->group_size; + } else { + // the block volume is one now, cannot decrease the block size anymore! + // this situation should not occur. + break; + } + } else { + break; + } + } + + // define parameters depending on block_dim, grid_dim + param->block_Cg_NFU = PAD_UP(block_dim->Cg, param->align_size_NFU); + + // define host arrays' strides + + // input[N,H,W,G,Cg] + param->input_stride_g = param->Cg; + param->input_stride_w = param->Ci; + param->input_stride_h = param->Wi * param->input_stride_w; + param->input_stride_n = param->Hi * param->input_stride_h; + // mask[N,Ho,Wo,G,Kh,Kw] + param->mask_stride_kh = param->kernel_size; + param->mask_stride_g = param->kernel_size * param->mask_stride_kh; + param->mask_stride_w = param->group_size * param->mask_stride_g; + param->mask_stride_h = param->Wo * param->mask_stride_w; + param->mask_stride_n = param->Ho * param->mask_stride_h; + // output[N,Ho,Wo,G,Cg] + param->output_stride_g = param->Cg; + param->output_stride_w = param->Ci; + param->output_stride_h = param->Wo * param->output_stride_w; + param->output_stride_n = param->Ho * param->output_stride_h; + + param->job_num = + param->N * grid_dim->Ho * grid_dim->Wo * grid_dim->G * grid_dim->Cg; + + // determine task type and dims + *k_type = CNRT_FUNC_TYPE_BLOCK; + k_dim->x = std::min(param->job_num, static_cast(core_num)); + k_dim->y = 1; + k_dim->z = 1; +} + +void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, + Tensor rinput, Tensor routput, Tensor rmask, + Tensor output, const int kernel_size, + const int group_size, + const int scale_factor) { + const int batch_size = output.size(0); + const int channels = output.size(1); + const int ho = output.size(2); + const int wo = output.size(3); + + // check tensor data type + TORCH_CHECK( + input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, + "Data type of input should be Float or Half. But now input type is ", + input.scalar_type(), "."); + + TORCH_CHECK(mask.scalar_type() == input.scalar_type(), + "Data types of input and mask should be the same, but got ", + input.scalar_type(), " and ", mask.scalar_type()); + + // check number of dimensions + TORCH_CHECK(input.dim() == 4, "input should be a 4-D tensor, but has ", + input.dim(), "D."); + TORCH_CHECK(mask.dim() == 4, "mask should be a 4-D tensor, but has ", + input.dim(), "D."); + + // return fast on zero-element tensor + if (output.numel() == 0) { + output = at::zeros({batch_size, channels, ho, wo}, output.options()); + return; + } + + // set param + CarafeForwardParam param; + param.N = input.size(0); + param.Ci = input.size(1); + param.Hi = input.size(2); + param.Wi = input.size(3); + + param.kernel_size = kernel_size; + param.group_size = group_size; + param.scale_factor = scale_factor; + param.Cg = param.Ci / group_size; + param.dtype_size = input.itemsize(); + param.align_size_NRAM = NRAM_ALIGN_SIZE / param.dtype_size; + param.align_size_NFU = NFU_ALIGN_SIZE / param.dtype_size; + param.kernel_size_sq = param.kernel_size * param.kernel_size; + param.kernel_size_half = (param.kernel_size - 1) / 2; + param.Ho = param.Hi * param.scale_factor; + param.Wo = param.Wi * param.scale_factor; + + // generate policy + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + CarafeForwardBlockDim block_dim; + CarafeForwardGridDim grid_dim; + + genPolicyForward(¶m, &block_dim, &grid_dim, &k_dim, &k_type); + + // convert NCHW to NHWC + auto memory_format_input_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); + auto rinput_ = + torch_mlu::cnnl::ops::cnnl_contiguous(input, memory_format_input_nhwc); + + auto memory_format_mask_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim()); + auto rmask_ = + torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc); + + auto memory_format_output_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(output.dim()); + auto routput_ = + torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc); + + // get ptr of tensors + auto input_impl = torch_mlu::getMluTensorImpl(rinput_); + auto input_ptr = input_impl->cnnlMalloc(); + auto mask_impl = torch_mlu::getMluTensorImpl(rmask_); + auto mask_ptr = mask_impl->cnnlMalloc(); + auto output_impl = torch_mlu::getMluTensorImpl(routput_); + auto output_ptr = output_impl->cnnlMalloc(); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get dtype of input + cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); + + // launch kernel + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + CNLOG(INFO) << "Launch Kernel KernelCarafeForward<<>>"; + + KernelCarafeForward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, param, + block_dim, grid_dim, output_ptr); + + // copy output from NHWC back into NCHW + rinput.copy_(rinput_); + output.copy_(routput_); +} + +// Policy Function for Backward +static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { + // set Union1 Job + *k_type = CNRT_FUNC_TYPE_UNION1; + k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); + k_dim->z = 1; +} + +void CARAFEBackwardMLUKernelLauncher( + const Tensor grad_output, const Tensor rinput, const Tensor mask, + Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input, + Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask, + const int kernel_size, const int group_size, const int scale_factor) { + const int batch_size = rinput.size(0); + const int channels = rinput.size(1); + const int hi = rinput.size(2); + const int wi = rinput.size(3); + + // data type check + TORCH_CHECK(grad_output.scalar_type() == at::kFloat || + grad_output.scalar_type() == at::kHalf, + "grad_output type should be Float or Half, got ", + grad_output.scalar_type()); + TORCH_CHECK(grad_output.scalar_type() == mask.scalar_type(), + "mask should have the same type as grad_output"); + + // dim check + TORCH_CHECK(grad_output.dim() == 4, "grad_output should be a 4d tensor, got ", + grad_output.dim(), "D"); + + // param check + TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ", + kernel_size); + + // set task dimension + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + policyFuncBackward(&k_dim, &k_type); + + // convert NCHW to NHWC + auto memory_format_input_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim()); + auto rinput_ = + torch_mlu::cnnl::ops::cnnl_contiguous(rinput, memory_format_input_nhwc); + + auto memory_format_mask_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(mask.dim()); + auto rmask_ = + torch_mlu::cnnl::ops::cnnl_contiguous(mask, memory_format_mask_nhwc); + + auto memory_format_grad_output_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_output.dim()); + auto rgrad_output_ = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_output, memory_format_grad_output_nhwc); + + auto memory_format_grad_input_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_input.dim()); + auto rgrad_input_ = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_input, memory_format_grad_input_nhwc) + .zero_(); + + auto memory_format_grad_mask_nhwc = + torch_mlu::cnnl::ops::get_channels_last_memory_format(grad_mask.dim()); + auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous( + grad_mask, memory_format_grad_mask_nhwc); + + // get compute queue + auto queue = torch_mlu::getCurQueue(); + + // get ptr of tensors + auto input_impl = torch_mlu::getMluTensorImpl(rinput_); + auto input_ptr = input_impl->cnnlMalloc(); + auto mask_impl = torch_mlu::getMluTensorImpl(rmask_); + auto mask_ptr = mask_impl->cnnlMalloc(); + auto grad_output_impl = torch_mlu::getMluTensorImpl(rgrad_output_); + auto grad_output_ptr = grad_output_impl->cnnlMalloc(); + auto grad_input_impl = torch_mlu::getMluTensorImpl(rgrad_input_); + auto grad_input_ptr = grad_input_impl->cnnlMalloc(); + auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_); + auto grad_mask_ptr = grad_mask_impl->cnnlMalloc(); + + // get dtype of grad_output + cnrtDataType_t d_type = torch_mlu::toCnrtDtype(grad_output.dtype()); + auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); + + CNLOG(INFO) << "Launch Kernel KernelCarafeBackward<<>>"; + + // launch kernel + KernelCarafeBackward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, + grad_output_ptr, grad_input_ptr, grad_mask_ptr, + batch_size, hi, wi, channels, kernel_size, group_size, + scale_factor); + + // copy output from NHWC back into NCHW + grad_input.copy_(rgrad_input_); + grad_mask.copy_(rgrad_mask_); +} + +void carafe_forward_mlu(Tensor features, Tensor masks, Tensor rfeatures, + Tensor routput, Tensor rmasks, Tensor output, + int kernel_size, int group_size, int scale_factor) { + CARAFEForwardMLUKernelLauncher(features, masks, rfeatures, routput, rmasks, + output, kernel_size, group_size, scale_factor); +} + +void carafe_backward_mlu(Tensor top_grad, Tensor rfeatures, Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, + Tensor rbottom_grad, Tensor rmask_grad, + Tensor bottom_grad, Tensor mask_grad, int kernel_size, + int group_size, int scale_factor) { + CARAFEBackwardMLUKernelLauncher(top_grad, rfeatures, masks, rtop_grad, + rbottom_grad_hs, rbottom_grad, rmask_grad, + bottom_grad, mask_grad, kernel_size, + group_size, scale_factor); +} + +void carafe_forward_impl(Tensor features, Tensor masks, Tensor rfeatures, + Tensor routput, Tensor rmasks, Tensor output, + int kernel_size, int group_size, int scale_factor); + +void carafe_backward_impl(Tensor top_grad, Tensor rfeatures, Tensor masks, + Tensor rtop_grad, Tensor rbottom_grad_hs, + Tensor rbottom_grad, Tensor rmask_grad, + Tensor bottom_grad, Tensor mask_grad, int kernel_size, + int group_size, int scale_factor); + +REGISTER_DEVICE_IMPL(carafe_forward_impl, MLU, carafe_forward_mlu); +REGISTER_DEVICE_IMPL(carafe_backward_impl, MLU, carafe_backward_mlu); diff --git a/tests/data/for_carafe/carafe_feat.bin b/tests/data/for_carafe/carafe_feat.bin new file mode 100755 index 0000000000..9402a7cdac Binary files /dev/null and b/tests/data/for_carafe/carafe_feat.bin differ diff --git a/tests/data/for_carafe/carafe_feat_grad.bin b/tests/data/for_carafe/carafe_feat_grad.bin new file mode 100755 index 0000000000..d195bd1803 --- /dev/null +++ b/tests/data/for_carafe/carafe_feat_grad.bin @@ -0,0 +1,33 @@ +A A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJAA A>AdA"~A܈A^aAAJA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhA +A +A0A6AE|AyAAbANhAA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AA)AA%ՈAZwA;AAeA,AWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAAWlAAA8Ag֊AA A^xAA8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`A8DAքA{AFܗAAחAATiA`ANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOANjAALlAlzA٦lA.A A*AYOAAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*AAA AuqA^ApAΚAA*A˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA˼AAhZAgAĒAB\AyAAЌA \ No newline at end of file diff --git a/tests/data/for_carafe/carafe_mask.bin b/tests/data/for_carafe/carafe_mask.bin new file mode 100755 index 0000000000..18dc01b0a4 Binary files /dev/null and b/tests/data/for_carafe/carafe_mask.bin differ diff --git a/tests/data/for_carafe/carafe_mask_grad.bin b/tests/data/for_carafe/carafe_mask_grad.bin new file mode 100755 index 0000000000..f6f93dc68d Binary files /dev/null and b/tests/data/for_carafe/carafe_mask_grad.bin differ diff --git a/tests/data/for_carafe/carafe_output.bin b/tests/data/for_carafe/carafe_output.bin new file mode 100755 index 0000000000..5400527020 Binary files /dev/null and b/tests/data/for_carafe/carafe_output.bin differ diff --git a/tests/test_ops/test_carafe.py b/tests/test_ops/test_carafe.py index 6b545a0276..02d00f1ff8 100644 --- a/tests/test_ops/test_carafe.py +++ b/tests/test_ops/test_carafe.py @@ -1,7 +1,11 @@ # Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +import pytest import torch from torch.autograd import gradcheck +from mmcv.utils import IS_CUDA_AVAILABLE, IS_MLU_AVAILABLE + class TestCarafe: @@ -26,3 +30,56 @@ def test_carafe_gradcheck(self): 2, 100, 6, 6, requires_grad=True, device='cuda').sigmoid().double() gradcheck(CARAFE(5, 4, 2), (feat, mask), atol=1e-4, eps=1e-4) + + @pytest.mark.parametrize('device', [ + pytest.param( + 'cuda', + marks=pytest.mark.skipif( + not IS_CUDA_AVAILABLE, reason='requires CUDA support')), + pytest.param( + 'mlu', + marks=pytest.mark.skipif( + not IS_MLU_AVAILABLE, reason='requires MLU support')) + ]) + def test_carafe_allclose(self, device): + try: + from mmcv.ops import CARAFE + except ModuleNotFoundError: + pytest.skip('test requires compilation') + + np_feat = np.fromfile( + 'tests/data/for_carafe/carafe_feat.bin', dtype=np.float32) + np_mask = np.fromfile( + 'tests/data/for_carafe/carafe_mask.bin', dtype=np.float32) + np_output = np.fromfile( + 'tests/data/for_carafe/carafe_output.bin', dtype=np.float32) + np_feat_grad = np.fromfile( + 'tests/data/for_carafe/carafe_feat_grad.bin', dtype=np.float32) + np_mask_grad = np.fromfile( + 'tests/data/for_carafe/carafe_mask_grad.bin', dtype=np.float32) + + np_feat = np_feat.reshape((2, 64, 3, 3)) + np_mask = np_mask.reshape((2, 100, 6, 6)) + np_output = np_output.reshape((2, 64, 6, 6)) + np_feat_grad = np_feat_grad.reshape((2, 64, 3, 3)) + np_mask_grad = np_mask_grad.reshape((2, 100, 6, 6)) + + feat = torch.tensor( + np_feat, dtype=torch.float, device=device, requires_grad=True) + mask = torch.tensor( + np_mask, dtype=torch.float, device=device, requires_grad=True) + + carafe = CARAFE(5, 4, 2) + + output = carafe(feat, mask) + output.backward(torch.ones_like(output)) + assert np.allclose( + output.data.type(torch.float).cpu().numpy(), np_output, atol=1e-3) + assert np.allclose( + feat.grad.data.type(torch.float).cpu().numpy(), + np_feat_grad, + atol=1e-3) + assert np.allclose( + mask.grad.data.type(torch.float).cpu().numpy(), + np_mask_grad, + atol=1e-3)