diff --git a/paddle/phi/api/yaml/sparse_api.yaml b/paddle/phi/api/yaml/sparse_api.yaml index 2e44e201429f7..904b185448918 100644 --- a/paddle/phi/api/yaml/sparse_api.yaml +++ b/paddle/phi/api/yaml/sparse_api.yaml @@ -80,14 +80,14 @@ data_type : x backward : cast_grad -- api : conv3d - args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) - output : Tensor(out), Tensor(rulebook) +- api : conv3d_coo + args : (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) + output : Tensor(out), Tensor(rulebook), Tensor(counter) kernel : - func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense} + func : conv3d_coo{sparse_coo, dense -> sparse_coo, dense, dense} layout : x - intermediate : rulebook - backward : conv3d_grad + intermediate: rulebook, counter + backward : conv3d_coo_grad - api : coo_to_dense args : (Tensor x) @@ -352,11 +352,11 @@ - api: maxpool args : (Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) - output : Tensor(out), Tensor(rulebook) + output : Tensor(out), Tensor(rulebook), Tensor(counter) kernel : - func : maxpool_coo{sparse_coo -> sparse_coo, dense} + func : maxpool_coo{sparse_coo -> sparse_coo, dense, dense} layout : x - intermediate : rulebook + intermediate : rulebook, counter backward : maxpool_grad - api: mv diff --git a/paddle/phi/api/yaml/sparse_bw_api.yaml b/paddle/phi/api/yaml/sparse_bw_api.yaml index bde86f3816065..cf8de8ceea157 100644 --- a/paddle/phi/api/yaml/sparse_bw_api.yaml +++ b/paddle/phi/api/yaml/sparse_bw_api.yaml @@ -81,12 +81,12 @@ cast_csr_grad {sparse_csr, sparse_csr -> sparse_csr} data_type : out_grad -- backward_api : conv3d_grad - forward : conv3d (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) -> Tensor(out@SparseCooTensor), Tensor(rulebook@DenseTensor) - args : (Tensor x, Tensor kernel, Tensor rulebook, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm) +- backward_api : conv3d_coo_grad + forward : conv3d_coo (Tensor x, Tensor kernel, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) -> Tensor(out), Tensor(rulebook), Tensor(counter) + args : (Tensor x, Tensor kernel, Tensor out, Tensor rulebook, Tensor counter, Tensor out_grad, int[] paddings, int[] dilations, int[] strides, int groups, bool subm, str key) output : Tensor(x_grad), Tensor(kernel_grad) kernel : - func : conv3d_coo_grad{sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} + func : conv3d_coo_grad{sparse_coo, dense, sparse_coo, dense, dense, sparse_coo -> sparse_coo, dense} - backward_api : coo_to_dense_grad forward : coo_to_dense(Tensor x) -> Tensor(out) @@ -164,11 +164,11 @@ matmul_coo_coo_grad {sparse_coo, sparse_coo, sparse_coo -> sparse_coo, sparse_coo} - backward_api : maxpool_grad - forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook) - args : (Tensor x, Tensor rulebook, Tensor out, Tensor out_grad, int[] kernel_sizes) + forward : maxpool(Tensor x, int[] kernel_sizes, int[] paddings, int[] dilations, int[] strides) -> Tensor(out), Tensor(rulebook), Tensor(counter) + args : (Tensor x, Tensor rulebook, Tensor counter, Tensor out, Tensor out_grad, int[] kernel_sizes) output : Tensor(x_grad) kernel : - func : maxpool_coo_grad {sparse_coo, dense, sparse_coo, sparse_coo -> sparse_coo} + func : maxpool_coo_grad {sparse_coo, dense, dense, sparse_coo, sparse_coo -> sparse_coo} - backward_api : multiply_grad forward : multiply(Tensor x, Tensor y) -> Tensor(out) diff --git a/paddle/phi/core/sparse_coo_tensor.h b/paddle/phi/core/sparse_coo_tensor.h index c65b5ce57430b..300ae8a0ab958 100644 --- a/paddle/phi/core/sparse_coo_tensor.h +++ b/paddle/phi/core/sparse_coo_tensor.h @@ -156,6 +156,48 @@ class SparseCooTensor : public TensorBase, /// \brief get the dnese dim int32_t dense_dim() const; + /// \brief query table according to key + const std::pair* IndicesPairs( + const std::string& key) const { + if (indices_dict_ == nullptr) { + return nullptr; + } + const auto& iter = indices_dict_->find(key); + if (iter == indices_dict_->end()) { + return nullptr; + } + return &iter->second; + } + + /// \brief save (key, indices_pairs) + void SaveIndicesPairs( + const std::string& key, + const std::pair& indices_pairs) { + if (indices_dict_ == nullptr) { + indices_dict_ = std::make_shared< + std::map>>(); + } + auto ret = indices_dict_->insert({key, indices_pairs}); + if (ret.second == false) { + ret.first->second = indices_pairs; + } + } + + /// \brief get indices_dict_ + const std::shared_ptr< + std::map>>& + GetIndicesDict() const { + return indices_dict_; + } + + /// \brief set indices_dict_ + void SetIndicesDict( + const std::shared_ptr< + std::map>>& + indices_dict) { + indices_dict_ = indices_dict; + } + private: // save the indices of non zero elements in original dense tensor DenseTensor non_zero_indices_; @@ -165,6 +207,14 @@ class SparseCooTensor : public TensorBase, bool coalesced_ = false; // save the number of non zero elements in each batch DDim dims_; + + // for submanifold conv + // SubmConv will generate a rulebook and a counter, which can be + // reused by different SubmConv. + // refer to sparse/gpu/convolution_kernel.cu. + std::shared_ptr>> + indices_dict_ = nullptr; + /* --------------------------- */ /* example: non zero element is scalar */ /* --------------------------- */ diff --git a/paddle/phi/kernels/funcs/sparse/convolution.h b/paddle/phi/kernels/funcs/sparse/convolution.h index f3caa2a62f4a8..0c6b8b76b54d8 100644 --- a/paddle/phi/kernels/funcs/sparse/convolution.h +++ b/paddle/phi/kernels/funcs/sparse/convolution.h @@ -15,6 +15,7 @@ limitations under the License. */ #pragma once #include "paddle/phi/core/ddim.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/kernels/funcs/blas/blas.h" namespace phi { @@ -188,6 +189,88 @@ inline void PrefixSum(const T* counter, T* offsets, const int n) { offsets[n] = offset; } +template +inline const IntT* GetRulebookPtr(const SparseCooTensor& coo, + const DenseTensor& rulebook, + const std::string& key, + int* rulebook_len) { + if (!key.empty()) { + const auto* indices_pairs = coo.IndicesPairs(key); + if (indices_pairs != nullptr) { + const DenseTensor& tmp_rulebook = indices_pairs->first; + *rulebook_len = tmp_rulebook.dims()[1]; + return tmp_rulebook.data(); + } + } + *rulebook_len = rulebook.dims()[1]; + return rulebook.data(); +} + +inline const int* GetCounterPtr(const SparseCooTensor& coo, + const DenseTensor& counter, + const std::string& key) { + if (!key.empty()) { + const auto* indices_pairs = coo.IndicesPairs(key); + if (indices_pairs != nullptr) { + return indices_pairs->second.data(); + } + } + return counter.data(); +} + +template +inline const IntT* PrepareSubm(const Context& dev_ctx, + const SparseCooTensor& x, + const std::string& key, + const DDim& out_dims, + SparseCooTensor* out, + int* counter, + int* offsets, + int* rulebook_len, + bool* need_product_rulebook) { + const auto* indices_pairs = x.IndicesPairs(key); + if (indices_pairs != nullptr) { + *need_product_rulebook = false; + const DenseTensor& rulebook = indices_pairs->first; + const int counter_size = indices_pairs->second.numel(); + memcpy( + counter, indices_pairs->second.data(), counter_size * sizeof(int)); + out->SetIndicesDict(x.GetIndicesDict()); + + *rulebook_len = rulebook.dims()[1]; + + DenseTensor out_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor out_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); + phi::Copy( + dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices); + out->SetMember(out_indices, out_values, out_dims, false); + PrefixSum(counter, offsets, counter_size); + return rulebook.data(); + } + return nullptr; +} + +template +inline void SaveToTable(const Context& dev_ctx, + const SparseCooTensor& x, + const std::string& key, + const DenseTensor& in_rulebook, + const DenseTensor& h_counter, + SparseCooTensor* out, + DenseTensor* out_rulebook, + DenseTensor* counter) { + out->SetIndicesDict(x.GetIndicesDict()); + if (!key.empty()) { + out->SaveIndicesPairs(key, std::make_pair(in_rulebook, h_counter)); + } else { + *out_rulebook = in_rulebook; + counter->Resize({h_counter.numel()}); + int* counter_ptr = dev_ctx.template HostAlloc(counter); + memcpy(counter_ptr, h_counter.data(), h_counter.numel() * sizeof(int)); + } +} + } // namespace sparse } // namespace funcs } // namespace phi diff --git a/paddle/phi/kernels/funcs/sparse/scatter.cu.h b/paddle/phi/kernels/funcs/sparse/scatter.cu.h index b9568f1df716d..f27174d581818 100644 --- a/paddle/phi/kernels/funcs/sparse/scatter.cu.h +++ b/paddle/phi/kernels/funcs/sparse/scatter.cu.h @@ -13,6 +13,11 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" + +#define VecBytes 16 namespace phi { namespace funcs { @@ -28,33 +33,126 @@ namespace sparse { * channels: the output channel size * out: the outputs **/ -template +template __global__ void ScatterKernel(const T* input, const int* unique_value, const int* out_index, const int non_zero_num, const int rulebook_len, const int channels, - T* out, - const bool subm = false) { + T* out) { int tid = threadIdx.x + blockIdx.x * blockDim.x; - for (int i = tid; i < non_zero_num * channels; i += gridDim.x * blockDim.x) { - int indices_i = i / channels; - int channels_i = i - indices_i * channels; + const int vec_channels = channels / VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + for (int i = tid; i < non_zero_num * vec_channels; + i += gridDim.x * blockDim.x) { + int indices_i = i / vec_channels; + int channels_i = i - indices_i * vec_channels; int start = unique_value[indices_i]; int end = indices_i == non_zero_num - 1 ? rulebook_len : unique_value[indices_i + 1]; // max(end-start) = kernel_size - T sum = static_cast(0); - if (subm) { - sum = out[indices_i * channels + channels_i]; - } + StoreT sums = {static_cast(0)}; for (int j = start; j < end; j++) { const int out_feature_i = out_index[j]; - sum += input[out_feature_i * channels + channels_i]; + LoadT vec_in; + phi::Load( + input + out_feature_i * channels + channels_i * VecSize, &vec_in); +#pragma unroll + for (int k = 0; k < VecSize; k++) { + sums[k] += vec_in[k]; + } } - out[indices_i * channels + channels_i] = sum; + phi::Store(sums, + out + indices_i * channels + channels_i * VecSize); + } +} + +// scatter's index has been grouped in advance +// index_counts record the count of each group +// index_groups save the index of each group +template +__global__ void ScatterKernelV2(const T* input, + const int* index_counts, + const int* index_groups, + const int non_zero_num, + const int kernel_size, + const int channels, + const int buffer_counts, + T* out) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int vec_channels = channels / VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + for (int i = tid; i < non_zero_num * vec_channels; + i += gridDim.x * blockDim.x) { + int indices_i = i / vec_channels; + int channels_i = i - indices_i * vec_channels; + + StoreT sums = {static_cast(0)}; + phi::Load(out + indices_i * channels + channels_i * VecSize, + &sums); + for (int it = 0; it < buffer_counts; it++) { + int len = index_counts[indices_i + it * non_zero_num]; + const int group_offset = it * kernel_size * non_zero_num; + for (int j = 0; j < len; j++) { + const int out_feature_i = + index_groups[indices_i * kernel_size + j + group_offset]; + LoadT vec_in; + phi::Load( + input + out_feature_i * channels + channels_i * VecSize, &vec_in); +#pragma unroll + for (int k = 0; k < VecSize; k++) { + sums[k] += vec_in[k]; + } + } + } + phi::Store(sums, + out + indices_i * channels + channels_i * VecSize); + } +} + +template +void ScatterV2(const GPUContext& dev_ctx, + const T* input, + const int* index_counts, + const int* index_groups, + const int non_zero_num, + const int kernel_size, + const int channels, + const int buffer_counts, + T* output) { + const int VecSize = VecBytes / sizeof(T); + if (channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, non_zero_num * channels / VecSize, 1); + ScatterKernelV2<<>>(input, + index_counts, + index_groups, + non_zero_num, + kernel_size, + channels, + buffer_counts, + output); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, non_zero_num * channels, 1); + ScatterKernelV2<<>>(input, + index_counts, + index_groups, + non_zero_num, + kernel_size, + channels, + buffer_counts, + output); } } diff --git a/paddle/phi/kernels/sparse/conv_grad_kernel.h b/paddle/phi/kernels/sparse/conv_grad_kernel.h index 205823e620375..867f6b5a53f37 100644 --- a/paddle/phi/kernels/sparse/conv_grad_kernel.h +++ b/paddle/phi/kernels/sparse/conv_grad_kernel.h @@ -25,13 +25,16 @@ template void Conv3dCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad); @@ -40,13 +43,16 @@ std::tuple Conv3dCooGrad( const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, - const bool subm) { + const bool subm, + const std::string& key) { SparseCooTensor x_grad; DenseTensor kernel_grad; @@ -54,13 +60,16 @@ std::tuple Conv3dCooGrad( Conv3dCooGradKernel(dev_ctx, x, kernel, + out, rulebook, + counter, out_grad, paddings, dilations, strides, groups, subm, + key, &x_grad, &kernel_grad); return std::make_tuple(x_grad, kernel_grad); diff --git a/paddle/phi/kernels/sparse/conv_kernel.h b/paddle/phi/kernels/sparse/conv_kernel.h index fbff46d4390ba..0c5a2081a6f3d 100644 --- a/paddle/phi/kernels/sparse/conv_kernel.h +++ b/paddle/phi/kernels/sparse/conv_kernel.h @@ -31,8 +31,10 @@ void Conv3dCooKernel(const Context& dev_ctx, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* out, - DenseTensor* rulebook); + DenseTensor* rulebook, + DenseTensor* counter); template SparseCooTensor Conv3dCoo(const Context& dev_ctx, @@ -43,7 +45,9 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx, const std::vector& strides, const int groups, const bool subm, - DenseTensor* rulebook) { + const std::string& key, + DenseTensor* rulebook, + DenseTensor* counter) { SparseCooTensor coo; Conv3dCooKernel(dev_ctx, x, @@ -53,8 +57,10 @@ SparseCooTensor Conv3dCoo(const Context& dev_ctx, strides, groups, subm, + key, &coo, - rulebook); + rulebook, + counter); return coo; } diff --git a/paddle/phi/kernels/sparse/cpu/convolution.h b/paddle/phi/kernels/sparse/cpu/conv.h similarity index 97% rename from paddle/phi/kernels/sparse/cpu/convolution.h rename to paddle/phi/kernels/sparse/cpu/conv.h index 373087ade272b..e47f33c8c4834 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution.h +++ b/paddle/phi/kernels/sparse/cpu/conv.h @@ -41,13 +41,12 @@ void ProductRuleBook(const Context& dev_ctx, const DDim& out_dims, const bool subm, DenseTensor* rulebook, - DenseTensor* counter_per_kernel) { + int* counter_per_kernel) { const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); const IntT* indices_ptr = non_zero_indices.data(); - int* counter_ptr = counter_per_kernel->data(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; - memset(counter_ptr, 0, kernel_size * sizeof(int)); + memset(counter_per_kernel, 0, kernel_size * sizeof(int)); int rulebook_len = 0; // calc the rulebook_len @@ -107,7 +106,7 @@ void ProductRuleBook(const Context& dev_ctx, } if (rulebook_ptr == nullptr) { - counter_ptr[kernel_index - 1] += 1; + counter_per_kernel[kernel_index - 1] += 1; ++rulebook_len; } else { rulebook_ptr[rulebook_index] = kernel_index - 1; diff --git a/paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc index a8f4441eae897..44ad2fa588b55 100644 --- a/paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/conv_grad_kernel.cc @@ -17,7 +17,7 @@ limitations under the License. */ #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/kernels/sparse/cpu/conv.h" namespace phi { namespace sparse { @@ -34,22 +34,27 @@ template void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const IntT* rulebook_ptr = rulebook.data(); - const int rulebook_len = rulebook.dims()[1]; + int rulebook_len = 0; + const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr( + out, rulebook, key, &rulebook_len); + const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key); DenseTensorMeta in_features_meta( x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); @@ -86,16 +91,14 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0); - for (int i = 0; i < rulebook_len; i++) { - counter[rulebook_ptr[i]] += 1; - } - IntT offset = 0, max_count = 0; + std::vector offsets(kernel_size + 1); + IntT offset = 0; + int max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; - offset += counter[i]; + offset += counter_ptr[i]; if (i < half_kernel_size) { - max_count = std::max(max_count, counter[i]); + max_count = std::max(max_count, counter_ptr[i]); } } offsets[kernel_size] = offset; @@ -129,11 +132,11 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0 || (subm && i == half_kernel_size)) { + if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) { continue; } - const int M = counter[i]; + const int M = counter_ptr[i]; const int K = in_channels; const int N = out_channels; T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; @@ -171,7 +174,7 @@ void Conv3dCooGradCPUKernel(const CPUContext& dev_ctx, // 4. scatter Scatter(d_x_features_ptr, - rulebook.data() + rulebook_len, + rulebook_ptr + rulebook_len, rulebook_len, in_channels, x_grad_values_ptr); @@ -181,13 +184,16 @@ template void Conv3dCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { PD_VISIT_INTEGRAL_TYPES( @@ -195,13 +201,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx, Conv3dCooGradCPUKernel(dev_ctx, x, kernel, + out, rulebook, + counter, out_grad, paddings, dilations, strides, groups, subm, + key, x_grad, kernel_grad); })); diff --git a/paddle/phi/kernels/sparse/cpu/conv_kernel.cc b/paddle/phi/kernels/sparse/cpu/conv_kernel.cc index 7147a29a9c832..f15a636f96d45 100644 --- a/paddle/phi/kernels/sparse/cpu/conv_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/conv_kernel.cc @@ -14,9 +14,10 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/kernels/sparse/cpu/conv.h" namespace phi { namespace sparse { @@ -35,8 +36,10 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) @@ -66,26 +69,50 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook - DenseTensorMeta counter_meta( - DataType::INT32, {kernel_size}, DataLayout::NCHW); - DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); - - ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel); - - UpdateRulebookAndOutIndex( - dev_ctx, x, kernel_size, out_channels, out_dims, rulebook, out); - - int n = rulebook->dims()[1]; - const int* counter_ptr = counter_per_kernel.data(); + DenseTensor h_counter, h_offsets; + h_counter.Resize({kernel_size}); + h_offsets.Resize({kernel_size + 1}); + int* h_counter_ptr = dev_ctx.template HostAlloc(&h_counter); + int* h_offsets_ptr = dev_ctx.template HostAlloc(&h_offsets); + + // DenseTensor* rulebook = nullptr; + const IntT* rulebook_ptr = nullptr; + int n = 0; + bool need_product_rulebook = true; + if (subm && !key.empty()) { + rulebook_ptr = phi::funcs::sparse::PrepareSubm( + dev_ctx, + x, + key, + out_dims, + out, + h_counter_ptr, + h_offsets_ptr, + &n, + &need_product_rulebook); + } + if (need_product_rulebook) { + DenseTensor tmp_rulebook; + ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + &tmp_rulebook, + h_counter_ptr); + + UpdateRulebookAndOutIndex( + dev_ctx, x, kernel_size, out_channels, out_dims, &tmp_rulebook, out); + n = tmp_rulebook.dims()[1]; + rulebook_ptr = tmp_rulebook.data(); + + phi::funcs::sparse::SaveToTable( + dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); + } + // int n = rulebook->dims()[1]; // 2. gather DenseTensorMeta in_features_meta( @@ -100,34 +127,33 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, T* out_features_ptr = out_features.data(); Gather(x.non_zero_elements().data(), - rulebook->data() + n, + rulebook_ptr + n, n, in_channels, in_features_ptr); // 3. call gemm for every werght auto blas = phi::funcs::GetBlas(dev_ctx); - std::vector offsets(kernel_size + 1); int offset = 0; for (int i = 0; i < kernel_size; i++) { - offsets[i] = offset; - offset += counter_ptr[i]; + h_offsets_ptr[i] = offset; + offset += h_counter_ptr[i]; } - offsets[kernel_size] = offset; + h_offsets_ptr[kernel_size] = offset; const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (counter_ptr[i] <= 0) { + if (h_counter_ptr[i] <= 0) { continue; } // call gemm: (n, in_channels) * (in_channels, out_channels) - const int M = counter_ptr[i]; + const int M = h_counter_ptr[i]; const int K = in_channels; // in_channels const int N = out_channels; // out_channels - T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; + T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; const T* tmp_kernel_ptr = kernel_ptr + i * K * N; - T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels; + T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; blas.GEMM(CblasNoTrans, CblasNoTrans, M, @@ -143,11 +169,8 @@ void Conv3dCooCPUKernel(const CPUContext& dev_ctx, // 4. scatter T* out_values_ptr = out->mutable_non_zero_elements()->data(); memset(out_values_ptr, 0, sizeof(T) * out->nnz() * out_channels); - Scatter(out_features_ptr, - rulebook->data() + n * 2, - n, - out_channels, - out_values_ptr); + Scatter( + out_features_ptr, rulebook_ptr + n * 2, n, out_channels, out_values_ptr); } template @@ -159,8 +182,10 @@ void Conv3dCooKernel(const Context& dev_ctx, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dCooCPUKernel", ([&] { Conv3dCooCPUKernel(dev_ctx, @@ -171,8 +196,10 @@ void Conv3dCooKernel(const Context& dev_ctx, strides, groups, subm, + key, out, - rulebook); + rulebook, + counter); })); } diff --git a/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc index dfdd00433680a..077ac07e8d38e 100644 --- a/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/pool_grad_kernel.cc @@ -28,6 +28,7 @@ template void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, @@ -36,11 +37,10 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, const int channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; const IntT* rulebook_ptr = rulebook.data(); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0); - for (int i = 0; i < rulebook_len; i++) { - counter[rulebook_ptr[i]] += 1; - } - phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size); + std::vector offsets(kernel_size + 1); + const int* counter_ptr = counter.data(); + + phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size); const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); @@ -60,7 +60,7 @@ void MaxPoolCooGradCPUKernel(const CPUContext& dev_ctx, phi::funcs::MaxPoolGrad grad_functor; for (int i = 0; i < kernel_size; i++) { - for (int j = 0; j < counter[i]; j++) { + for (int j = 0; j < counter_ptr[i]; j++) { IntT in_i = rulebook_ptr[rulebook_len + offsets[i] + j]; IntT out_i = rulebook_ptr[rulebook_len * 2 + offsets[i] + j]; for (int c = 0; c < channels; c++) { @@ -78,6 +78,7 @@ template void MaxPoolCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, @@ -85,7 +86,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx, PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolCooGradCPUKernel", ([&] { MaxPoolCooGradCPUKernel( - dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad); })); } diff --git a/paddle/phi/kernels/sparse/cpu/pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/pool_kernel.cc index ae32b6cc1d695..f01017bba56f5 100644 --- a/paddle/phi/kernels/sparse/cpu/pool_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/pool_kernel.cc @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/kernels/sparse/cpu/convolution.h" +#include "paddle/phi/kernels/sparse/cpu/conv.h" namespace phi { namespace sparse { @@ -37,7 +37,8 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, const std::vector& dilations, const std::vector& strides, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -47,9 +48,7 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); const int in_channels = real_kernel_sizes[3]; - DenseTensorMeta counter_meta( - DataType::INT32, {kernel_size}, DataLayout::NCHW); - DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); + std::vector counter_per_kernel(kernel_size, 0); const T* in_features_ptr = x.non_zero_elements().data(); // 1. product rule book @@ -62,14 +61,17 @@ void MaxPoolCooCPUKernel(const CPUContext& dev_ctx, out_dims, false, rulebook, - &counter_per_kernel); + counter_per_kernel.data()); UpdateRulebookAndOutIndex( dev_ctx, x, kernel_size, in_channels, out_dims, rulebook, out); int rulebook_len = rulebook->dims()[1]; const IntT* rulebook_ptr = rulebook->data(); - const int* counter_ptr = counter_per_kernel.data(); + + counter->Resize({kernel_size}); + int* counter_ptr = dev_ctx.template HostAlloc(counter); + memcpy(counter_ptr, counter_per_kernel.data(), kernel_size * sizeof(int)); std::vector offsets(kernel_size + 1); phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size); @@ -105,7 +107,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolCooCPUKernel", ([&] { MaxPoolCooCPUKernel(dev_ctx, @@ -115,7 +118,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, dilations, strides, out, - rulebook); + rulebook, + counter); })); } diff --git a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu index f6aedb8b68fc3..a8e88f351ccbc 100644 --- a/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/coalesce_kernel.cu @@ -125,16 +125,35 @@ void CoalesceGPUKernel(const GPUContext& dev_ctx, } // 5. scatter the values - config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); - phi::funcs::sparse::ScatterKernel - <<>>( - x_values_ptr, - public_indexs.data(), - values_indexs_ptr, - out_nnz, - nnz, - stride, - out_values.data()); + const int VecSize = VecBytes / sizeof(T); + if (stride % VecSize == 0) { + config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, nnz * stride / VecSize, 1); + phi::funcs::sparse::ScatterKernel + <<>>(x_values_ptr, + public_indexs.data(), + values_indexs_ptr, + out_nnz, + nnz, + stride, + out_values.data()); + } else { + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, nnz * stride, 1); + phi::funcs::sparse::ScatterKernel + <<>>(x_values_ptr, + public_indexs.data(), + values_indexs_ptr, + out_nnz, + nnz, + stride, + out_values.data()); + } // 6. convert index to coordinate Dim const_dims; diff --git a/paddle/phi/kernels/sparse/gpu/conv.cu.h b/paddle/phi/kernels/sparse/gpu/conv.cu.h new file mode 100644 index 0000000000000..859857ed7baac --- /dev/null +++ b/paddle/phi/kernels/sparse/gpu/conv.cu.h @@ -0,0 +1,760 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/phi/kernels/sparse/conv_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_info.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/core/enforce.h" +#include "paddle/phi/core/tensor_utils.h" +#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/phi/kernels/funcs/index_impl.cu.h" +#include "paddle/phi/kernels/funcs/math_function.h" +#include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" +#include "paddle/phi/kernels/funcs/sparse/utils.cu.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" + +namespace phi { +namespace sparse { + +using Dims4D = phi::funcs::sparse::Dims4D; + +// Vectorize load and store global memory +// In the scene of 3D point cloud, the slice_size 4,8,16,32,64 are commonly +// used. +template +__global__ void GatherKernel(const T* params, + const IndexT* indices, + T* output, + size_t index_size, + size_t slice_size) { + CUDA_KERNEL_LOOP_TYPE(i, index_size * slice_size / VecSize, int64_t) { + const int vec_slice_size = slice_size / VecSize; + int indices_i = i / vec_slice_size; + int slice_i = i - indices_i * vec_slice_size; // offset inside the slice + IndexT gather_i = indices[indices_i]; + int64_t params_i = gather_i * slice_size + slice_i * VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + LoadT params_vec; + phi::Load(params + params_i, ¶ms_vec); + phi::Store(params_vec, output + i * VecSize); + } +} + +// double sparse, seed GroupIndexs +template +__global__ void GatherKernelV2(const T* inputs, + const int* index_counts, + const int* index_groups, + const int non_zero_num, + const int kernel_size, + const int channels, + const int buffer_count, + T* output) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int vec_channels = channels / VecSize; + using LoadT = phi::AlignedVector; + using StoreT = phi::AlignedVector; + for (int i = tid; i < non_zero_num * vec_channels; + i += gridDim.x * blockDim.x) { + int indices_i = i / vec_channels; + int channels_i = i - indices_i * vec_channels; + LoadT in_vec; + phi::Load(inputs + indices_i * channels + channels_i * VecSize, + &in_vec); +#pragma unroll + for (int it = 0; it < buffer_count; it++) { + int len = index_counts[indices_i + it * non_zero_num]; + const int group_offset = it * kernel_size * non_zero_num; +#pragma unroll + for (int j = 0; j < len; j++) { + int out_i = index_groups[indices_i * kernel_size + j + group_offset]; + phi::Store( + in_vec, output + out_i * channels + channels_i * VecSize); + } + } + } +} + +template +inline void Gather(const GPUContext& dev_ctx, + const T* inputs, + const IntT* indices, + const int indices_size, + const int channels, + T* output) { + const int VecSize = VecBytes / sizeof(T); + if (channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, indices_size * channels / VecSize, 1); + GatherKernel + <<>>(inputs, indices, output, indices_size, channels); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, indices_size * channels, 1); + GatherKernel + <<>>(inputs, indices, output, indices_size, channels); + } +} + +template +inline void GatherV2(const GPUContext& dev_ctx, + const T* inputs, + const int* index_counts, + const int* index_groups, + const int non_zero_num, + const int kernel_size, + const int channels, + const int buffer_count, + T* output) { + const int VecSize = VecBytes / sizeof(T); + if (channels % VecSize == 0) { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, non_zero_num * channels / VecSize, 1); + GatherKernelV2<<>>(inputs, + index_counts, + index_groups, + non_zero_num, + kernel_size, + channels, + buffer_count, + output); + } else { + auto config = phi::backends::gpu::GetGpuLaunchConfig1D( + dev_ctx, non_zero_num * channels, 1); + GatherKernelV2<<>>(inputs, + index_counts, + index_groups, + non_zero_num, + kernel_size, + channels, + buffer_count, + output); + } +} + +// unique the out indexs in rulebook +template +__global__ void UniqueKernel(const IntT* in_indexs, + const int rulebook_len, + int* out_index_table, + int* out_indexs, + int* nnz) { + extern __shared__ int cache[]; + __shared__ int count, start; + if (threadIdx.x == 0) { + count = 0; + start = 0; + } + __syncthreads(); + + int i = threadIdx.x + blockDim.x * blockIdx.x; + if (i < rulebook_len) { + // atomicOr only support int + int index = static_cast(in_indexs[i]); + int change_index = index == 0 ? -1 : index; + int flag = atomicOr(out_index_table + index, change_index); + if (flag == 0) { + int j = atomicAdd(&count, 1); + cache[j] = index; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + start = atomicAdd(nnz, count); + } + __syncthreads(); + for (int i = threadIdx.x; i < count; i += blockDim.x) { + out_indexs[start + i] = cache[i]; + } +} + +template +__global__ void GroupIndexs(const int* out_index_table, + const int n, + const int kernel_size, + IntT* out_indexs, + int* out_index_counts, + int* out_index_groups) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + IntT index = out_indexs[i]; + int real_index = out_index_table[index]; + out_indexs[i] = real_index; + + // kernel_size at most + int j = atomicAdd(out_index_counts + real_index, 1); + // nnz * kernel_size + out_index_groups[real_index * kernel_size + j] = i; + } +} + +/** + * @brief product rulebook + * for input_i in x_indices: + * if input_i participate in the convolution calculation: + * infer the output_i by input_i and kernel_i + * save output_i + * + * x_indices: the indices of input features + * x_dims: the input dims + * kernel_dims: the kernel dims + * out_dims: the output dims + * non_zero_num: the number of input features + * rulebook: the rulebook to save the kernel index, input index and output index + * counter: save the number of times each location in the kernel participates in + *the caculation + **/ +template +__global__ void ProductRuleBookKernel(const T* x_indices, + const Dims4D x_dims, + const Dims4D kernel_dims, + const Dims4D out_dims, + const int64_t non_zero_num, + const Dims4D paddings, + const Dims4D dilations, + const Dims4D strides, + T* rulebook, + int* counter) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + extern __shared__ int counter_buf[]; // kernel_size + const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1]; + const int offset = kernel_size * non_zero_num; + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + counter_buf[i] = 0; + } + __syncthreads(); + + for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { + int kernel_index = 0; + T batch = x_indices[i]; + T in_z = x_indices[i + non_zero_num]; + T in_y = x_indices[i + 2 * non_zero_num]; + T in_x = x_indices[i + 3 * non_zero_num]; + for (int kz = 0; kz < kernel_dims[1]; kz++) { + for (int ky = 0; ky < kernel_dims[2]; ky++) { + for (int kx = 0; kx < kernel_dims[3]; kx++) { + int in_i = -1, out_index = -1, kernel_i = -1; + if (phi::funcs::sparse::Check(x_dims, + kernel_dims, + paddings, + dilations, + strides, + in_x, + in_y, + in_z, + kx, + ky, + kz)) { + T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; + T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; + T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; + in_i = i; + out_index = phi::funcs::sparse::PointToIndex( + batch, out_x, out_y, out_z, out_dims); + atomicAdd(&counter_buf[kernel_index], 1); + kernel_i = kernel_index; + } + // rulebook[kernel_index * non_zero_num + i] = kernel_i; + rulebook[kernel_index * non_zero_num + i] = in_i; + rulebook[kernel_index * non_zero_num + offset + i] = out_index; + ++kernel_index; + } + } + } + } + __syncthreads(); + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + atomicAdd(&counter[i], counter_buf[i]); + } +} + +template +__global__ void GetOutIndexTable(const IntT* indices, + const IntT non_zero_num, + const Dims4D dims, + int* out_index_table) { + CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { + IntT batch = indices[i]; + IntT in_z = indices[i + non_zero_num]; + IntT in_y = indices[i + 2 * non_zero_num]; + IntT in_x = indices[i + 3 * non_zero_num]; + IntT index = PointToIndex(batch, in_x, in_y, in_z, dims); + out_index_table[index] = i == 0 ? -1 : i; + } +} + +template +__global__ void GetOutIndexTable(int* indexs, + const int non_zero_num, + const Dims4D out_dims, + int* out_index_table, + IntT* out_indices) { + CUDA_KERNEL_LOOP_TYPE(i, non_zero_num, int64_t) { + IntT index = static_cast(indexs[i]); + out_index_table[index] = i; + IntT batch, x, y, z; + phi::funcs::sparse::IndexToPoint( + index, out_dims, &batch, &x, &y, &z); + // get out indices + out_indices[i] = batch; + out_indices[i + non_zero_num] = z; + out_indices[i + non_zero_num * 2] = y; + out_indices[i + non_zero_num * 3] = x; + indexs[i] = 0; + } +} + +template +__global__ void CopyRuleBook(const int* counters, + const int* offsets, + const IntT* in_rulebook, + const int len, + const int kernel_size, + const int non_zero_num, + IntT* out_rulebook) { + int tid = threadIdx.x + blockDim.x * blockIdx.x; + extern __shared__ int cache_counters[]; + int* cache_offsets = cache_counters + kernel_size; + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + cache_counters[i] = counters[i]; + cache_offsets[i] = offsets[i]; + } + __syncthreads(); + for (int i = tid; i < len; i += gridDim.x * blockDim.x) { + // get the kernel index + int kernel_index = 0; + for (; kernel_index < kernel_size - 1; kernel_index++) { + if (i >= offsets[kernel_index] && i < offsets[kernel_index + 1]) { + break; + } + } + int inner_index = i - offsets[kernel_index]; + out_rulebook[i] = in_rulebook[kernel_index * non_zero_num + inner_index]; + out_rulebook[len + i] = + in_rulebook[kernel_size * non_zero_num + kernel_index * non_zero_num + + inner_index]; + } +} + +template +__global__ void ProductSubmRuleBookKernel(const T* x_indices, + const Dims4D x_dims, + const Dims4D kernel_dims, + const Dims4D out_dims, + const int64_t non_zero_num, + const Dims4D paddings, + const Dims4D dilations, + const Dims4D strides, + const int* out_index_table, + T* rulebook, + int* counter) { + int tid = threadIdx.x + blockIdx.x * blockDim.x; + const int kernel_size = kernel_dims[3] * kernel_dims[2] * kernel_dims[1]; + extern __shared__ int counter_buf[]; // kernel_size + int* counter_buf2 = counter_buf + kernel_size; + // length = kernel_size * blockDim.x * 2; + int* rulebook_buf = counter_buf + kernel_size * 2; + + const int offset = kernel_size * non_zero_num; + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + counter_buf[i] = 0; + } + __syncthreads(); + + for (int i = tid; i < non_zero_num; i += gridDim.x * blockDim.x) { + int kernel_index = 0; + T batch = x_indices[i]; + T in_z = x_indices[i + non_zero_num]; + T in_y = x_indices[i + 2 * non_zero_num]; + T in_x = x_indices[i + 3 * non_zero_num]; + for (int kz = 0; kz < kernel_dims[1]; kz++) { + for (int ky = 0; ky < kernel_dims[2]; ky++) { + for (int kx = 0; kx < kernel_dims[3]; kx++) { + int in_i = -1, out_index = -1, kernel_i = -1; + if (phi::funcs::sparse::Check(x_dims, + kernel_dims, + paddings, + dilations, + strides, + in_x, + in_y, + in_z, + kx, + ky, + kz)) { + T out_z = (in_z + paddings[1] - kz * dilations[1]) / strides[1]; + T out_y = (in_y + paddings[2] - ky * dilations[2]) / strides[2]; + T out_x = (in_x + paddings[3] - kx * dilations[3]) / strides[3]; + out_index = phi::funcs::sparse::PointToIndex( + batch, out_x, out_y, out_z, out_dims); + int real_out_index = out_index_table[out_index]; + if (real_out_index != 0) { + real_out_index = real_out_index == -1 ? 0 : real_out_index; + in_i = i; + int buf_i = atomicAdd(&counter_buf[kernel_index], 1); + kernel_i = kernel_index; + rulebook_buf[kernel_index * blockDim.x + buf_i] = in_i; + rulebook_buf[kernel_index * blockDim.x + + kernel_size * blockDim.x + buf_i] = real_out_index; + } + } + ++kernel_index; + } + } + } + } + __syncthreads(); + for (int i = threadIdx.x; i < kernel_size; i += blockDim.x) { + counter_buf2[i] = atomicAdd(&counter[i], counter_buf[i]); + } + __syncthreads(); + for (int i = 0; i < kernel_size; i++) { + if (threadIdx.x < counter_buf[i]) { + // rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = i; + rulebook[i * non_zero_num + counter_buf2[i] + threadIdx.x] = + rulebook_buf[i * blockDim.x + threadIdx.x]; + rulebook[i * non_zero_num + offset + counter_buf2[i] + threadIdx.x] = + rulebook_buf[i * blockDim.x + kernel_size * blockDim.x + threadIdx.x]; + } + } +} + +template +__global__ void GroupIndexs(const int n, + const int kernel_size, + const IntT* indexs, + int* index_counts, + int* index_groups) { + CUDA_KERNEL_LOOP_TYPE(i, n, int64_t) { + IntT index = indexs[i]; + // kernel_size at most + int j = atomicAdd(index_counts + index, 1); + // nnz * kernel_size + index_groups[index * kernel_size + j] = i; + } +} + +// double space to reduce atomicAdd conflict +template +__global__ void GroupIndexsV2(const int rulebook_len, + const int non_zero_num, + const int kernel_size, + const int half_kernel_offset, + const IntT* indexs, + int* index_counts, + int* index_groups) { + CUDA_KERNEL_LOOP_TYPE(i, rulebook_len, int64_t) { + IntT index = indexs[i]; + int* counts_ptr = + i < half_kernel_offset ? index_counts : index_counts + non_zero_num; + int* groups_ptr = i < half_kernel_offset + ? index_groups + : index_groups + non_zero_num * kernel_size; + // conflict kernel_size times at most + int j = atomicAdd(counts_ptr + index, 1); + // nnz * kernel_size + groups_ptr[index * kernel_size + j] = i; + } +} + +inline void CallThrustScan(const GPUContext& dev_ctx, + const int* counter_ptr, + const int kernel_size, + int* offsets_ptr, + int* h_counter_ptr, + int* h_offsets_ptr) { +#ifdef PADDLE_WITH_HIP + thrust::exclusive_scan(thrust::hip::par.on(dev_ctx.stream()), +#else + thrust::exclusive_scan(thrust::cuda::par.on(dev_ctx.stream()), +#endif + counter_ptr, + counter_ptr + kernel_size, + offsets_ptr); + + phi::backends::gpu::GpuMemcpyAsync(h_counter_ptr, + counter_ptr, + kernel_size * sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + + phi::backends::gpu::GpuMemcpyAsync(h_offsets_ptr, + offsets_ptr, + kernel_size * sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); +} + +// the basic algorithm can refer to convolution_kernel.cc or +// the second paper +// example: +// 1. the rulebook: +// the kernel_index: 0, 0, 0, 1, 1, 1, 2, 2, .... +// the out_index(key): 20, 30, 33, 30, 33, 20, 25 +// 2. mark the index of out_index(value): 0, 1, 2, 3, 4, 5, 6, .... +// 3. sorted the (key, value) +// 4. unique the (key, value): +// unique_key: 20, 25, 30, 33 +// unique_values: 0, 2, 3, 5 +// the index of unique_values is: 0, 1, 2, 3 +// 5. update the out_index by unique_key, uniqe_value and the index of +// unique_value: +// the new out_index: 0, 2, 3, 2, 3, 0, 1 +template +int ProductRuleBook(const Context& dev_ctx, + const SparseCooTensor& x, + const std::vector& kernel_sizes, + const std::vector& paddings, + const std::vector& dilations, + const std::vector& strides, + const DDim& out_dims, + const bool subm, + DenseTensor* rulebook, + DenseTensor* counter_per_kernel, + DenseTensor* offsets_per_kernel, + DenseTensor* out_index, + DenseTensor* unique_value, + SparseCooTensor* out, + int* h_counter, + int* h_offsets) { + auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); + const int64_t non_zero_num = x.nnz(); + const auto& non_zero_indices = x.non_zero_indices(); + const IntT* indices_ptr = non_zero_indices.data(); + int* counter_ptr = counter_per_kernel->data(); + int* offsets_ptr = offsets_per_kernel->data(); + int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; + + const auto x_dims = x.dims(); + Dims4D d_x_dims(x_dims[0], x_dims[3], x_dims[2], x_dims[1]); + Dims4D d_kernel_dims(1, kernel_sizes[2], kernel_sizes[1], kernel_sizes[0]); + Dims4D d_out_dims(out_dims[0], out_dims[3], out_dims[2], out_dims[1]); + Dims4D d_paddings(1, paddings[2], paddings[1], paddings[0]); + Dims4D d_strides(1, strides[2], strides[1], strides[0]); + Dims4D d_dilations(1, dilations[2], dilations[1], dilations[0]); + // 1. product rule book + phi::backends::gpu::GpuMemsetAsync(counter_ptr, + 0, + sizeof(int) * counter_per_kernel->numel(), + dev_ctx.stream()); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); + + const int rulebook_rows = 2; + const int rulebook_cols = kernel_size * non_zero_num; + DenseTensorMeta rulebook_meta( + indices_dtype, {rulebook_rows, rulebook_cols}, DataLayout::NCHW); + + int64_t table_size = 1; + for (int i = 0; i < out_dims.size() - 1; i++) { + table_size *= out_dims[i]; + } + DenseTensor out_index_table = phi::Empty(dev_ctx, {table_size}); + int* out_index_table_ptr = out_index_table.data(); + + if (subm) { + DenseTensor tmp_rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); + IntT* rulebook_ptr = tmp_rulebook.data(); + DenseTensor out_indices = + phi::EmptyLike(dev_ctx, x.non_zero_indices()); + DenseTensor out_values = phi::Empty(dev_ctx, {x.nnz(), kernel_sizes[4]}); + + phi::Copy( + dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), false, &out_indices); + + phi::backends::gpu::GpuMemsetAsync( + out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); + + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, non_zero_num, 1); + GetOutIndexTable<<>>( + out_indices.data(), non_zero_num, d_x_dims, out_index_table_ptr); + + size_t cache_size = kernel_size * 2 + kernel_size * + config.thread_per_block.x * 2 * + sizeof(int); + const int MAX_CACHE_SIZE = 48 * 1024; + while (cache_size >= MAX_CACHE_SIZE) { + config.thread_per_block.x /= 2; + config.block_per_grid.x *= 2; + PADDLE_ENFORCE_GE(config.thread_per_block.x, + 32, + phi::errors::Fatal("the shared memory is not enough")); + cache_size = kernel_size * 2 + + kernel_size * config.thread_per_block.x * 2 * sizeof(int); + } + ProductSubmRuleBookKernel<<>>(indices_ptr, + d_x_dims, + d_kernel_dims, + d_out_dims, + non_zero_num, + d_paddings, + d_dilations, + d_strides, + out_index_table_ptr, + rulebook_ptr, + counter_ptr); + + out->SetMember(out_indices, out_values, out_dims, false); + + CallThrustScan( + dev_ctx, counter_ptr, kernel_size, offsets_ptr, h_counter, h_offsets); + + dev_ctx.Wait(); + int rulebook_len = h_offsets[kernel_size - 1] + h_counter[kernel_size - 1]; + DenseTensor out_rulebook = + phi::Empty(dev_ctx, {rulebook_rows, rulebook_len}); + IntT* out_rulebook_ptr = out_rulebook.data(); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + cache_size = kernel_size * 2 * sizeof(int); + CopyRuleBook<<>>(counter_ptr, + offsets_ptr, + rulebook_ptr, + rulebook_len, + kernel_size, + non_zero_num, + out_rulebook_ptr); + *rulebook = out_rulebook; + + return rulebook_len; + + } else { + *rulebook = phi::Empty(dev_ctx, std::move(rulebook_meta)); + IntT* rulebook_ptr = rulebook->data(); + ProductRuleBookKernel<<>>(indices_ptr, + d_x_dims, + d_kernel_dims, + d_out_dims, + non_zero_num, + d_paddings, + d_dilations, + d_strides, + rulebook_ptr, + counter_ptr); + + // 2. remove -1 +#ifdef PADDLE_WITH_HIP + IntT* last = thrust::remove(thrust::hip::par.on(dev_ctx.stream()), +#else + IntT* last = thrust::remove(thrust::cuda::par.on(dev_ctx.stream()), +#endif + rulebook_ptr, + rulebook_ptr + rulebook_rows * rulebook_cols, + -1); + + IntT rulebook_len = (last - rulebook_ptr) / 2; + + CallThrustScan( + dev_ctx, counter_ptr, kernel_size, offsets_ptr, h_counter, h_offsets); + + rulebook->Resize({rulebook_rows, static_cast(rulebook_len)}); + // 3. sorted or merge the out index + out_index->ResizeAndAllocate({static_cast(rulebook_len)}); + DenseTensor unique_key = + phi::Empty(dev_ctx, {static_cast(rulebook_len)}); + int* out_index_ptr = out_index->data(); + int* unique_key_ptr = unique_key.data(); + + phi::backends::gpu::GpuMemsetAsync( + out_index_table_ptr, 0, sizeof(int) * table_size, dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync( + unique_key_ptr, 0, sizeof(int), dev_ctx.stream()); + + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + size_t cache_size = sizeof(int) * config.thread_per_block.x; + UniqueKernel<<>>(rulebook_ptr + rulebook_len, + rulebook_len, + out_index_table_ptr, + out_index_ptr, + unique_key_ptr); + int out_nnz = 0; + phi::backends::gpu::GpuMemcpyAsync(&out_nnz, + unique_key_ptr, + sizeof(int), + gpuMemcpyDeviceToHost, + dev_ctx.stream()); + dev_ctx.Wait(); + + const int64_t sparse_dim = 4; + phi::DenseTensor out_indices = + phi::Empty(dev_ctx, {sparse_dim, out_nnz}); + phi::DenseTensor out_values = + phi::Empty(dev_ctx, {out_nnz, kernel_sizes[4]}); + out->SetMember(out_indices, out_values, out_dims, false); + + IntT* out_indices_ptr = out_indices.data(); + + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, out_nnz, 1); + GetOutIndexTable<<>>(out_index_ptr, + out_nnz, + d_out_dims, + out_index_table_ptr, + out_indices_ptr); + config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + unique_value->ResizeAndAllocate({static_cast(out_nnz * kernel_size)}); + int* unique_value_ptr = unique_value->data(); + + GroupIndexs<<>>(out_index_table_ptr, + rulebook_len, + kernel_size, + rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); + + return rulebook_len; + } +} + +} // namespace sparse +} // namespace phi diff --git a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu index 0ce3558e1d73f..848517aae2549 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_grad_kernel.cu @@ -19,13 +19,11 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" -#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_utils.h" #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" -#include "paddle/phi/kernels/funcs/scatter.cu.h" -#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" namespace phi { namespace sparse { @@ -42,43 +40,42 @@ template void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { const auto& kernel_dims = kernel.dims(); const int kernel_size = kernel_dims[0] * kernel_dims[1] * kernel_dims[2]; const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - const IntT* rulebook_ptr = rulebook.data(); - const int rulebook_len = rulebook.dims()[1]; + int rulebook_len = 0; + const IntT* rulebook_ptr = phi::funcs::sparse::GetRulebookPtr( + out, rulebook, key, &rulebook_len); + const int* counter_ptr = phi::funcs::sparse::GetCounterPtr(out, counter, key); - DenseTensorMeta in_features_meta( - x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); - DenseTensorMeta d_x_features_meta( - x.dtype(), {rulebook_len, in_channels}, DataLayout::NCHW); - DenseTensorMeta out_grad_features_meta( - x.dtype(), {rulebook_len, out_channels}, DataLayout::NCHW); phi::DenseTensor in_features = - phi::Empty(dev_ctx, std::move(in_features_meta)); + phi::Empty(dev_ctx, {rulebook_len, in_channels}); phi::DenseTensor d_x_features = - phi::Empty(dev_ctx, std::move(d_x_features_meta)); + phi::Empty(dev_ctx, {rulebook_len, in_channels}); phi::DenseTensor out_grad_features = - phi::Empty(dev_ctx, std::move(out_grad_features_meta)); + phi::Empty(dev_ctx, {rulebook_len, out_channels}); T* in_features_ptr = in_features.data(); T* d_x_features_ptr = d_x_features.data(); T* out_grad_features_ptr = out_grad_features.data(); *kernel_grad = phi::EmptyLike(dev_ctx, kernel); T* d_kernel_ptr = kernel_grad->data(); - phi::funcs::SetConstant set_zero; - set_zero(dev_ctx, kernel_grad, static_cast(0.0f)); + phi::backends::gpu::GpuMemsetAsync( + d_kernel_ptr, 0, sizeof(T) * kernel_grad->numel(), dev_ctx.stream()); int half_kernel_size = kernel_size / 2; auto blas = phi::funcs::GetBlas(dev_ctx); @@ -86,8 +83,12 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, phi::EmptyLike(dev_ctx, x.non_zero_indices()); DenseTensor x_grad_values = phi::EmptyLike(dev_ctx, x.non_zero_elements()); T* x_grad_values_ptr = x_grad_values.data(); - set_zero(dev_ctx, &x_grad_values, static_cast(0.0f)); - set_zero(dev_ctx, &d_x_features, static_cast(0.0f)); + phi::backends::gpu::GpuMemsetAsync(x_grad_values_ptr, + 0, + sizeof(T) * x_grad_values.numel(), + dev_ctx.stream()); + phi::backends::gpu::GpuMemsetAsync( + d_x_features_ptr, 0, sizeof(T) * d_x_features.numel(), dev_ctx.stream()); phi::Copy(dev_ctx, x.non_zero_indices(), dev_ctx.GetPlace(), @@ -95,29 +96,14 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, &x_grad_indices); x_grad->SetMember(x_grad_indices, x_grad_values, x.dims(), true); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), - h_counter(rulebook_len, 0); - phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], - rulebook_ptr, - rulebook_len * sizeof(IntT), -#ifdef PADDLE_WITH_HIP - hipMemcpyDeviceToHost, -#else - cudaMemcpyDeviceToHost, -#endif - - dev_ctx.stream()); - dev_ctx.Wait(); + std::vector offsets(kernel_size + 1); - for (int i = 0; i < rulebook_len; i++) { - counter[h_counter[i]] += 1; - } - IntT offset = 0, max_count = 0; + int offset = 0, max_count = 0; for (int i = 0; i < kernel_size; i++) { offsets[i] = offset; - offset += counter[i]; + offset += counter_ptr[i]; if (i < half_kernel_size) { - max_count = std::max(max_count, counter[i]); + max_count = std::max(max_count, counter_ptr[i]); } } offsets[kernel_size] = offset; @@ -138,36 +124,52 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, } } - auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + rulebook_len, - in_features_ptr, - rulebook_len, - in_channels); + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + DenseTensor unique_value = phi::Empty( + dev_ctx, {static_cast(x_grad->nnz() * kernel_size * 2)}); + DenseTensor out_index = + phi::Empty(dev_ctx, {static_cast(x.nnz() * 2)}); + int* out_index_ptr = out_index.data(); + int* unique_value_ptr = unique_value.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * x.nnz() * 2, dev_ctx.stream()); - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * out_channels, 1); - GatherKernel - <<>>(out_grad.non_zero_elements().data(), - rulebook_ptr + rulebook_len * 2, - out_grad_features_ptr, - rulebook_len, - out_channels); + GroupIndexsV2<<>>(rulebook_len, + x.nnz(), + kernel_size, + offsets[kernel_size / 2], + rulebook_ptr, + out_index_ptr, + unique_value_ptr); + + GatherV2(dev_ctx, + x.non_zero_elements().data(), + out_index_ptr, + unique_value_ptr, + x.nnz(), + kernel_size, + in_channels, + 2, + in_features_ptr); + + Gather(dev_ctx, + out_grad.non_zero_elements().data(), + rulebook_ptr + rulebook_len, + rulebook_len, + out_channels, + out_grad_features_ptr); const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0 || (subm && i == half_kernel_size)) { + if (counter_ptr[i] <= 0 || (subm && i == half_kernel_size)) { continue; } - const int M = counter[i]; + const int M = counter_ptr[i]; const int K = in_channels; const int N = out_channels; T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; @@ -204,32 +206,31 @@ void Conv3dCooGradGPUKernel(const GPUContext& dev_ctx, } // 4. scatter - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, rulebook_len * in_channels, 1); - - phi::funcs::ScatterCUDAKernel<<>>( - d_x_features_ptr, - rulebook_ptr + rulebook_len, - x_grad_values_ptr, - rulebook_len, - in_channels, - false); + phi::funcs::sparse::ScatterV2(dev_ctx, + d_x_features_ptr, + out_index.data(), + unique_value.data(), + x_grad->nnz(), + kernel_size, + in_channels, + 2, + x_grad_values_ptr); } template void Conv3dCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& kernel, + const SparseCooTensor& out, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out_grad, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { PD_VISIT_INTEGRAL_TYPES( @@ -237,13 +238,16 @@ void Conv3dCooGradKernel(const Context& dev_ctx, Conv3dCooGradGPUKernel(dev_ctx, x, kernel, + out, rulebook, + counter, out_grad, paddings, dilations, strides, groups, subm, + key, x_grad, kernel_grad); })); diff --git a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu index 6820b677147f3..543f3884edcb4 100644 --- a/paddle/phi/kernels/sparse/gpu/conv_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/conv_kernel.cu @@ -21,7 +21,9 @@ limitations under the License. */ #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/scatter.cu.h" #include "paddle/phi/kernels/funcs/sparse/scatter.cu.h" -#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" + +#include "glog/logging.h" namespace phi { namespace sparse { @@ -35,8 +37,10 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { // update padding and dilation // Currently, only support x.layout is NDHWC, groups = 1 // if x.layout != NDHWC then transpose(x), transpose(weight) @@ -61,85 +65,117 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, x_dims, kernel_sizes, subm_paddings, dilations, subm_strides, &out_dims); const int in_channels = kernel_dims[3]; const int out_channels = kernel_dims[4]; - std::vector offsets(kernel_size + 1), h_counter(kernel_size); + DenseTensor h_counter, h_offsets; + h_counter.Resize({kernel_size}); + h_offsets.Resize({kernel_size + 1}); + int* h_counter_ptr = dev_ctx.template HostAlloc(&h_counter); + int* h_offsets_ptr = dev_ctx.template HostAlloc(&h_offsets); // Second algorithm: // https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf // 1. product rulebook - DenseTensorMeta counter_meta( - DataType::INT32, {kernel_size}, DataLayout::NCHW); - DenseTensorMeta offsets_meta( - DataType::INT32, {kernel_size}, DataLayout::NCHW); - DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); - DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, std::move(offsets_meta)); - DenseTensorMeta index_meta(DataType::INT32, {1}, DataLayout::NCHW); - DenseTensor out_index = phi::Empty(dev_ctx, std::move(index_meta)); - DenseTensor unique_value = phi::Empty(dev_ctx, std::move(index_meta)); - - int n = ProductRuleBook(dev_ctx, - x, - kernel_sizes, - subm_paddings, - dilations, - subm_strides, - out_dims, - subm, - rulebook, - &counter_per_kernel, - &offsets_per_kernel, - &out_index, - &unique_value, - out, - &h_counter, - &offsets); - - const int* counter_ptr = counter_per_kernel.data(); - const int* offsets_ptr = counter_per_kernel.data(); - const IntT* rulebook_ptr = rulebook->data(); + DenseTensor counter_per_kernel = phi::Empty(dev_ctx, {kernel_size}); + DenseTensor offsets_per_kernel = phi::Empty(dev_ctx, {kernel_size}); + DenseTensor out_index = phi::Empty(dev_ctx, {1}); + DenseTensor unique_value = phi::Empty(dev_ctx, {1}); + + VLOG(6) << "call SubmConv3D or Conv3D " << subm << " and the key is " << key; + int rulebook_len = 0; + const IntT* rulebook_ptr = nullptr; + bool need_product_rulebook = true; + if (subm && !key.empty()) { + rulebook_ptr = phi::funcs::sparse::PrepareSubm( + dev_ctx, + x, + key, + out_dims, + out, + h_counter.data(), + h_offsets.data(), + &rulebook_len, + &need_product_rulebook); + } + + if (need_product_rulebook) { + DenseTensor tmp_rulebook; + rulebook_len = ProductRuleBook(dev_ctx, + x, + kernel_sizes, + subm_paddings, + dilations, + subm_strides, + out_dims, + subm, + &tmp_rulebook, + &counter_per_kernel, + &offsets_per_kernel, + &out_index, + &unique_value, + out, + h_counter_ptr, + h_offsets_ptr); + rulebook_ptr = tmp_rulebook.data(); + + phi::funcs::sparse::SaveToTable( + dev_ctx, x, key, tmp_rulebook, h_counter, out, rulebook, counter); + } // 2. gather - DenseTensorMeta in_features_meta( - x.dtype(), {n, in_channels}, DataLayout::NCHW); - DenseTensorMeta out_features_meta( - x.dtype(), {n, out_channels}, DataLayout::NCHW); phi::DenseTensor in_features = - phi::Empty(dev_ctx, std::move(in_features_meta)); + phi::Empty(dev_ctx, {rulebook_len, in_channels}); phi::DenseTensor out_features = - phi::Empty(dev_ctx, std::move(out_features_meta)); + phi::Empty(dev_ctx, {rulebook_len, out_channels}); T* in_features_ptr = in_features.data(); T* out_features_ptr = out_features.data(); phi::funcs::SetConstant set_zero; set_zero(dev_ctx, &out_features, static_cast(0.0f)); - auto config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * in_channels, 1); - GatherKernel<<>>(x.non_zero_elements().data(), - rulebook_ptr + n, - in_features_ptr, - n, - in_channels); + Gather(dev_ctx, + x.non_zero_elements().data(), + rulebook_ptr, + rulebook_len, + in_channels, + in_features_ptr); // 3. call gemm for every werght auto blas = phi::funcs::GetBlas(dev_ctx); auto* out_values = out->mutable_non_zero_elements(); T* out_values_ptr = out_values->data(); + set_zero(dev_ctx, out_values, static_cast(0.0f)); + + if (subm) { + auto config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, rulebook_len, 1); + unique_value.ResizeAndAllocate( + {static_cast(out->nnz() * kernel_size)}); + out_index.ResizeAndAllocate({static_cast(rulebook_len)}); + int* out_index_ptr = out_index.data(); + int* unique_value_ptr = unique_value.data(); + phi::backends::gpu::GpuMemsetAsync( + out_index_ptr, 0, sizeof(int) * rulebook_len, dev_ctx.stream()); + GroupIndexs<<>>(rulebook_len, + kernel_size, + rulebook_ptr + rulebook_len, + out_index_ptr, + unique_value_ptr); + } const T* kernel_ptr = kernel.data(); for (int i = 0; i < kernel_size; i++) { - if (h_counter[i] <= 0) { + if (h_counter_ptr[i] <= 0) { continue; } // call gemm: (n, in_channels) * (in_channels, out_channels) - const int M = h_counter[i]; + const int M = h_counter_ptr[i]; const int K = in_channels; const int N = out_channels; - T* tmp_in_ptr = in_features_ptr + offsets[i] * in_channels; + T* tmp_in_ptr = in_features_ptr + h_offsets_ptr[i] * in_channels; const T* tmp_kernel_ptr = kernel_ptr + i * K * N; - T* tmp_out_ptr = out_features_ptr + offsets[i] * out_channels; + T* tmp_out_ptr = out_features_ptr + h_offsets_ptr[i] * out_channels; blas.GEMM(CblasNoTrans, CblasNoTrans, @@ -154,40 +190,23 @@ void Conv3dCooGPUKernel(const GPUContext& dev_ctx, } // 4. scatter - if (subm) { - set_zero(dev_ctx, out_values, static_cast(0.0f)); - config = - phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, n * out_channels, 1); - phi::funcs::ScatterCUDAKernel - <<>>(out_features_ptr, - rulebook_ptr + 2 * n, - out_values_ptr, - n, - out_channels, - false); - } else { - config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, out->nnz() * out_channels, 1); - phi::funcs::sparse::ScatterKernel - <<>>(out_features_ptr, - unique_value.data(), - out_index.data(), - out->nnz(), - n, - out_channels, - out_values_ptr); - } + phi::funcs::sparse::ScatterV2(dev_ctx, + out_features_ptr, + out_index.data(), + unique_value.data(), + out->nnz(), + kernel_size, + out_channels, + 1, + out_values_ptr); } + /** - * x: (N, D, H, W, C) - * kernel: (D, H, W, C, OC) - * out: (N, D, H, W, OC) + * x: the input SparseCooTensor, shape is (N, D, H, W, C) + * kernel: the weight data, shape is (D, H, W, C, OC) + * out: the output SparseCooTensor, shape is (N, D, H, W, OC) + * rulebook: return rulebook if key is not vailed else return nullptr + * counter: return counter if key is not vailed else return nullptr **/ template void Conv3dCooKernel(const Context& dev_ctx, @@ -198,8 +217,10 @@ void Conv3dCooKernel(const Context& dev_ctx, const std::vector& strides, const int groups, const bool subm, + const std::string& key, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dCooGPUKernel", ([&] { Conv3dCooGPUKernel(dev_ctx, @@ -210,8 +231,10 @@ void Conv3dCooKernel(const Context& dev_ctx, strides, groups, subm, + key, out, - rulebook); + rulebook, + counter); })); } diff --git a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu index 39fa89c0379b7..35d63b7630930 100644 --- a/paddle/phi/kernels/sparse/gpu/mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/mask_kernel.cu @@ -238,6 +238,7 @@ void SparseMaskHelperGPUKernel(const GPUContext& dev_ctx, x_indexs_ptr, x_indexs.numel(), table.data()); config = phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, mask_indexs.numel(), 1); + const int VecBytes = 16; const int VecSize = VecBytes / sizeof(T); if (stride % VecSize == 0) { diff --git a/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu index 724072443a9ed..07bdd68a7f7b4 100644 --- a/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/pool_grad_kernel.cu @@ -55,6 +55,7 @@ template void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, @@ -63,23 +64,9 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, const int in_channels = x.dims()[4]; int rulebook_len = rulebook.dims()[1]; const IntT* rulebook_ptr = rulebook.data(); - std::vector offsets(kernel_size + 1), counter(kernel_size, 0), - h_counter(rulebook_len, 0); - phi::backends::gpu::GpuMemcpyAsync(&h_counter[0], - rulebook_ptr, - rulebook_len * sizeof(IntT), -#ifdef PADDLE_WITH_HIP - hipMemcpyDeviceToHost, -#else - cudaMemcpyDeviceToHost, -#endif - - dev_ctx.stream()); - dev_ctx.Wait(); - for (int i = 0; i < rulebook_len; i++) { - counter[h_counter[i]] += 1; - } - phi::funcs::sparse::PrefixSum(&counter[0], &offsets[0], kernel_size); + std::vector offsets(kernel_size + 1); + const int* counter_ptr = counter.data(); + phi::funcs::sparse::PrefixSum(counter_ptr, &offsets[0], kernel_size); const T* in_features_ptr = x.non_zero_elements().data(); const T* out_features_ptr = out.non_zero_elements().data(); @@ -99,12 +86,12 @@ void MaxPoolCooGradGPUKernel(const GPUContext& dev_ctx, &x_grad_indices); for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0) { + if (counter_ptr[i] <= 0) { continue; } auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, counter[i] * in_channels, 1); + dev_ctx, counter_ptr[i] * in_channels, 1); MaxPoolGradCudaKernel <<>>(in_features_ptr, out_features_ptr, out_grad_ptr, - rulebook_ptr + offsets[i] + rulebook_len, - counter[i], + rulebook_ptr + offsets[i], + counter_ptr[i], rulebook_len, in_channels, x_grad_ptr); @@ -124,6 +111,7 @@ template void MaxPoolCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, @@ -131,7 +119,7 @@ void MaxPoolCooGradKernel(const Context& dev_ctx, PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolCooGradGPUKernel", ([&] { MaxPoolCooGradGPUKernel( - dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); + dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, x_grad); })); } diff --git a/paddle/phi/kernels/sparse/gpu/pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/pool_kernel.cu index 0d24594f0a85f..8b1888e0a64cf 100644 --- a/paddle/phi/kernels/sparse/gpu/pool_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/pool_kernel.cu @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" +#include "paddle/phi/kernels/sparse/gpu/conv.cu.h" namespace phi { namespace sparse { @@ -55,7 +55,8 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, const std::vector& dilations, const std::vector& strides, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { const auto& x_dims = x.dims(); int kernel_size = kernel_sizes[0] * kernel_sizes[1] * kernel_sizes[2]; const std::vector& real_kernel_sizes = @@ -65,7 +66,7 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, x_dims, real_kernel_sizes, paddings, dilations, strides, &out_dims); const int in_channels = real_kernel_sizes[3]; - std::vector offsets(kernel_size + 1), counter(kernel_size); + std::vector offsets(kernel_size + 1), h_counter(kernel_size); DenseTensorMeta counter_meta( DataType::INT32, {kernel_size}, DataLayout::NCHW); DenseTensor counter_per_kernel = phi::Empty(dev_ctx, std::move(counter_meta)); @@ -89,13 +90,16 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, &out_index, &unique_value, out, - &counter, - &offsets); + h_counter.data(), + offsets.data()); const IntT* rulebook_ptr = rulebook->data(); T* out_features_ptr = out->mutable_non_zero_elements()->data(); const T* in_features_ptr = x.non_zero_elements().data(); + counter->Resize({kernel_size}); + int* counter_ptr = dev_ctx.template HostAlloc(counter); + memcpy(counter_ptr, h_counter.data(), h_counter.size() * sizeof(int)); // 2. max pool #ifdef PADDLE_WITH_HIP thrust::fill(thrust::hip::par.on(dev_ctx.stream()), @@ -107,22 +111,21 @@ void MaxPoolCooGPUKernel(const GPUContext& dev_ctx, static_cast(0)); // TODO(zhangkaihuo) Replacing multiple calls with one kernel may be faster for (int i = 0; i < kernel_size; i++) { - if (counter[i] <= 0) { + if (h_counter[i] <= 0) { continue; } auto config = phi::backends::gpu::GetGpuLaunchConfig1D( - dev_ctx, counter[i] * in_channels, 1); - MaxPoolCudaKernel - <<>>(in_features_ptr, - rulebook_ptr + offsets[i] + rulebook_len, - counter[i], - rulebook_len, - in_channels, - out_features_ptr); + dev_ctx, h_counter[i] * in_channels, 1); + MaxPoolCudaKernel<<>>(in_features_ptr, + rulebook_ptr + offsets[i], + h_counter[i], + rulebook_len, + in_channels, + out_features_ptr); } } @@ -134,7 +137,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, SparseCooTensor* out, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolCooGPUKernel", ([&] { MaxPoolCooGPUKernel(dev_ctx, @@ -144,7 +148,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, dilations, strides, out, - rulebook); + rulebook, + counter); })); } diff --git a/paddle/phi/kernels/sparse/pool_grad_kernel.h b/paddle/phi/kernels/sparse/pool_grad_kernel.h index 6afcbfea6ca26..47f3df6f9fb73 100644 --- a/paddle/phi/kernels/sparse/pool_grad_kernel.h +++ b/paddle/phi/kernels/sparse/pool_grad_kernel.h @@ -25,6 +25,7 @@ template void MaxPoolCooGradKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, @@ -34,12 +35,13 @@ template SparseCooTensor MaxPoolCooGrad(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& rulebook, + const DenseTensor& counter, const SparseCooTensor& out, const SparseCooTensor& out_grad, const std::vector& kernel_sizes) { SparseCooTensor x_grad; MaxPoolCooGradKernel( - dev_ctx, x, rulebook, out, out_grad, kernel_sizes, &x_grad); + dev_ctx, x, rulebook, counter, out, out_grad, kernel_sizes, &x_grad); return x_grad; } diff --git a/paddle/phi/kernels/sparse/pool_kernel.h b/paddle/phi/kernels/sparse/pool_kernel.h index 95349291efb69..085cea82d3350 100644 --- a/paddle/phi/kernels/sparse/pool_kernel.h +++ b/paddle/phi/kernels/sparse/pool_kernel.h @@ -29,7 +29,8 @@ void MaxPoolCooKernel(const Context& dev_ctx, const std::vector& dilations, const std::vector& strides, SparseCooTensor* out, - DenseTensor* rulebook); + DenseTensor* rulebook, + DenseTensor* counter); template SparseCooTensor MaxPoolCoo(const Context& dev_ctx, @@ -38,10 +39,18 @@ SparseCooTensor MaxPoolCoo(const Context& dev_ctx, const std::vector& paddings, const std::vector& dilations, const std::vector& strides, - DenseTensor* rulebook) { + DenseTensor* rulebook, + DenseTensor* counter) { SparseCooTensor coo; - MaxPoolCooKernel( - dev_ctx, x, kernel_sizes, paddings, dilations, strides, &coo, rulebook); + MaxPoolCooKernel(dev_ctx, + x, + kernel_sizes, + paddings, + dilations, + strides, + &coo, + rulebook, + counter); return coo; } diff --git a/paddle/phi/tests/api/test_sparse_conv_api.cc b/paddle/phi/tests/api/test_sparse_conv_api.cc index 95f4afe4d1540..b1df197f42f47 100644 --- a/paddle/phi/tests/api/test_sparse_conv_api.cc +++ b/paddle/phi/tests/api/test_sparse_conv_api.cc @@ -76,8 +76,8 @@ void TestConv3dBase(const std::vector& indices, kernel.size() * sizeof(T)); if (!std::is_same::value) { - auto tensor_out = paddle::experimental::sparse::conv3d( - x, weight, paddings, dilations, strides, 1, false); + auto tensor_out = paddle::experimental::sparse::conv3d_coo( + x, weight, paddings, dilations, strides, 1, false, "Conv3d"); auto out = std::dynamic_pointer_cast(tensor_out.impl()); diff --git a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc index 4a39f2bd8f1c4..f7c7b7e9486ee 100644 --- a/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_conv3d_dev_api.cc @@ -112,8 +112,7 @@ void TestConv3dBase(const std::vector& indices, }; if (!std::is_same::value) { - DenseTensor rulebook = phi::Empty( - dev_ctx_cpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); + DenseTensor rulebook, counter; SparseCooTensor out = sparse::Conv3dCoo(dev_ctx_cpu, x_tensor, kernel_tensor, @@ -122,7 +121,9 @@ void TestConv3dBase(const std::vector& indices, strides, 1, subm, - &rulebook); + "Conv3d", + &rulebook, + &counter); ASSERT_EQ(correct_out_dims.size(), out.dims().size()); for (int i = 0; i < correct_out_dims.size(); i++) { @@ -142,13 +143,16 @@ void TestConv3dBase(const std::vector& indices, sparse::Conv3dCooGrad(dev_ctx_cpu, x_tensor, kernel_tensor, + out, rulebook, + counter, out, paddings, dilations, strides, 1, - subm); + subm, + "Conv3d"); f_verify(std::get<0>(grads).non_zero_elements().data(), features_grad); f_verify(std::get<1>(grads).data(), kernel_grad); } @@ -196,8 +200,7 @@ void TestConv3dBase(const std::vector& indices, phi::Copy( dev_ctx_gpu, kernel_tensor, phi::GPUPlace(), true, &d_kernel_tensor); - DenseTensor d_rulebook = phi::Empty( - dev_ctx_gpu, DenseTensorMeta(indices_dtype, {1}, DataLayout::NCHW)); + DenseTensor d_rulebook, d_counter; SparseCooTensor d_out = sparse::Conv3dCoo(dev_ctx_gpu, d_x_tensor, d_kernel_tensor, @@ -206,8 +209,9 @@ void TestConv3dBase(const std::vector& indices, strides, 1, subm, - &d_rulebook); - + "Conv3d", + &d_rulebook, + &d_counter); SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); @@ -245,13 +249,16 @@ void TestConv3dBase(const std::vector& indices, sparse::Conv3dCooGrad(dev_ctx_gpu, d_x_tensor, d_kernel_tensor, + d_out, d_rulebook, + d_counter, d_out, paddings, dilations, strides, 1, - subm); + subm, + "Conv3d"); DenseTensor d_features_grad = std::get<0>(grads).non_zero_elements(); DenseTensor d_kernel_grad = std::get<1>(grads); DenseTensor h_features_grad = diff --git a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc index 08f8cd8a73273..ffc2604e6b8da 100644 --- a/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc +++ b/paddle/phi/tests/kernels/test_sparse_pool_dev_api.cc @@ -90,14 +90,15 @@ void TestMaxPoolBase(const std::vector& indices, }; if (!std::is_same::value) { - DenseTensor rulebook; + DenseTensor rulebook, counter; SparseCooTensor out = sparse::MaxPoolCoo(dev_ctx_cpu, x_tensor, kernel_sizes, paddings, dilations, strides, - &rulebook); + &rulebook, + &counter); ASSERT_EQ(correct_out_dims.size(), out.dims().size()); for (int i = 0; i < correct_out_dims.size(); i++) { @@ -114,7 +115,7 @@ void TestMaxPoolBase(const std::vector& indices, if (backward) { SparseCooTensor x_grad = sparse::MaxPoolCooGrad( - dev_ctx_cpu, x_tensor, rulebook, out, out, kernel_sizes); + dev_ctx_cpu, x_tensor, rulebook, counter, out, out, kernel_sizes); f_verify(x_grad.non_zero_elements().data(), features_grad); } } @@ -150,14 +151,16 @@ void TestMaxPoolBase(const std::vector& indices, SparseCooTensor d_x_tensor(d_indices_tensor, d_features_tensor, x_dims); - DenseTensor d_rulebook; + DenseTensor d_rulebook, d_counter; SparseCooTensor d_out = sparse::MaxPoolCoo(dev_ctx_gpu, d_x_tensor, kernel_sizes, paddings, dilations, strides, - &d_rulebook); + &d_rulebook, + &d_counter); + SparseCooTensor tmp_d_out = sparse::Coalesce(dev_ctx_gpu, d_out); ASSERT_EQ(correct_out_dims.size(), d_out.dims().size()); @@ -191,8 +194,13 @@ void TestMaxPoolBase(const std::vector& indices, f_verify(h_features_tensor.data(), correct_out_features); if (backward) { - SparseCooTensor x_grad = sparse::MaxPoolCooGrad( - dev_ctx_gpu, d_x_tensor, d_rulebook, d_out, d_out, kernel_sizes); + SparseCooTensor x_grad = sparse::MaxPoolCooGrad(dev_ctx_gpu, + d_x_tensor, + d_rulebook, + d_counter, + d_out, + d_out, + kernel_sizes); DenseTensor h_features_grad = phi::EmptyLike(dev_ctx_cpu, x_grad.non_zero_elements()); phi::Copy(dev_ctx_gpu, diff --git a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py index 9501b2c89531f..36ecfeccd1a1d 100644 --- a/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py +++ b/python/paddle/fluid/tests/unittests/test_sparse_conv_op.py @@ -67,7 +67,7 @@ def test_subm_conv3d(self): indices, values, dense_shape, stop_gradient=True) weight = paddle.randn((1, 3, 3, 1, 1), dtype='float32') y = paddle.incubate.sparse.nn.functional.subm_conv3d( - sparse_x, weight) + sparse_x, weight, key='subm_conv') assert np.array_equal(sparse_x.indices().numpy(), y.indices().numpy()) @@ -91,7 +91,7 @@ def test_Conv3D(self): with self.assertRaises(ValueError): #Currently, only support data_format='NDHWC' conv3d = paddle.incubate.sparse.nn.SubmConv3D( - 1, 1, (1, 3, 3), data_format='NCDHW') + 1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv') def test_SubmConv3D(self): with _test_eager_guard(): @@ -105,7 +105,7 @@ def test_SubmConv3D(self): indices, values, dense_shape, False) subm_conv3d = paddle.incubate.sparse.nn.SubmConv3D( - 1, 1, (1, 3, 3), data_format='NDHWC') + 1, 1, (1, 3, 3), data_format='NDHWC', key='subm_conv') # test extra_repr print(subm_conv3d.extra_repr()) @@ -117,7 +117,7 @@ def test_SubmConv3D(self): with self.assertRaises(ValueError): #Currently, only support data_format='NDHWC' conv3d = paddle.incubate.sparse.nn.SubmConv3D( - 1, 1, (1, 3, 3), data_format='NCDHW') + 1, 1, (1, 3, 3), data_format='NCDHW', key='subm_conv') def test_Conv3D_bias(self): with _test_eager_guard(): diff --git a/python/paddle/incubate/sparse/nn/functional/conv.py b/python/paddle/incubate/sparse/nn/functional/conv.py index 75c0514da8e0e..60cbb94bea236 100644 --- a/python/paddle/incubate/sparse/nn/functional/conv.py +++ b/python/paddle/incubate/sparse/nn/functional/conv.py @@ -29,6 +29,7 @@ def _conv3d(x, dilation=1, groups=1, subm=False, + key=None, data_format="NDHWC", name=None): assert in_dynamic_mode(), "Currently, only support dynamic mode" @@ -62,8 +63,9 @@ def _conv3d(x, dilation = convert_to_list(dilation, dims, 'dilation') op_type = "conv3d" - pre_bias = _C_ops.final_state_sparse_conv3d(x, weight, padding, dilation, - stride, groups, subm) + pre_bias = _C_ops.final_state_sparse_conv3d_coo( + x, weight, padding, dilation, stride, groups, subm, + key if key is not None else "") if bias is not None: values = pre_bias.values() add_bias = elementwise_add(values, bias, axis=1) @@ -186,7 +188,7 @@ def conv3d(x, # (1, 1, 1, 2, 1) """ return _conv3d(x, weight, bias, stride, padding, dilation, groups, False, - data_format, name) + None, data_format, name) def subm_conv3d(x, @@ -197,6 +199,7 @@ def subm_conv3d(x, dilation=1, groups=1, data_format="NDHWC", + key=None, name=None): r""" @@ -274,6 +277,10 @@ def subm_conv3d(x, will be consistent with that of the input. An optional string from: `"NCDHW"`, `"NDHWC"`. The default is `"NDHWC"`. When it is `"NDHWC"`, the data is stored in the order of: `[batch_size, input_depth, input_height, input_width, input_channels]`. + key(str, optional): the key is used to save or use the same rulebook, + the definition and role of rulebook refers to + https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The + default value is None. name(str|None): For detailed information, please refer to :ref:`api_guide_Name`. Usually name is no need to set and None by default. @@ -301,4 +308,4 @@ def subm_conv3d(x, #(1, 1, 3, 4, 1) """ return _conv3d(x, weight, bias, stride, padding, dilation, groups, True, - data_format, name) + key, data_format, name) diff --git a/python/paddle/incubate/sparse/nn/layer/conv.py b/python/paddle/incubate/sparse/nn/layer/conv.py index 74574ee61e3c2..f44358bbe9f3e 100644 --- a/python/paddle/incubate/sparse/nn/layer/conv.py +++ b/python/paddle/incubate/sparse/nn/layer/conv.py @@ -33,6 +33,7 @@ def __init__(self, dilation=1, groups=1, subm=False, + key=None, padding_mode='zeros', weight_attr=None, bias_attr=None, @@ -46,6 +47,7 @@ def __init__(self, self._out_channels = out_channels self._data_format = data_format self._subm = subm + self._key = key assert padding_mode == 'zeros', "Currently, only support padding_mode='zeros'" assert groups == 1, "Currently, only support groups=1" @@ -95,6 +97,7 @@ def forward(self, x): dilation=self._dilation, groups=self._groups, subm=self._subm, + key=self._key, data_format=self._data_format) return out @@ -240,6 +243,7 @@ def __init__(self, dilation=dilation, groups=groups, subm=False, + key=None, padding_mode=padding_mode, weight_attr=weight_attr, bias_attr=bias_attr, @@ -293,6 +297,10 @@ class SubmConv3D(_Conv3D): of the input channels, while the second half of the filters is only connected to the second half of the input channels. The default value is 1. padding_mode(str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Currently only support ``'zeros'``. + key(str, optional): the key is used to save or use the same rulebook, + the definition and role of rulebook refers to + https://pdfs.semanticscholar.org/5125/a16039cabc6320c908a4764f32596e018ad3.pdf. The + default value is None. weight_attr(ParamAttr, optional): The parameter attribute for learnable parameters/weights of conv3d. If it is set to None or one attribute of ParamAttr, conv3d will create ParamAttr as param_attr. If it is set to None, the parameter @@ -361,6 +369,7 @@ def __init__(self, dilation=1, groups=1, padding_mode='zeros', + key=None, weight_attr=None, bias_attr=None, data_format="NDHWC"): @@ -372,6 +381,7 @@ def __init__(self, dilation=dilation, groups=groups, subm=True, + key=key, padding_mode=padding_mode, weight_attr=weight_attr, bias_attr=bias_attr,