diff --git a/paddle/phi/kernels/stride/indexing.cu b/paddle/phi/kernels/stride/indexing.cu index 638a31eb9cf47d..ba61b2b1e14498 100644 --- a/paddle/phi/kernels/stride/indexing.cu +++ b/paddle/phi/kernels/stride/indexing.cu @@ -27,6 +27,7 @@ #include "paddle/phi/kernels/funcs/stride_utils.h" #include "paddle/phi/kernels/funcs/strided_utils.h" #include "paddle/phi/kernels/index_put_kernel.h" +#include "paddle/phi/kernels/stride/elementwise_stride_base.cu.h" #if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__) #include "paddle/phi/kernels/funcs/dims_simplifier.h" @@ -74,20 +75,6 @@ inline bool CheckIsDimsMatchBool(const DDim& first, const DDim& second) { return false; } -template -phi::DenseTensor Tensor2Contiguous(const Context& dev_ctx, - const phi::DenseTensor& tensor) { - phi::DenseTensor dense_out; - phi::MetaTensor meta_input(tensor); - phi::MetaTensor meta_out(&dense_out); - UnchangedInferMeta(meta_input, &meta_out); - PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] { - phi::ContiguousKernel( - dev_ctx, tensor, &dense_out); - })); - return dense_out; -} - template void LaunchIndexPutKernel_V2(const Context& dev_ctx, const DenseTensor& x, @@ -110,11 +97,41 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx, false, common::errors::InvalidArgument("Indices cannot be empty.")); + bool is_initialized = out->initialized(); + auto meta = x.meta(); + meta.dims = out->dims(); + meta.strides = meta.calc_strides(out->dims()); + out->set_meta(meta); + T* out_data = dev_ctx.template Alloc(out); + if (!is_initialized) { + if (!x.meta().is_contiguous() || x.offset() != 0) { + StridedTensorCopy(x, + common::vectorize(out->dims()), + common::vectorize(out->strides()), + 0, + out); + } else { + phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out); + } + } + funcs::AdvancedIndex ad = - funcs::AdvancedIndex(dev_ctx, x, indices); + funcs::AdvancedIndex(dev_ctx, *out, indices); if (!CheckIsDimsMatchBool(ad.src.dims(), value.dims())) { + DenseTensor x_; + DenseTensor value_; + if (!x.meta().is_contiguous() || x.offset() != 0) { + x_ = Tensor2Contiguous(dev_ctx, x); + } else { + x_ = x; + } + if (!value.meta().is_contiguous() || value.offset() != 0) { + value_ = Tensor2Contiguous(dev_ctx, value); + } else { + value_ = value; + } phi::IndexPutKernel( - dev_ctx, x, indices, value, accumulate, out); + dev_ctx, x_, indices, value_, accumulate, out); return; } @@ -151,16 +168,6 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx, auto* val_data = value.data(); - bool is_initialized = out->initialized(); - T* out_data = dev_ctx.template Alloc(out); - if (!is_initialized) { - StridedTensorCopy(x, - common::vectorize(x.dims()), - common::vectorize(x.strides()), - x.offset(), - out); - } - const char* in_ptr = reinterpret_cast(val_data); char* out_ptr = reinterpret_cast(out_data); funcs::index_put_kernel<<>>(