2727#include " paddle/phi/kernels/funcs/stride_utils.h"
2828#include " paddle/phi/kernels/funcs/strided_utils.h"
2929#include " paddle/phi/kernels/index_put_kernel.h"
30+ #include " paddle/phi/kernels/stride/elementwise_stride_base.cu.h"
3031
3132#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
3233#include " paddle/phi/kernels/funcs/dims_simplifier.h"
@@ -74,20 +75,6 @@ inline bool CheckIsDimsMatchBool(const DDim& first, const DDim& second) {
7475 return false ;
7576}
7677
77- template <typename Context>
78- phi::DenseTensor Tensor2Contiguous (const Context& dev_ctx,
79- const phi::DenseTensor& tensor) {
80- phi::DenseTensor dense_out;
81- phi::MetaTensor meta_input (tensor);
82- phi::MetaTensor meta_out (&dense_out);
83- UnchangedInferMeta (meta_input, &meta_out);
84- PD_VISIT_ALL_TYPES (tensor.dtype (), " Tensor2Contiguous" , ([&] {
85- phi::ContiguousKernel<data_t , Context>(
86- dev_ctx, tensor, &dense_out);
87- }));
88- return dense_out;
89- }
90-
9178template <typename T, typename Context>
9279void LaunchIndexPutKernel_V2 (const Context& dev_ctx,
9380 const DenseTensor& x,
@@ -110,11 +97,41 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
11097 false ,
11198 common::errors::InvalidArgument (" Indices cannot be empty." ));
11299
100+ bool is_initialized = out->initialized ();
101+ auto meta = x.meta ();
102+ meta.dims = out->dims ();
103+ meta.strides = meta.calc_strides (out->dims ());
104+ out->set_meta (meta);
105+ T* out_data = dev_ctx.template Alloc <T>(out);
106+ if (!is_initialized) {
107+ if (!x.meta ().is_contiguous () || x.offset () != 0 ) {
108+ StridedTensorCopy<T>(x,
109+ common::vectorize<int64_t >(out->dims ()),
110+ common::vectorize<int64_t >(out->strides ()),
111+ 0 ,
112+ out);
113+ } else {
114+ phi::Copy (dev_ctx, x, dev_ctx.GetPlace (), false , out);
115+ }
116+ }
117+
113118 funcs::AdvancedIndex ad =
114- funcs::AdvancedIndex<T, Context>(dev_ctx, x , indices);
119+ funcs::AdvancedIndex<T, Context>(dev_ctx, *out , indices);
115120 if (!CheckIsDimsMatchBool (ad.src .dims (), value.dims ())) {
121+ DenseTensor x_;
122+ DenseTensor value_;
123+ if (!x.meta ().is_contiguous () || x.offset () != 0 ) {
124+ x_ = Tensor2Contiguous<Context>(dev_ctx, x);
125+ } else {
126+ x_ = x;
127+ }
128+ if (!value.meta ().is_contiguous () || value.offset () != 0 ) {
129+ value_ = Tensor2Contiguous<Context>(dev_ctx, value);
130+ } else {
131+ value_ = value;
132+ }
116133 phi::IndexPutKernel<T, Context>(
117- dev_ctx, x , indices, value , accumulate, out);
134+ dev_ctx, x_ , indices, value_ , accumulate, out);
118135 return ;
119136 }
120137
@@ -151,16 +168,6 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
151168
152169 auto * val_data = value.data <T>();
153170
154- bool is_initialized = out->initialized ();
155- T* out_data = dev_ctx.template Alloc <T>(out);
156- if (!is_initialized) {
157- StridedTensorCopy<T>(x,
158- common::vectorize<int64_t >(x.dims ()),
159- common::vectorize<int64_t >(x.strides ()),
160- x.offset (),
161- out);
162- }
163-
164171 const char * in_ptr = reinterpret_cast <const char *>(val_data);
165172 char * out_ptr = reinterpret_cast <char *>(out_data);
166173 funcs::index_put_kernel<nt, vt, T><<<grid, block, 0 , stream>>> (
0 commit comments