diff --git a/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu b/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu deleted file mode 100644 index 8dd6a8e582..0000000000 --- a/mmcv/ops/csrc/common/mlu/carafe_mlu_kernel.mlu +++ /dev/null @@ -1,552 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#include "carafe_utils.hpp" -#include "common_mlu_helper.hpp" - -#define INDEX3(n, h, w, c, strN, strH, strW) \ - (strN) * (n) + (strH) * (h) + (strW) * (w) + (c) - -#define NRAM_BLOCK PAD_DOWN(MAX_NRAM_SIZE / 5, NRAM_ALIGN_SIZE) - -__nram__ char nram_buf[MAX_NRAM_SIZE]; - -namespace forward { -struct BlockId { - int Ho; - int Wo; - int G; - int Cg; - int Kh; - int Kw; - int Hi; - int Wi; -}; - -// start indices of block -struct BlockStart { - int Ho; - int Wo; - int G; - int Cg; - int Kh; - int Kw; - int Hi; - int Wi; - int C; -}; - -struct BlockEnd { - int Ho; - int Wo; - int Kh; - int Kw; - int Hi; - int Wi; -}; - -struct BlockSize { - int Ho; - int Wo; - int G; - int Cg; - int Kh; - int Kw; - int Hi; - int Wi; -}; - -template -__mlu_func__ void carafeForwardBLOCK(T *input, T *mask, - const CarafeForwardParam param, - const CarafeForwardBlockDim block_dim, - const CarafeForwardGridDim grid_dim, - T *output) { - // data block info - BlockId blkId; - BlockStart blkStart; - BlockEnd blkEnd; - BlockSize blkSize; - - // set pointers on NRAM arrays - - // input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_(G*Cg)] - T *input_nram = (T *)nram_buf; - - // mask_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Kh*Kw)] - T *mask_nram = input_nram + param.input_nram_size; - - // output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)] - T *output_nram = mask_nram + param.mask_nram_size; - - // sum_array[blkDim_(G*Cg)] - T *sum_array = output_nram + param.output_nram_size; - - /* ===== loop over N, grid_dim(Ho,Wo,G,Cg) - * iterations are distributed over computing cores - */ - for (int loop_index = taskId; loop_index < param.job_num; - loop_index += taskDim) { - // block idx - blkId.Cg = loop_index; - blkId.G = blkId.Cg / grid_dim.Cg; - blkId.Wo = blkId.G / grid_dim.G; - blkId.Ho = blkId.Wo / grid_dim.Wo; - int sample_idx = blkId.Ho / grid_dim.Ho; - - blkId.Cg %= grid_dim.Cg; - blkId.G %= grid_dim.G; - blkId.Wo %= grid_dim.Wo; - blkId.Ho %= grid_dim.Ho; - - // block starting indices - blkStart.Ho = blkId.Ho * block_dim.Ho; - blkStart.Wo = blkId.Wo * block_dim.Wo; - blkStart.G = blkId.G * block_dim.G; - blkStart.Cg = blkId.Cg * block_dim.Cg; - blkStart.C = blkStart.G * param.Cg + blkStart.Cg; - - // block size - blkSize.Ho = block_dim.Ho; - blkSize.Wo = block_dim.Wo; - blkSize.G = block_dim.G; - blkSize.Cg = block_dim.Cg; - - // take care of blocks near the end of each dimension - if (blkId.Ho == (grid_dim.Ho - 1)) { - blkSize.Ho = param.Ho - (grid_dim.Ho - 1) * block_dim.Ho; - } - if (blkId.Wo == (grid_dim.Wo - 1)) { - blkSize.Wo = param.Wo - (grid_dim.Wo - 1) * block_dim.Wo; - } - if (blkId.G == (grid_dim.G - 1)) { - blkSize.G = param.group_size - (grid_dim.G - 1) * block_dim.G; - } - if (blkId.Cg == (grid_dim.Cg - 1)) { - blkSize.Cg = param.Cg - (grid_dim.Cg - 1) * block_dim.Cg; - } - - // block end indices - blkEnd.Ho = blkStart.Ho + blkSize.Ho - 1; - blkEnd.Wo = blkStart.Wo + blkSize.Wo - 1; - - // set output_nram to zero - __bang_write_value(output_nram, param.output_nram_size, T(0)); - - // loop blocks of kernel window: grid_dim.(Kh, Kw) - for (blkId.Kh = 0; blkId.Kh < grid_dim.Kh; ++blkId.Kh) { - blkStart.Kh = blkId.Kh * block_dim.Kh; - blkSize.Kh = block_dim.Kh; - if (blkId.Kh == (grid_dim.Kh - 1)) { - blkSize.Kh = param.kernel_size - (grid_dim.Kh - 1) * block_dim.Kh; - } - blkEnd.Kh = blkStart.Kh + blkSize.Kh - 1; - - blkStart.Hi = blkStart.Ho / param.scale_factor - param.kernel_size_half + - blkStart.Kh; - blkEnd.Hi = - blkEnd.Ho / param.scale_factor - param.kernel_size_half + blkEnd.Kh; - blkSize.Hi = blkEnd.Hi - blkStart.Hi + 1; - - for (blkId.Kw = 0; blkId.Kw < grid_dim.Kw; ++blkId.Kw) { - blkStart.Kw = blkId.Kw * block_dim.Kw; - blkSize.Kw = block_dim.Kw; - if (blkId.Kw == (grid_dim.Kw - 1)) { - blkSize.Kw = param.kernel_size - (grid_dim.Kw - 1) * block_dim.Kw; - } - blkEnd.Kw = blkStart.Kw + blkSize.Kw - 1; - - blkStart.Wi = blkStart.Wo / param.scale_factor - - param.kernel_size_half + blkStart.Kw; - blkEnd.Wi = - blkEnd.Wo / param.scale_factor - param.kernel_size_half + blkEnd.Kw; - blkSize.Wi = blkEnd.Wi - blkStart.Wi + 1; - - // load input block from gdram2nram - // - // input_nram[ | input[ sample_idx, - // 0:blkSize.Hi-1, | blkStart.Hi + 0:blkSize.Hi-1, - // 0:blkSize.Wi-1, | blkStart.Wi + 0:blkSize.Wi-1, - // 0:blkSize.G-1 | blkStart.G + 0:blkSize.G-1 - // 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1] - // - // To skip out of bound indices: - // - // input_nram[ - // hi_start_local:hi_end_local, - // wi_start_local:wi_end_local, ...] - // = input[n, - // hi_start_global:hi_end_global, - // wi_start_global:wi_end_global, ...] - // - int hi_start_local = 0; - int hi_start_global = blkStart.Hi; - if (blkStart.Hi < 0) { - hi_start_local = -blkStart.Hi; - hi_start_global = 0; - } - int wi_start_local = 0; - int wi_start_global = blkStart.Wi; - if (blkStart.Wi < 0) { - wi_start_local = -blkStart.Wi; - wi_start_global = 0; - } - int hi_end_local = blkSize.Hi - 1; - int hi_end_global = blkEnd.Hi; - if (blkEnd.Hi > param.Hi - 1) { - hi_end_global = param.Hi - 1; - hi_end_local -= blkEnd.Hi - hi_end_global; - } - int wi_end_local = blkSize.Wi - 1; - int wi_end_global = blkEnd.Wi; - if (blkEnd.Wi > param.Wi - 1) { - wi_end_global = param.Wi - 1; - wi_end_local -= blkEnd.Wi - wi_end_global; - } - - int dst_offset = param.input_nram_stride_h * hi_start_local + - param.input_nram_stride_w * wi_start_local; - T *dst = input_nram + dst_offset; - - int src_offset = INDEX3(sample_idx, hi_start_global, wi_start_global, - blkStart.C, param.input_stride_n, - param.input_stride_h, param.input_stride_w); - T *src = input + src_offset; - - int input_seg_num_h = hi_end_local - hi_start_local + 1; - int input_seg_num_w = wi_end_local - wi_start_local + 1; - for (int i = 0; i < input_seg_num_h; ++i) { - loadStr3D(dst, src, blkSize.Cg, blkSize.G, input_seg_num_w, - param.input_nram_stride_g, param.input_nram_stride_w, - param.input_stride_g, param.input_stride_w); - dst += param.input_nram_stride_h; - src += param.input_stride_h; - } - - /* load mask block from gdram2nram - * - * mask_nram[ | mask[sample_idx, - * 0:blkSize.Ho-1 , | blkStart.Ho + 0:blkSize.Ho-1, - * 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1, - * 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1, - * 0:blkSize.Kh-1, | blkStart.Kh + 0:blkSize.Kh-1, - * 0:blkSize.Kw-1] | blkStart.Kw + 0:blkSize.Kw-1] - */ - src_offset = INDEX3(blkStart.Wo, blkStart.G, blkStart.Kh, blkStart.Kw, - param.mask_stride_w, param.mask_stride_g, - param.mask_stride_kh); - src_offset += sample_idx * param.mask_stride_n + - blkStart.Ho * param.mask_stride_h; - - for (int ho = 0; ho < blkSize.Ho; ++ho) { - dst = mask_nram + ho * param.mask_nram_stride_h; - src = mask + src_offset + ho * param.mask_stride_h; - - for (int wo = 0; wo < blkSize.Wo; ++wo) { - loadStr3D(dst, src, blkSize.Kw, blkSize.Kh, blkSize.G, - param.mask_nram_stride_kh, param.mask_nram_stride_g, - param.mask_stride_kh, param.mask_stride_g); - dst += param.mask_nram_stride_w; - src += param.mask_stride_w; - } - } - - // loop each pixel of the output block - for (int ho = 0; ho < blkSize.Ho; ++ho) { - int kernel_hi_start_global = (blkStart.Ho + ho) / param.scale_factor - - param.kernel_size_half + blkStart.Kh; - int kernel_hi_start_local = kernel_hi_start_global - blkStart.Hi; - - // int kernel_hi_end_global = kernel_hi_start_global + blkSize.Kh - 1; - // int kernel_hi_end_local = kernel_hi_end_global - blkStart.Hi; - - // exclude out of bound indices which should be ignored - int kh_min = hi_start_local - kernel_hi_start_local > 0 - ? hi_start_local - kernel_hi_start_local - : 0; - int kh_max = hi_end_local - kernel_hi_start_local < blkSize.Kh - 1 - ? hi_end_local - kernel_hi_start_local - : blkSize.Kh - 1; - - for (int wo = 0; wo < blkSize.Wo; ++wo) { - int kernel_wi_start_global = - (blkStart.Wo + wo) / param.scale_factor - - param.kernel_size_half + blkStart.Kw; - int kernel_wi_start_local = kernel_wi_start_global - blkStart.Wi; - - // exclude out of bound indices wwich should be ignored - int kw_min = wi_start_local - kernel_wi_start_local > 0 - ? wi_start_local - kernel_wi_start_local - : 0; - int kw_max = wi_end_local - kernel_wi_start_local < blkSize.Kw - 1 - ? wi_end_local - kernel_wi_start_local - : blkSize.Kw - 1; - - // output_nram[ho, wo, g, c] = sum(mask_nram[ho, wo, g, kh, kw] - // * input_nram[hi+kh, wi+kw, g, c], - // for (kh,kw) in [0:blkSize.Kw-1] x [0:blkSize.Kh-1]) - // - // sum(mask_nram[ho, wo, g, kh, kw] - // * input_nram[hi+kh, wi+kw, g, c], (kh,kw)) - // - T *mask_array = mask_nram + param.mask_nram_stride_h * ho + - param.mask_nram_stride_w * wo; - - for (int kh = kh_min; kh <= kh_max; ++kh) { - for (int kw = kw_min; kw <= kw_max; ++kw) { - T *src = - input_nram + - param.input_nram_stride_h * (kernel_hi_start_local + kh) + - param.input_nram_stride_w * (kernel_wi_start_local + kw); - - int mask_index = param.mask_nram_stride_kh * kh + kw; - - // mlutiply mask weight with channels for each channel group - T *sum = sum_array; - - for (int g = 0; g < blkSize.G; ++g) { - __bang_mul_scalar(sum, src, mask_array[mask_index], - param.block_Cg_NFU); - // - // NOTE: Since block_Cg_NFU >= block_Cg_stride, - // overlapped writing may occur on sum_array. - // So this loop must be executed in order to - // avoid data contamination, as shown below. - // - // |-----block_Cg_NFU---------| - // xxxxxxxxxxxxxxxxxxxxyyyzzzzz------------ - // |---block_Cg_stride---|^^^^^will be overwritten - // in the next iteration. - // - // x: actual data used, y: not used, z: overwritten - // - sum += param.input_nram_stride_g; - src += param.input_nram_stride_g; - mask_index += param.mask_nram_stride_g; - } // loop blk_G - - // add array[blk_G * blk_C] to output_nram - dst = output_nram + param.output_nram_stride_h * ho + - param.output_nram_stride_w * wo; - - __bang_add(dst, dst, sum_array, param.output_nram_stride_w); - } // end loop blk_Kw - } // end loop blk_Kh - } // end loop blk_Wo - } // end loop blk_Ho - } // end loop grid_dim.Kw - } // end loop grid_dim.Kh - - /* write output from nram2gdram - * - * output_nram[ | output[sample_idx, - * 0:blkSize.Ho-1, | blkStart.Ho + 0:blkSize.Ho-1, - * 0:blkSize.Wo-1, | blkStart.Wo + 0:blkSize.Wo-1, - * 0:blkSize.G-1, | blkStart.G + 0:blkSize.G-1, - * 0:blkSize.Cg-1] | blkStart.Cg + 0:blkSize.Cg-1] - */ - int dst_offset = INDEX3(sample_idx, blkStart.Ho, blkStart.Wo, blkStart.C, - param.output_stride_n, param.output_stride_h, - param.output_stride_w); - T *dst = output + dst_offset; - T *src = output_nram; - for (int i = 0; i < blkSize.Ho; ++i) { - storeStr3D(dst, src, blkSize.Cg, blkSize.G, blkSize.Wo, - param.output_stride_g, param.output_stride_w, - param.output_nram_stride_g, param.output_nram_stride_w); - dst += param.output_stride_h; - src += param.output_nram_stride_h; - } - } // end loop N, grid_dim.(Hi,Wi,G,Cg) -} - -template -__mlu_global__ void MLUBLOCKKernelCarafeForward( - const void *input, const void *mask, const CarafeForwardParam param, - const CarafeForwardBlockDim block_dim, const CarafeForwardGridDim grid_dim, - void *output) { - carafeForwardBLOCK((T *)input, (T *)mask, param, block_dim, grid_dim, - (T *)output); -} -} // namespace forward - -namespace backward { -template -__mlu_func__ void CarafeCompute(T *input, T *mask, T *grad_output, - T *grad_input, T *grad_mask, const int n, - const int hi, const int wi, const int c, - const int k_up, const int group, - const int scale) { - char *input_buff = nram_buf; - char *mask_buff = input_buff + NRAM_BLOCK; - char *grad_input_buff = mask_buff + NRAM_BLOCK; - char *grad_output_buff = grad_input_buff + NRAM_BLOCK; - char *grad_mask_buff = grad_output_buff + NRAM_BLOCK; - - int wo = wi * scale; - int ho = hi * scale; - int out_num = n * ho * wo * group; - int group_size = c / group; - int repeat = out_num / taskDim + (int)(taskId < out_num % taskDim); - int num_align = PAD_DOWN(NRAM_BLOCK / sizeof(T), NFU_ALIGN_SIZE / sizeof(T)); - int num_per_loop = group_size / num_align; - int rem_for_loop = group_size % num_align; - int rem_for_loop_align = PAD_UP(rem_for_loop, NFU_ALIGN_SIZE / sizeof(T)); - for (int k = 0; k < repeat; k++) { - int iter = k * taskDim + taskId; - int group_k = iter % group; - int w_k = (iter / group) % wo; - int h_k = (iter / wo / group) % ho; - int n_k = (iter / ho / wo / group) % n; - int h_i = h_k / scale; - int w_i = w_k / scale; - int start_h = h_i - ((k_up - 1) / 2); - int end_h = h_i + ((k_up - 1) / 2) + 1; - int start_w = w_i - ((k_up - 1) / 2); - int end_w = w_i + ((k_up - 1) / 2) + 1; - T *base_mask = (T *)mask + n_k * ho * wo * group * k_up * k_up + - h_k * wo * group * k_up * k_up + w_k * group * k_up * k_up + - group_k * k_up * k_up; - T *base_grad_mask = (T *)grad_mask + n_k * ho * wo * group * k_up * k_up + - h_k * wo * group * k_up * k_up + - w_k * group * k_up * k_up + group_k * k_up * k_up; - - __bang_write_zero((T *)grad_input_buff, NRAM_BLOCK / sizeof(T)); - __bang_write_zero((T *)grad_mask_buff, NRAM_BLOCK / sizeof(T)); - __bang_write_zero((T *)grad_output_buff, NRAM_BLOCK / sizeof(T)); - - __memcpy((T *)mask_buff, (T *)base_mask, k_up * k_up * sizeof(T), - GDRAM2NRAM); - for (int i = 0; i < num_per_loop; i++) { - __bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T)); - T *base_grad_output = (T *)grad_output + n_k * ho * wo * c + - h_k * wo * c + w_k * c + group_k * group_size + - i * num_align; - __memcpy((T *)grad_output_buff, (T *)base_grad_output, - num_align * sizeof(T), GDRAM2NRAM); - for (int ih = start_h; ih < end_h; ih++) { - for (int iw = start_w; iw < end_w; iw++) { - if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) { - continue; - } - int mask_ih = ih - h_i + (k_up - 1) / 2; - int mask_iw = iw - w_i + (k_up - 1) / 2; - int mask_index = mask_ih * k_up + mask_iw; - int input_index = n_k * hi * wi * c + ih * wi * c + iw * c + - group_k * group_size + i * num_align; - T *base_input = (T *)input + input_index; - T *base_grad_input = (T *)grad_input + input_index; - __memcpy((T *)input_buff, (T *)base_input, num_align * sizeof(T), - GDRAM2NRAM); - __bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff, - ((T *)mask_buff)[mask_index], num_align); - __bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input, - (T *)grad_input_buff, num_align); - __bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff, - num_align); - - __bang_sumpool((T *)input_buff, (T *)input_buff, - NFU_ALIGN_SIZE / sizeof(T), - num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, - num_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1); - - __bang_reduce_sum((T *)input_buff, (T *)input_buff, - NFU_ALIGN_SIZE / sizeof(T)); - ((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0]; - } - } - } - if (rem_for_loop) { - __bang_write_zero((T *)input_buff, NRAM_BLOCK / sizeof(T)); - T *base_grad_output = (T *)grad_output + n_k * ho * wo * c + - h_k * wo * c + w_k * c + group_k * group_size + - num_per_loop * num_align; - __memcpy((T *)grad_output_buff, (T *)base_grad_output, - rem_for_loop * sizeof(T), GDRAM2NRAM); - for (int ih = start_h; ih < end_h; ih++) { - for (int iw = start_w; iw < end_w; iw++) { - if (ih < 0 || ih > hi - 1 || iw < 0 || iw > wi - 1) { - continue; - } - int mask_ih = ih - h_i + (k_up - 1) / 2; - int mask_iw = iw - w_i + (k_up - 1) / 2; - int mask_index = mask_ih * k_up + mask_iw; - int input_index = n_k * hi * wi * c + ih * wi * c + iw * c + - group_k * group_size + num_per_loop * num_align; - T *base_input = (T *)input + input_index; - T *base_grad_input = (T *)grad_input + input_index; - __memcpy((T *)input_buff, (T *)base_input, rem_for_loop * sizeof(T), - GDRAM2NRAM); - __bang_mul_scalar((T *)grad_input_buff, (T *)grad_output_buff, - ((T *)mask_buff)[mask_index], rem_for_loop_align); - __bang_atomic_add((T *)grad_input_buff, (T *)base_grad_input, - (T *)grad_input_buff, rem_for_loop); - __bang_mul((T *)input_buff, (T *)grad_output_buff, (T *)input_buff, - rem_for_loop_align); - - __bang_sumpool( - (T *)input_buff, (T *)input_buff, NFU_ALIGN_SIZE / sizeof(T), - rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, - rem_for_loop_align / (NFU_ALIGN_SIZE / sizeof(T)), 1, 1, 1); - __bang_reduce_sum((T *)input_buff, (T *)input_buff, - NFU_ALIGN_SIZE / sizeof(T)); - - ((T *)grad_mask_buff)[mask_index] += ((T *)input_buff)[0]; - } - } - } - __memcpy((T *)base_grad_mask, (T *)grad_mask_buff, k_up * k_up * sizeof(T), - NRAM2GDRAM); - } -} - -template -__mlu_global__ void MLUUnion1KernelCarafeBackward( - const void *input, const void *mask, const void *grad_output, - void *grad_input, void *grad_mask, const int n, const int hi, const int wi, - const int c, const int k_up, const int group, const int scale) { - CarafeCompute((T *)input, (T *)mask, (T *)grad_output, (T *)grad_input, - (T *)grad_mask, n, hi, wi, c, k_up, group, scale); -} -} // namespace backward - -void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const void *input, const void *mask, - const CarafeForwardParam ¶m, - const CarafeForwardBlockDim &block_dim, - const CarafeForwardGridDim &grid_dim, void *output) { - if (d_type == CNRT_FLOAT16) { - forward::MLUBLOCKKernelCarafeForward<<>>( - input, mask, param, block_dim, grid_dim, output); - } else { - forward::MLUBLOCKKernelCarafeForward<<>>( - input, mask, param, block_dim, grid_dim, output); - } -} - -void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t dtype, - const void *input, const void *mask, - const void *grad_output, void *grad_input, - void *grad_mask, const int n, const int hi, - const int wi, const int c, const int k_up, - const int group, const int scale) { - if (dtype == CNRT_FLOAT16) { - backward::MLUUnion1KernelCarafeBackward<<>>( - input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up, - group, scale); - } else { - backward::MLUUnion1KernelCarafeBackward<<>>( - input, mask, grad_output, grad_input, grad_mask, n, hi, wi, c, k_up, - group, scale); - } -} diff --git a/mmcv/ops/csrc/common/mlu/carafe_utils.hpp b/mmcv/ops/csrc/common/mlu/carafe_utils.hpp deleted file mode 100644 index 09ca60ab11..0000000000 --- a/mmcv/ops/csrc/common/mlu/carafe_utils.hpp +++ /dev/null @@ -1,95 +0,0 @@ -/************************************************************************* - * Copyright (C) 2022 Cambricon. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS - * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF - * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. - * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY - * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, - * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE - * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - *************************************************************************/ -#ifndef CARAFE_UTILS_HPP_ -#define CARAFE_UTILS_HPP_ - -#define NRAM_ALIGN_SIZE 64 - -struct CarafeForwardParam { - int N; // batch size - int Hi; // input height - int Wi; // input width - int Ci; // input channels - int Ho; // output height - int Wo; // output width - int Cg; // channels per group - - int kernel_size; // kernel_size - int group_size; // group_size - int scale_factor; // scale_factor - int kernel_size_half; // kernel half size (K-1)/2 - int kernel_size_sq; // square of kernel size - - int dtype_size; // size of tensor data type - - // Host arrays' geometry - int input_stride_g; - int input_stride_w; - int input_stride_h; - int input_stride_n; - int input_size; - int mask_stride_kh; - int mask_stride_g; - int mask_stride_w; - int mask_stride_h; - int mask_stride_n; - int mask_size; - int output_stride_g; - int output_stride_w; - int output_stride_h; - int output_stride_n; - int output_size; - - // NRAM arrays' geometry - int input_nram_stride_g; - int input_nram_stride_w; - int input_nram_stride_h; - int input_nram_size; - int mask_nram_stride_kh; - int mask_nram_stride_g; - int mask_nram_stride_w; - int mask_nram_stride_h; - int mask_nram_size; - int output_nram_stride_g; - int output_nram_stride_w; - int output_nram_stride_h; - int output_nram_size; - - // for address/compute alignment - int align_size_NRAM; // for addressing on NRAM - int align_size_NFU; // for NFU operation length - int block_Cg_NFU; // for bang_mul_const - - int job_num; // total job number -}; - -struct CarafeForwardBlockDim { - int Ho; // block size of output height - int Wo; // block size of output width - int Kh; // block size of kernel height - int Kw; // block size of kernel width - int G; // block size of groups - int Cg; // block size of channels within a group - int Hi; // block size of input height - int Wi; // block size of input width -}; - -struct CarafeForwardGridDim { - int Ho; // number of blocks of output height - int Wo; - int Kh; - int Kw; - int G; - int Cg; -}; - -#endif // CARAFE_UTILS_HPP_ diff --git a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp index 88805ba8e9..8527372241 100644 --- a/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp +++ b/mmcv/ops/csrc/common/mlu/common_mlu_helper.hpp @@ -45,148 +45,6 @@ __mlu_func__ inline scalar_t max(scalar_t a, scalar_t b) { return a > b ? a : b; } -/*! - * @brief loads data from global DRAM to NRAM with 2D pattern. - * - * @param[out] dst - * Pointer to NRAM that stores dst data. - * @param[in] src - * Pointer to global DRAM that stores src data. - * @param[in] size - * The byte size of segment in the lower dimension. - * @param[in] dst_str - * The data stride in bytes between segments in the lower dimension of dst. - * @param[in] src_str - * The data stride in bytes between segments in the lower dimension of src. - * @param[in] seg_num - * The total count of data segments in the lower dimension. - */ -template -__mlu_func__ void loadStr2D(T *dst, T *src, const int size, const int dst_str, - const int src_str, const int seg_num) { - if (dst_str == src_str && size == src_str) { - __memcpy(dst, src, src_str * seg_num * sizeof(T), GDRAM2NRAM); - } else if ((size == src_str || src_str <= dst_str) && - src_str * sizeof(T) <= 512) { - // gather data less than 512Bytes to improve IO efficiency - T *tmp = (T *)dst + (dst_str - src_str) * seg_num; - __memcpy(tmp, src, (src_str * (seg_num - 1) + size) * sizeof(T), - GDRAM2NRAM); - if (dst_str != src_str) { - __memcpy(dst, tmp, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T), - src_str * sizeof(T), seg_num - 1); - } - } else { - __memcpy(dst, src, size * sizeof(T), GDRAM2NRAM, dst_str * sizeof(T), - src_str * sizeof(T), seg_num - 1); - } -} - -/*! - * @brief loads data from global DRAM to NRAM with 3D pattern. - * - * @param[out] dst - * Pointer to NRAM that stores dst data. - * @param[in] src - * Pointer to global DRAM that stores src data. - * @param[in] size - * The byte size of segment in the lowest dimension. - * @param[in] seg_num_in - * The total count of data segments in the lowest dimension. - * @param[in] seg_num_out - * The total count of data segments in the middle dimension. - * @param[in] dst_str_in - * The data stride in bytes between segments in the lowest dimension of dst. - * @param[in] dst_str_out - * The data stride in bytes between segments in the middle dimension of dst. - * @param[in] src_str_in - * The data stride in bytes between segments in the lowest dimension of src. - * @param[in] src_str_out - * The data stride in bytes between segments in the middle dimension of src. - */ -template -__mlu_func__ void loadStr3D(T *dst, T *src, const int size, - const int seg_num_in, const int seg_num_out, - const int dst_str_in, const int dst_str_out, - const int src_str_in, const int src_str_out) { - T *tmp_dst = dst; - T *tmp_src = src; - - for (int i = 0; i < seg_num_out; ++i) { - loadStr2D(tmp_dst, tmp_src, size, dst_str_in, src_str_in, seg_num_in); - tmp_src += src_str_out; - tmp_dst += dst_str_out; - } -} - -/*! - * @brief stores data from NRAM to global DRAM with 2D pattern. - * - * @param[out] dst - * Pointer to global DRAM that stores dst data. - * @param[in] src - * Pointer to NRAM that stores src data. - * @param[in] size - * The byte size of segment in the lower dimension. - * @param[in] dst_str - * The data stride in bytes between segments in the lower dimension of dst. - * @param[in] src_str - * The data stride in bytes between segments in the lower dimension of src. - * @param[in] seg_num - * The total count of data segments in the lower dimension. - */ -template -__mlu_func__ void storeStr2D(T *dst, T *src, const int size, const int seg_num, - const int dst_str, const int src_str) { - if ((size == dst_str && dst_str <= src_str) && dst_str * sizeof(T) <= 512) { - // gather data less than 512Bytes to improve IO efficiency - if (dst_str != src_str) { - __memcpy(src, src, size * sizeof(T), NRAM2NRAM, dst_str * sizeof(T), - src_str * sizeof(T), seg_num - 1); - } - __memcpy(dst, src, size * seg_num * sizeof(T), NRAM2GDRAM); - } else { - __memcpy(dst, src, size * sizeof(T), NRAM2GDRAM, dst_str * sizeof(T), - src_str * sizeof(T), seg_num - 1); - } -} - -/*! - * @brief stores data from NRAM to global DRAM with 3D pattern. - * - * @param[out] dst - * Pointer to global DRAM that stores dst data. - * @param[in] src - * Pointer to NRAM that stores src data. - * @param[in] size - * The byte size of segment in the lowest dimension. - * @param[in] seg_num_in - * The total count of data segments in the lowest dimension. - * @param[in] seg_num_out - * The total count of data segments in the middle dimension. - * @param[in] dst_str_in - * The data stride in bytes between segments in the lowest dimension of dst. - * @param[in] dst_str_out - * The data stride in bytes between segments in the middle dimension of dst. - * @param[in] src_str_in - * The data stride in bytes between segments in the lowest dimension of src. - * @param[in] src_str_out - * The data stride in bytes between segments in the middle dimension of src. - */ -template -__mlu_func__ void storeStr3D(T *dst, T *src, const int size, - const int seg_num_in, const int seg_num_out, - const int dst_str_in, const int dst_str_out, - const int src_str_in, const int src_str_out) { - T *tmp_dst = dst; - T *tmp_src = src; - for (int i = 0; i < seg_num_out; ++i) { - storeStr2D(tmp_dst, tmp_src, size, seg_num_in, dst_str_in, src_str_in); - tmp_src += src_str_out; - tmp_dst += dst_str_out; - } -} - /*! * @brief Converts int32 to float32 data type. * diff --git a/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp index 25e0b85d12..5a7d6c7e39 100644 --- a/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp +++ b/mmcv/ops/csrc/pytorch/mlu/carafe_mlu.cpp @@ -9,200 +9,13 @@ * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. *************************************************************************/ -#include "carafe_utils.hpp" -#include "pytorch_device_registry.hpp" -#include "pytorch_mlu_helper.hpp" - -void KernelCarafeForward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, const cnrtDataType_t d_type, - const void *input, const void *mask, - const CarafeForwardParam ¶m, - const CarafeForwardBlockDim &block_dim, - const CarafeForwardGridDim &grid_dim, void *output); - -void KernelCarafeBackward(cnrtDim3_t k_dim, cnrtFunctionType_t k_type, - cnrtQueue_t queue, cnrtDataType_t dtype, - const void *input, const void *mask, - const void *grad_output, void *grad_input, - void *grad_mask, const int n, const int hi, - const int wi, const int c, const int k_up, - const int group, const int scale); - -// Get total NRAM usage and set strides of NRAM arrays. -static void getNramUsage(CarafeForwardParam *param, - CarafeForwardBlockDim *block_dim, int *nram_usage) { - // input_nram[blkDim_(Hi+Kh)-1, blkDim_(Wi+Kw)-1, blkDim_G, blkDim_Cg] - block_dim->Hi = CEIL_DIV(block_dim->Ho, param->scale_factor) + 1; - block_dim->Wi = CEIL_DIV(block_dim->Wo, param->scale_factor) + 1; - - param->input_nram_stride_g = PAD_UP(block_dim->Cg, param->align_size_NRAM); - param->input_nram_stride_w = param->input_nram_stride_g * block_dim->G; - param->input_nram_stride_h = - (block_dim->Wi + block_dim->Kw - 1) * param->input_nram_stride_w; - param->input_nram_size = - (block_dim->Hi + block_dim->Kh - 1) * param->input_nram_stride_h; - - // mask_nram[blkDim_Ho, blkDim_Wo, blkDim_G, blkDim_Kh, blkDim_Kw] - param->mask_nram_stride_kh = block_dim->Kw; - param->mask_nram_stride_g = block_dim->Kh * param->mask_nram_stride_kh; - param->mask_nram_stride_w = block_dim->G * param->mask_nram_stride_g; - param->mask_nram_stride_h = block_dim->Wo * param->mask_nram_stride_w; - param->mask_nram_size = - PAD_UP(block_dim->Ho * param->mask_nram_stride_h, param->align_size_NRAM); - - // output_nram[blkDim_Ho, blkDim_Wo, blkDim_(G*Cg)] - param->output_nram_stride_g = param->input_nram_stride_g; - param->output_nram_stride_w = - PAD_UP(param->input_nram_stride_w, param->align_size_NFU); - param->output_nram_stride_h = block_dim->Wo * param->output_nram_stride_w; - param->output_nram_size = block_dim->Ho * param->output_nram_stride_h; - - // sum_array[blkDim_(G*Cg)] - - // ensure the last mul_const on Cg does not exceed memory boundary - int sum_array_size_bang_mul_const = - (block_dim->G - 1) * param->input_nram_stride_g + - PAD_UP(param->input_nram_stride_g, param->align_size_NFU); - - int sum_array_size = - std::max(param->output_nram_stride_w, sum_array_size_bang_mul_const); - - *nram_usage = param->input_nram_size + param->mask_nram_size + - param->output_nram_size + sum_array_size; -} - -// Policy Function for Forward -static void genPolicyForward(CarafeForwardParam *param, - CarafeForwardBlockDim *block_dim, - CarafeForwardGridDim *grid_dim, cnrtDim3_t *k_dim, - cnrtFunctionType_t *k_type) { - // device info - auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - auto cluster_num = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - auto core_num = core_dim * cluster_num; - - // maximum NRAM size as the number of - auto max_nram_size = - torch_mlu::getDeviceAttr(cnrtAttrNramSizePerMcore) / param->dtype_size; - - // determine grid and block dimensions - - // set initial values for block_dim and grid_dim - block_dim->Ho = param->Ho; - block_dim->Wo = param->Wo; - block_dim->Kh = param->kernel_size; - block_dim->Kw = param->kernel_size; - block_dim->G = param->group_size; - block_dim->Cg = param->Cg; - - grid_dim->Ho = 1; - grid_dim->Wo = 1; - grid_dim->Kh = 1; - grid_dim->Kw = 1; - grid_dim->G = 1; - grid_dim->Cg = 1; - - // decrease the block size to fit in the NRAM. - int nram_usage = 0; - while (true) { - getNramUsage(param, block_dim, &nram_usage); - - if (nram_usage > max_nram_size) { - // decrease Ho - // decrease block_Ho and block_Wo evenly - // so that the block is close to a square. - if (block_dim->Ho > 1 && block_dim->Ho >= block_dim->Wo) { - grid_dim->Ho += 1; - block_dim->Ho = CEIL_DIV(param->Ho, grid_dim->Ho); - } else if (block_dim->Wo > 1 && block_dim->Wo > block_dim->Ho) { - // decrease Wo - grid_dim->Wo += 1; - block_dim->Wo = CEIL_DIV(param->Wo, grid_dim->Wo); - } else if (block_dim->Kh > 1) { - // decrease Kh - grid_dim->Kh += 1; - block_dim->Kh = CEIL_DIV(param->kernel_size, grid_dim->Kh); - // reset Hi, Wi to maximize NRAM usage - grid_dim->Ho = 1; - block_dim->Ho = param->Ho; - grid_dim->Wo = 1; - block_dim->Wo = param->Wo; - } else if (block_dim->Kw > 1) { - // decrease Kw - grid_dim->Kw += 1; - block_dim->Kw = CEIL_DIV(param->kernel_size, grid_dim->Kw); - // reset Kh - grid_dim->Kh = 1; - block_dim->Kh = param->kernel_size; - } else if (block_dim->G > 1) { - // decrease G - grid_dim->G += 1; - block_dim->G = CEIL_DIV(param->group_size, grid_dim->G); - // reset Kw - grid_dim->Kw = 1; - block_dim->Kw = param->kernel_size; - } else if (block_dim->Cg > 1) { - // decrease block_Cg - // This is done in the last since c is the continuous dim - // (input layout is NHWC) and large c can improve - // IO & compute efficiency. - grid_dim->Cg += 1; - block_dim->Cg = CEIL_DIV(param->Cg, grid_dim->Cg); - // reset G - grid_dim->G = 1; - block_dim->G = param->group_size; - } else { - // the block volume is one now, cannot decrease the block size anymore! - // this situation should not occur. - break; - } - } else { - break; - } - } - - // define parameters depending on block_dim, grid_dim - param->block_Cg_NFU = PAD_UP(block_dim->Cg, param->align_size_NFU); - - // define host arrays' strides - - // input[N,H,W,G,Cg] - param->input_stride_g = param->Cg; - param->input_stride_w = param->Ci; - param->input_stride_h = param->Wi * param->input_stride_w; - param->input_stride_n = param->Hi * param->input_stride_h; - // mask[N,Ho,Wo,G,Kh,Kw] - param->mask_stride_kh = param->kernel_size; - param->mask_stride_g = param->kernel_size * param->mask_stride_kh; - param->mask_stride_w = param->group_size * param->mask_stride_g; - param->mask_stride_h = param->Wo * param->mask_stride_w; - param->mask_stride_n = param->Ho * param->mask_stride_h; - // output[N,Ho,Wo,G,Cg] - param->output_stride_g = param->Cg; - param->output_stride_w = param->Ci; - param->output_stride_h = param->Wo * param->output_stride_w; - param->output_stride_n = param->Ho * param->output_stride_h; - - param->job_num = - param->N * grid_dim->Ho * grid_dim->Wo * grid_dim->G * grid_dim->Cg; - - // determine task type and dims - *k_type = CNRT_FUNC_TYPE_BLOCK; - k_dim->x = std::min(param->job_num, static_cast(core_num)); - k_dim->y = 1; - k_dim->z = 1; -} +#include "mlu_common_helper.h" void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, Tensor rinput, Tensor routput, Tensor rmask, Tensor output, const int kernel_size, const int group_size, const int scale_factor) { - const int batch_size = output.size(0); - const int channels = output.size(1); - const int ho = output.size(2); - const int wo = output.size(3); - // check tensor data type TORCH_CHECK( input.scalar_type() == at::kFloat || input.scalar_type() == at::kHalf, @@ -221,37 +34,10 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, // return fast on zero-element tensor if (output.numel() == 0) { - output = at::zeros({batch_size, channels, ho, wo}, output.options()); + output = at::zeros(output.sizes().vec(), output.options()); return; } - // set param - CarafeForwardParam param; - param.N = input.size(0); - param.Ci = input.size(1); - param.Hi = input.size(2); - param.Wi = input.size(3); - - param.kernel_size = kernel_size; - param.group_size = group_size; - param.scale_factor = scale_factor; - param.Cg = param.Ci / group_size; - param.dtype_size = input.itemsize(); - param.align_size_NRAM = NRAM_ALIGN_SIZE / param.dtype_size; - param.align_size_NFU = NFU_ALIGN_SIZE / param.dtype_size; - param.kernel_size_sq = param.kernel_size * param.kernel_size; - param.kernel_size_half = (param.kernel_size - 1) / 2; - param.Ho = param.Hi * param.scale_factor; - param.Wo = param.Wi * param.scale_factor; - - // generate policy - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - CarafeForwardBlockDim block_dim; - CarafeForwardGridDim grid_dim; - - genPolicyForward(¶m, &block_dim, &grid_dim, &k_dim, &k_type); - // convert NCHW to NHWC auto memory_format_input_nhwc = torch_mlu::cnnl::ops::get_channels_last_memory_format(input.dim()); @@ -268,6 +54,12 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, auto routput_ = torch_mlu::cnnl::ops::cnnl_contiguous(output, memory_format_output_nhwc); + // set tensor descriptor + MluOpTensorDescriptor input_desc, mask_desc, output_desc; + input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC); + mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC); + output_desc.set_with_layout(routput_, MLUOP_LAYOUT_NHWC); + // get ptr of tensors auto input_impl = torch_mlu::getMluTensorImpl(rinput_); auto input_ptr = input_impl->cnnlMalloc(); @@ -276,45 +68,29 @@ void CARAFEForwardMLUKernelLauncher(const Tensor input, const Tensor mask, auto output_impl = torch_mlu::getMluTensorImpl(routput_); auto output_ptr = output_impl->cnnlMalloc(); - // get compute queue - auto queue = torch_mlu::getCurQueue(); - - // get dtype of input - cnrtDataType_t d_type = torch_mlu::toCnrtDtype(input.dtype()); - + // set op descriptor + auto handle = mluOpGetCurrentHandle(); + mluOpCarafeDescriptor_t carafe_desc; + mluOpCreateCarafeDescriptor(&carafe_desc); + mluOpSetCarafeDescriptor(carafe_desc, input.dim(), kernel_size, group_size, + scale_factor); // launch kernel - auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - CNLOG(INFO) << "Launch Kernel KernelCarafeForward<<>>"; - - KernelCarafeForward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, param, - block_dim, grid_dim, output_ptr); + mluOpCarafeForward(handle, carafe_desc, input_desc.desc(), input_ptr, + mask_desc.desc(), mask_ptr, output_desc.desc(), + output_ptr); + // destroy op descriptor + mluOpDestroyCarafeDescriptor(carafe_desc); // copy output from NHWC back into NCHW rinput.copy_(rinput_); output.copy_(routput_); } -// Policy Function for Backward -static void policyFuncBackward(cnrtDim3_t *k_dim, cnrtFunctionType_t *k_type) { - // set Union1 Job - *k_type = CNRT_FUNC_TYPE_UNION1; - k_dim->x = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - k_dim->y = torch_mlu::getDeviceAttr(cnrtAttrClusterCount); - k_dim->z = 1; -} - void CARAFEBackwardMLUKernelLauncher( const Tensor grad_output, const Tensor rinput, const Tensor mask, Tensor rgrad_output, Tensor rgrad_input_hs, Tensor rgrad_input, Tensor rgrad_mask, Tensor grad_input, Tensor grad_mask, const int kernel_size, const int group_size, const int scale_factor) { - const int batch_size = rinput.size(0); - const int channels = rinput.size(1); - const int hi = rinput.size(2); - const int wi = rinput.size(3); - // data type check TORCH_CHECK(grad_output.scalar_type() == at::kFloat || grad_output.scalar_type() == at::kHalf, @@ -331,11 +107,6 @@ void CARAFEBackwardMLUKernelLauncher( TORCH_CHECK(kernel_size < 137, "kernel_size should be less than 137, got ", kernel_size); - // set task dimension - cnrtDim3_t k_dim; - cnrtFunctionType_t k_type; - policyFuncBackward(&k_dim, &k_type); - // convert NCHW to NHWC auto memory_format_input_nhwc = torch_mlu::cnnl::ops::get_channels_last_memory_format(rinput.dim()); @@ -363,8 +134,15 @@ void CARAFEBackwardMLUKernelLauncher( auto rgrad_mask_ = torch_mlu::cnnl::ops::cnnl_contiguous( grad_mask, memory_format_grad_mask_nhwc); - // get compute queue - auto queue = torch_mlu::getCurQueue(); + // set tensor descriptor + MluOpTensorDescriptor input_desc, mask_desc; + input_desc.set_with_layout(rinput_, MLUOP_LAYOUT_NHWC); + mask_desc.set_with_layout(rmask_, MLUOP_LAYOUT_NHWC); + + MluOpTensorDescriptor grad_output_desc, grad_input_desc, grad_mask_desc; + grad_output_desc.set_with_layout(rgrad_output_, MLUOP_LAYOUT_NHWC); + grad_input_desc.set_with_layout(rgrad_input_, MLUOP_LAYOUT_NHWC); + grad_mask_desc.set_with_layout(rgrad_mask_, MLUOP_LAYOUT_NHWC); // get ptr of tensors auto input_impl = torch_mlu::getMluTensorImpl(rinput_); @@ -378,19 +156,19 @@ void CARAFEBackwardMLUKernelLauncher( auto grad_mask_impl = torch_mlu::getMluTensorImpl(rgrad_mask_); auto grad_mask_ptr = grad_mask_impl->cnnlMalloc(); - // get dtype of grad_output - cnrtDataType_t d_type = torch_mlu::toCnrtDtype(grad_output.dtype()); - auto core_dim = torch_mlu::getDeviceAttr(cnrtAttrMcorePerCluster); - - CNLOG(INFO) << "Launch Kernel KernelCarafeBackward<<>>"; - + // set op descriptor + auto handle = mluOpGetCurrentHandle(); + mluOpCarafeDescriptor_t carafe_desc; + mluOpCreateCarafeDescriptor(&carafe_desc); + mluOpSetCarafeDescriptor(carafe_desc, grad_output.dim(), kernel_size, + group_size, scale_factor); // launch kernel - KernelCarafeBackward(k_dim, k_type, queue, d_type, input_ptr, mask_ptr, - grad_output_ptr, grad_input_ptr, grad_mask_ptr, - batch_size, hi, wi, channels, kernel_size, group_size, - scale_factor); + mluOpCarafeBackward(handle, carafe_desc, input_desc.desc(), input_ptr, + mask_desc.desc(), mask_ptr, grad_output_desc.desc(), + grad_output_ptr, grad_input_desc.desc(), grad_input_ptr, + grad_mask_desc.desc(), grad_mask_ptr); + // destroy op descriptor + mluOpDestroyCarafeDescriptor(carafe_desc); // copy output from NHWC back into NCHW grad_input.copy_(rgrad_input_);