From b91bbd320fccb49576937d26d0639dc8e0f94583 Mon Sep 17 00:00:00 2001
From: 201716010711 <87008376+201716010711@users.noreply.github.com>
Date: Thu, 8 Dec 2022 11:13:19 +0800
Subject: [PATCH] Optimize Paddle diagonal (#47904)

---
 paddle/phi/kernels/cpu/diagonal_kernel.cc     | 67 ++++++++------
 paddle/phi/kernels/funcs/diagonal.h           | 88 +++++++++----------
 .../phi/kernels/gpu/diagonal_grad_kernel.cu   | 12 +++
 paddle/phi/kernels/gpu/diagonal_kernel.cu     | 11 ++-
 .../fluid/tests/unittests/test_diagonal_op.py | 29 ++++++
 5 files changed, 133 insertions(+), 74 deletions(-)

diff --git a/paddle/phi/kernels/cpu/diagonal_kernel.cc b/paddle/phi/kernels/cpu/diagonal_kernel.cc
index f125802c19e242..d2361bee30a5fe 100644
--- a/paddle/phi/kernels/cpu/diagonal_kernel.cc
+++ b/paddle/phi/kernels/cpu/diagonal_kernel.cc
@@ -35,6 +35,7 @@ void DiagonalKernel(const Context& dev_ctx,
   auto* output = out;
   T* output_data = dev_ctx.template Alloc<T>(output);
   auto output_dim = vectorize(output->dims());
+  auto output_dim_size = output_dim.size();
 
   const int64_t offset_ = offset;
   int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
@@ -43,40 +44,48 @@ void DiagonalKernel(const Context& dev_ctx,
   std::vector<int64_t> input_stride = funcs::ComputeDimStride(input_dim);
   std::vector<int64_t> output_stride = funcs::ComputeDimStride(output_dim);
 
-  int64_t numel = input->numel();
-
-  for (int64_t idx = 0; idx < numel; idx++) {
-    std::vector<int64_t> idx_dim(input_dim_size);
+  int64_t out_numel = out->numel();
+  for (int64_t idx = 0; idx < out_numel; idx++) {
+    std::vector<int64_t> idx_dim(output_dim_size);
     int64_t temp = 0;
-    for (size_t i = 0; i < input_dim_size; i++) {
-      idx_dim[i] = (idx - temp) / input_stride[i];
-      temp = temp + idx_dim[i] * input_stride[i];
+    for (size_t i = 0; i < output_dim_size; i++) {
+      idx_dim[i] = (idx - temp) / output_stride[i];
+      temp = temp + idx_dim[i] * output_stride[i];
     }
-
-    int64_t axis1_dim = idx_dim[axis1_];
-    int64_t axis2_dim = idx_dim[axis2_];
-
-    idx_dim.erase(idx_dim.begin() + std::max(axis1_, axis2_));
-    idx_dim.erase(idx_dim.begin() + std::min(axis1_, axis2_));
-
-    bool flag = false;
-    if (offset_ == 0 && axis1_dim == axis2_dim) {
-      idx_dim.push_back(axis1_dim);
-      flag = true;
-    } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
-      idx_dim.push_back(axis1_dim);
-      flag = true;
-    } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
-      idx_dim.push_back(axis2_dim);
-      flag = true;
+    int64_t tmp = idx_dim[output_dim_size - 1];
+    std::vector<int64_t> list;
+    list.clear();
+    int64_t l = std::min(axis1_, axis2_);
+    int64_t r = std::max(axis1_, axis2_);
+    for (size_t j = 0; j < output_dim_size - 1; j++) {
+      list.push_back(idx_dim[j]);
     }
-    if (flag) {
-      int64_t idx_output = 0;
-      for (size_t i = 0; i < idx_dim.size(); i++) {
-        idx_output = idx_output + idx_dim[i] * output_stride[i];
+    if (offset_ == 0) {
+      list.insert(list.begin() + l, tmp);
+      list.insert(list.begin() + r, tmp);
+    } else if (offset_ > 0) {
+      if (axis1_ < axis2_) {
+        list.insert(list.begin() + l, tmp);
+        list.insert(list.begin() + r, tmp + offset_);
+      } else {
+        list.insert(list.begin() + l, tmp + offset_);
+        list.insert(list.begin() + r, tmp);
       }
-      output_data[idx_output] = input_data[idx];
+    } else if (offset_ < 0) {
+      if (axis1_ < axis2_) {
+        list.insert(list.begin() + l, tmp - offset_);
+        list.insert(list.begin() + r, tmp);
+      } else {
+        list.insert(list.begin() + l, tmp);
+        list.insert(list.begin() + r, tmp - offset_);
+      }
+    }
+
+    int64_t input_offset = 0;
+    for (size_t i = 0; i < input_dim_size; i++) {
+      input_offset = input_offset + list[i] * input_stride[i];
     }
+    output_data[idx] = input_data[input_offset];
   }
 }
 }  // namespace phi
diff --git a/paddle/phi/kernels/funcs/diagonal.h b/paddle/phi/kernels/funcs/diagonal.h
index 92f970aed32795..a30fb79f8c8b04 100644
--- a/paddle/phi/kernels/funcs/diagonal.h
+++ b/paddle/phi/kernels/funcs/diagonal.h
@@ -156,59 +156,59 @@ __global__ void DiagonalCuda(const T* data1,
                              int64_t* x_stride,
                              int64_t* out_stride,
                              int64_t numel,
+                             int64_t out_numel,
                              bool is_grad) {
-  CUDA_KERNEL_LOOP(idx, numel) {
-    int64_t idx_dim[X_DIM_SIZE] = {0};
+  CUDA_KERNEL_LOOP(idx, out_numel) {
+    int64_t idx_dim[OUT_DIM_SIZE] = {0};
     int64_t temp = 0;
-    for (size_t i = 0; i < X_DIM_SIZE - 1; i++) {
-      idx_dim[i] = (idx - temp) / x_stride[i];
-      temp = temp + idx_dim[i] * x_stride[i];
+    for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
+      idx_dim[i] = (idx - temp) / out_stride[i];
+      temp = temp + idx_dim[i] * out_stride[i];
     }
-    idx_dim[X_DIM_SIZE - 1] = idx - temp;
-
-    int64_t axis1_dim = idx_dim[axis1_];
-    int64_t axis2_dim = idx_dim[axis2_];
-
-    int64_t out_dim[OUT_DIM_SIZE] = {0};
-    int temp_pos = 0;
-    for (int i = 0; i < X_DIM_SIZE; i++) {
-      if (i != axis1_ && i != axis2_) {
-        out_dim[temp_pos] = idx_dim[i];
-        temp_pos++;
+    idx_dim[OUT_DIM_SIZE - 1] = idx - temp;
+    int64_t tmp = idx - temp;
+    int64_t list[9];
+    int64_t p = 0;
+    for (size_t j = 0; j < X_DIM_SIZE; j++) {
+      if (j == axis1_ || j == axis2_) {
+        list[j] = 0;
+      } else {
+        list[j] = idx_dim[p];
+        p += 1;
       }
     }
-    bool flag = false;
-    if (offset_ == 0 && axis1_dim == axis2_dim) {
-      out_dim[temp_pos] = axis1_dim;
-      flag = true;
-    } else if (offset_ > 0 && (axis1_dim + offset_) == axis2_dim) {
-      out_dim[temp_pos] = axis1_dim;
-      flag = true;
-    } else if (offset_ < 0 && (axis1_dim + offset_) == axis2_dim) {
-      out_dim[temp_pos] = axis2_dim;
-      flag = true;
-    }
-    if (!is_grad) {
-      if (flag) {
-        int64_t idx_output = 0;
-        for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
-          idx_output = idx_output + out_dim[i] * out_stride[i];
-        }
-        idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
-        data2[idx_output] = data1[idx];
+    int64_t l = min(axis1_, axis2_);
+    int64_t r = max(axis1_, axis2_);
+    if (offset_ == 0) {
+      list[l] = tmp;
+      list[r] = tmp;
+    } else if (offset_ > 0) {
+      if (axis1_ < axis2_) {
+        list[l] = tmp;
+        list[r] = tmp + offset_;
+      } else {
+        list[l] = tmp + offset_;
+        list[r] = tmp;
       }
-    } else {
-      if (flag) {
-        int64_t idx_output = 0;
-        for (size_t i = 0; i < OUT_DIM_SIZE - 1; i++) {
-          idx_output = idx_output + out_dim[i] * out_stride[i];
-        }
-        idx_output = idx_output + out_dim[OUT_DIM_SIZE - 1];
-        data2[idx] = data1[idx_output];
+    } else if (offset_ < 0) {
+      if (axis1_ < axis2_) {
+        list[l] = tmp - offset_;
+        list[r] = tmp;
       } else {
-        data2[idx] = static_cast<T>(0);
+        list[l] = tmp;
+        list[r] = tmp - offset_;
       }
     }
+    int64_t input_offset = 0;
+
+    for (size_t i = 0; i < X_DIM_SIZE; i++) {
+      input_offset = input_offset + list[i] * x_stride[i];
+    }
+    if (!is_grad) {
+      data2[idx] = data1[input_offset];
+    } else {
+      data2[input_offset] = data1[idx];
+    }
   }
 }
 #endif
diff --git a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu
index 05a57426fcb213..a65d9af75f6a33 100644
--- a/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu
+++ b/paddle/phi/kernels/gpu/diagonal_grad_kernel.cu
@@ -62,6 +62,10 @@ void DiagonalGradKernel(const Context& dev_ctx,
   int threads = PADDLE_CUDA_NUM_THREADS;
   int blocks = (numel + threads - 1) / threads;
 
+  int64_t dout_numel = out_grad.numel();
+  phi::backends::gpu::GpuMemsetAsync(
+      dx_data, 0, numel * sizeof(T), dev_ctx.stream());
+
   switch (dx_dim_size) {
     case 2:
       funcs::DiagonalCuda<T, 2, 1><<<blocks, threads>>>(dout_data,
@@ -72,6 +76,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 3:
@@ -83,6 +88,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 4:
@@ -94,6 +100,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 5:
@@ -105,6 +112,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 6:
@@ -116,6 +124,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 7:
@@ -127,6 +136,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 8:
@@ -138,6 +148,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     case 9:
@@ -149,6 +160,7 @@ void DiagonalGradKernel(const Context& dev_ctx,
                                                         dx_stride,
                                                         dout_stride,
                                                         numel,
+                                                        dout_numel,
                                                         true);
       break;
     default:
diff --git a/paddle/phi/kernels/gpu/diagonal_kernel.cu b/paddle/phi/kernels/gpu/diagonal_kernel.cu
index 74bad0ecd9a350..74e7db258c7d1a 100644
--- a/paddle/phi/kernels/gpu/diagonal_kernel.cu
+++ b/paddle/phi/kernels/gpu/diagonal_kernel.cu
@@ -54,9 +54,10 @@ void DiagonalKernel(const Context& dev_ctx,
   int64_t axis1_ = axis1 < 0 ? input_dim_size + axis1 : axis1;
   int64_t axis2_ = axis2 < 0 ? input_dim_size + axis2 : axis2;
   int64_t numel = input->numel();
+  int64_t out_numel = out->numel();
 
   int threads = PADDLE_CUDA_NUM_THREADS;
-  int blocks = (numel + threads - 1) / threads;
+  int blocks = (out_numel + threads - 1) / threads;
 
   switch (input_dim_size) {
     case 2:
@@ -68,6 +69,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 3:
@@ -79,6 +81,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 4:
@@ -90,6 +93,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 5:
@@ -101,6 +105,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 6:
@@ -112,6 +117,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 7:
@@ -123,6 +129,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 8:
@@ -134,6 +141,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     case 9:
@@ -145,6 +153,7 @@ void DiagonalKernel(const Context& dev_ctx,
                                                         input_stride,
                                                         output_stride,
                                                         numel,
+                                                        out_numel,
                                                         false);
       break;
     default:
diff --git a/python/paddle/fluid/tests/unittests/test_diagonal_op.py b/python/paddle/fluid/tests/unittests/test_diagonal_op.py
index 5b3c3830c57ca0..cb35a3fce5d030 100644
--- a/python/paddle/fluid/tests/unittests/test_diagonal_op.py
+++ b/python/paddle/fluid/tests/unittests/test_diagonal_op.py
@@ -101,6 +101,35 @@ def test_check_grad(self):
         pass
 
 
+class TestDiagonalOpCase4(TestDiagonalOp):
+    def init_config(self):
+        self.case = np.random.randn(100, 100).astype('int64')
+        self.inputs = {'Input': self.case}
+        self.attrs = {'offset': 1, 'axis1': 1, 'axis2': 0}
+        self.target = np.diagonal(
+            self.inputs['Input'],
+            offset=self.attrs['offset'],
+            axis1=self.attrs['axis1'],
+            axis2=self.attrs['axis2'],
+        )
+
+    def test_check_grad(self):
+        pass
+
+
+class TestDiagonalOpCase5(TestDiagonalOp):
+    def init_config(self):
+        self.case = np.random.randn(4, 2, 4, 4).astype('float32')
+        self.inputs = {'Input': self.case}
+        self.attrs = {'offset': -2, 'axis1': 0, 'axis2': 3}
+        self.target = np.diagonal(
+            self.inputs['Input'],
+            offset=self.attrs['offset'],
+            axis1=self.attrs['axis1'],
+            axis2=self.attrs['axis2'],
+        )
+
+
 class TestDiagonalAPI(unittest.TestCase):
     def setUp(self):
         self.shape = [10, 3, 4]