Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 33 additions & 26 deletions paddle/phi/kernels/stride/indexing.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "paddle/phi/kernels/funcs/stride_utils.h"
#include "paddle/phi/kernels/funcs/strided_utils.h"
#include "paddle/phi/kernels/index_put_kernel.h"
#include "paddle/phi/kernels/stride/elementwise_stride_base.cu.h"

#if defined(__NVCC__) || defined(__HIPCC__) || defined(__xpu__)
#include "paddle/phi/kernels/funcs/dims_simplifier.h"
Expand Down Expand Up @@ -74,20 +75,6 @@ inline bool CheckIsDimsMatchBool(const DDim& first, const DDim& second) {
return false;
}

template <typename Context>
phi::DenseTensor Tensor2Contiguous(const Context& dev_ctx,
const phi::DenseTensor& tensor) {
phi::DenseTensor dense_out;
phi::MetaTensor meta_input(tensor);
phi::MetaTensor meta_out(&dense_out);
UnchangedInferMeta(meta_input, &meta_out);
PD_VISIT_ALL_TYPES(tensor.dtype(), "Tensor2Contiguous", ([&] {
phi::ContiguousKernel<data_t, Context>(
dev_ctx, tensor, &dense_out);
}));
return dense_out;
}

template <typename T, typename Context>
void LaunchIndexPutKernel_V2(const Context& dev_ctx,
const DenseTensor& x,
Expand All @@ -110,11 +97,41 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,
false,
common::errors::InvalidArgument("Indices cannot be empty."));

bool is_initialized = out->initialized();
auto meta = x.meta();
meta.dims = out->dims();
meta.strides = meta.calc_strides(out->dims());
out->set_meta(meta);
T* out_data = dev_ctx.template Alloc<T>(out);
if (!is_initialized) {
if (!x.meta().is_contiguous() || x.offset() != 0) {
StridedTensorCopy<T>(x,
common::vectorize<int64_t>(out->dims()),
common::vectorize<int64_t>(out->strides()),
0,
out);
} else {
phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
}
}

funcs::AdvancedIndex ad =
funcs::AdvancedIndex<T, Context>(dev_ctx, x, indices);
funcs::AdvancedIndex<T, Context>(dev_ctx, *out, indices);
if (!CheckIsDimsMatchBool(ad.src.dims(), value.dims())) {
DenseTensor x_;
DenseTensor value_;
if (!x.meta().is_contiguous() || x.offset() != 0) {
x_ = Tensor2Contiguous<Context>(dev_ctx, x);
} else {
x_ = x;
}
if (!value.meta().is_contiguous() || value.offset() != 0) {
value_ = Tensor2Contiguous<Context>(dev_ctx, value);
} else {
value_ = value;
}
phi::IndexPutKernel<T, Context>(
dev_ctx, x, indices, value, accumulate, out);
dev_ctx, x_, indices, value_, accumulate, out);
return;
}

Expand Down Expand Up @@ -151,16 +168,6 @@ void LaunchIndexPutKernel_V2(const Context& dev_ctx,

auto* val_data = value.data<T>();

bool is_initialized = out->initialized();
T* out_data = dev_ctx.template Alloc<T>(out);
if (!is_initialized) {
StridedTensorCopy<T>(x,
common::vectorize<int64_t>(x.dims()),
common::vectorize<int64_t>(x.strides()),
x.offset(),
out);
}

const char* in_ptr = reinterpret_cast<const char*>(val_data);
char* out_ptr = reinterpret_cast<char*>(out_data);
funcs::index_put_kernel<nt, vt, T><<<grid, block, 0, stream>>>(
Expand Down
Loading