Skip to content

Commit 9396014

Browse files
fix index_put (#74944)
1 parent ec1b290 commit 9396014

File tree

1 file changed

+33
-26
lines changed

1 file changed

+33
-26
lines changed

paddle/phi/kernels/stride/indexing.cu

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "paddle/phi/kernels/funcs/stride_utils.h"
2828
#include "paddle/phi/kernels/funcs/strided_utils.h"
2929
#include "paddle/phi/kernels/index_put_kernel.h"
30+
#include "paddle/phi/kernels/stride/elementwise_stride_base.cu.h"
3031

3132
#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
3233
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
@@ -74,20 +75,6 @@ inline bool CheckIsDimsMatchBool(const DDim& first, const DDim& second) {
7475
return false;
7576
}
7677

77-
template <typename Context>
78-
phi::DenseTensor Tensor2Contiguous(const Context& dev_ctx,
79-
const phi::DenseTensor& tensor) {
80-
phi::DenseTensor dense_out;
81-
phi::MetaTensor meta_input(tensor);
82-
phi::MetaTensor meta_out(&dense_out);
83-
UnchangedInferMeta(meta_input, &meta_out);
84-
PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] {
85-
phi::ContiguousKernel<data_t, Context>(
86-
dev_ctx, tensor, &dense_out);
87-
}));
88-
return dense_out;
89-
}
90-
9178
template <typename T, typename Context>
9279
void LaunchIndexPutKernel_V2(const Context& dev_ctx,
9380
const DenseTensor& x,
@@ -110,11 +97,41 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
11097
false,
11198
common::errors::InvalidArgument("Indices cannot be empty."));
11299

100+
bool is_initialized = out->initialized();
101+
auto meta = x.meta();
102+
meta.dims = out->dims();
103+
meta.strides = meta.calc_strides(out->dims());
104+
out->set_meta(meta);
105+
T* out_data = dev_ctx.template Alloc<T>(out);
106+
if (!is_initialized) {
107+
if (!x.meta().is_contiguous() || x.offset() != 0) {
108+
StridedTensorCopy<T>(x,
109+
common::vectorize<int64_t>(out->dims()),
110+
common::vectorize<int64_t>(out->strides()),
111+
0,
112+
out);
113+
} else {
114+
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
115+
}
116+
}
117+
113118
funcs::AdvancedIndex ad =
114-
funcs::AdvancedIndex<T, Context>(dev_ctx, x, indices);
119+
funcs::AdvancedIndex<T, Context>(dev_ctx, *out, indices);
115120
if (!CheckIsDimsMatchBool(ad.src.dims(), value.dims())) {
121+
DenseTensor x_;
122+
DenseTensor value_;
123+
if (!x.meta().is_contiguous() || x.offset() != 0) {
124+
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
125+
} else {
126+
x_ = x;
127+
}
128+
if (!value.meta().is_contiguous() || value.offset() != 0) {
129+
value_ = Tensor2Contiguous<Context>(dev_ctx, value);
130+
} else {
131+
value_ = value;
132+
}
116133
phi::IndexPutKernel<T, Context>(
117-
dev_ctx, x, indices, value, accumulate, out);
134+
dev_ctx, x_, indices, value_, accumulate, out);
118135
return;
119136
}
120137

@@ -151,16 +168,6 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
151168

152169
auto* val_data = value.data<T>();
153170

154-
bool is_initialized = out->initialized();
155-
T* out_data = dev_ctx.template Alloc<T>(out);
156-
if (!is_initialized) {
157-
StridedTensorCopy<T>(x,
158-
common::vectorize<int64_t>(x.dims()),
159-
common::vectorize<int64_t>(x.strides()),
160-
x.offset(),
161-
out);
162-
}
163-
164171
const char* in_ptr = reinterpret_cast<const char*>(val_data);
165172
char* out_ptr = reinterpret_cast<char*>(out_data);
166173
funcs::index_put_kernel<nt, vt, T><<<grid, block, 0, stream>>>(

0 commit comments

Comments
 (0)