Skip to content

Commit

Permalink
fix behavior of put_along_axis and take_along_axis 易用性提升No.43 (#59163)
Browse files Browse the repository at this point in the history
* fix behavior of put_along_axis and take_along_axis

* fix error

* fix take_along_axis used in stat

* update

* fix build error

* add test for error

* add param broadcast

* use origin example

* add param include_self

* update param name

* modify ut

* update test case

* add error UT

* update
  • Loading branch information
YibinLiu666 authored Dec 4, 2023
1 parent 78534d8 commit 46176ef
Show file tree
Hide file tree
Showing 9 changed files with 738 additions and 89 deletions.
6 changes: 3 additions & 3 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
value_grad->Resize(index.dims());
dev_ctx.template Alloc<T>(value_grad);
if (index_type == DataType::INT32) {
phi::funcs::cpu_gather_kernel<T, int32_t>(
phi::funcs::cpu_scatter_value_grad_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::cpu_gather_kernel<T, int64_t>(
} else {
phi::funcs::cpu_scatter_value_grad_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
}
Expand Down
110 changes: 93 additions & 17 deletions paddle/phi/kernels/funcs/gather_scatter_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,10 @@ struct cpu_gather_scatter_functor {
}
int64_t select_dim_size = index_dims[dim];
// index matrix has different shape with self matrix or src matrix.
int replaced_select_dim_size =
is_scatter_like ? self_dims[dim] : src_dims[dim];
int self_select_dim_size = self_dims[dim];
int src_select_dim_size = src_dims[dim];
int64_t outer_dim_size_self = 1;
int64_t outer_dim_size_src = 1;
int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
for (int i = 0; i < dim; ++i) {
Expand All @@ -90,10 +92,10 @@ struct cpu_gather_scatter_functor {

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_self *= self_dims[i];
outer_dim_size_src *= src_dims[i];
}
int64_t index_idx = 0;
int64_t self_idx = 0, src_idx = 0;

// N layer loop squeezed into 3 layers loop
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
Expand All @@ -117,13 +119,21 @@ struct cpu_gather_scatter_functor {

// This index might out of bound of index matrix's index, so here
// multiply the replaced_select_dim_size.
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * replaced_select_dim_size;
int64_t replace_index_self, replace_index_src;
if (is_scatter_like) {
replace_index_self = k + index * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;

replace_index_src = k + j * outer_dim_size_src +
i * outer_dim_size_src * src_select_dim_size;
} else {
replace_index_self = index_idx;

self_idx = is_scatter_like ? replace_index : index_idx;
src_idx = is_scatter_like ? index_idx : replace_index;
reduce_op((tensor_t*)(self_data + self_idx), // NOLINT
(tensor_t*)(src_data + src_idx)); // NOLINT
replace_index_src = k + index * outer_dim_size_src +
i * outer_dim_size_src * src_select_dim_size;
}
reduce_op((tensor_t*)(self_data + replace_index_self), // NOLINT
(tensor_t*)(src_data + replace_index_src)); // NOLINT
index_idx++;
}
}
Expand Down Expand Up @@ -183,45 +193,111 @@ template <typename tensor_t, typename index_t>
void cpu_scatter_input_grad_kernel(phi::DenseTensor self UNUSED,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor output,
phi::DenseTensor grad,
const phi::DeviceContext& ctx UNUSED) {
auto* index_data = index.data<index_t>();
auto* output_data = output.data<tensor_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto output_dims = output.dims();
auto grad_dims = grad.dims();

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int64_t outer_dim_size_data = 1;
int64_t select_dim_size = index_dims[dim];
int64_t output_select_dim_size = output_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
for (int i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_data *= grad_dims[i];
}

int64_t index_idx = 0;
for (int64_t i = 0; i < inner_dim_size; i++) {
for (int64_t j = 0; j < select_dim_size; j++) {
for (int64_t k = 0; k < outer_dim_size; k++) {
int64_t index = index_data[index_idx];
int64_t replace_index = k + index * outer_dim_size +
i * outer_dim_size * output_select_dim_size;
output_data[replace_index] = 0;
int64_t replace_index = k + index * outer_dim_size_data +
i * outer_dim_size_data * grad_select_dim_size;
grad_data[replace_index] = 0;
index_idx++;
}
}
}
}

template <typename tensor_t, typename index_t>
void cpu_scatter_value_grad_kernel(phi::DenseTensor self,
int dim,
const phi::DenseTensor& index,
phi::DenseTensor grad,
const phi::DeviceContext& ctx UNUSED) {
auto* self_data = self.data<tensor_t>();
auto* index_data = index.data<index_t>();
auto* grad_data = grad.data<tensor_t>();

auto index_dims = index.dims();
auto self_dims = self.dims();
auto grad_dims = grad.dims();

int64_t self_size = self.numel();
int64_t grad_size = grad.numel();
bool* is_self_grad_used = new bool[self_size];

for (int i = 0; i < self_size; i++) {
is_self_grad_used[i] = false;
}

int64_t inner_dim_size = 1;
int64_t outer_dim_size = 1;
int64_t outer_dim_size_self = 1;
int64_t outer_dim_size_grad = 1;
int64_t select_dim_size = index_dims[dim];
int64_t self_select_dim_size = self_dims[dim];
int64_t grad_select_dim_size = grad_dims[dim];
for (int i = 0; i < dim; ++i) {
inner_dim_size *= index_dims[i];
}

for (int i = dim + 1; i < index_dims.size(); i++) {
outer_dim_size *= index_dims[i];
outer_dim_size_self *= self_dims[i];
outer_dim_size_grad *= grad_dims[i];
}
int64_t index_idx = index.numel() - 1;
for (int i = 0; i < grad_size; i++) {
grad_data[i] = static_cast<tensor_t>(0);
}
for (int64_t i = inner_dim_size - 1; i >= 0; i--) {
for (int64_t j = select_dim_size - 1; j >= 0; j--) {
for (int64_t k = outer_dim_size - 1; k >= 0; k--) {
int64_t index = index_data[index_idx];
int64_t replace_index_self =
k + index * outer_dim_size_self +
i * outer_dim_size_self * self_select_dim_size;
int64_t replace_index_grad =
k + j * outer_dim_size_grad +
i * outer_dim_size_grad * grad_select_dim_size;
if (!is_self_grad_used[replace_index_self]) {
grad_data[replace_index_grad] = self_data[replace_index_self];
is_self_grad_used[replace_index_self] = true;
}
index_idx--;
}
}
}
delete[] is_self_grad_used;
}

Instantiate_Template_Function(cpu_gather_kernel)
Instantiate_Template_Function(cpu_scatter_assign_kernel)
Instantiate_Template_Function(cpu_scatter_add_kernel)
Instantiate_Template_Function(cpu_scatter_mul_kernel)
Instantiate_Template_Function(cpu_scatter_input_grad_kernel)
Instantiate_Template_Function(cpu_scatter_value_grad_kernel)

} // namespace funcs
} // namespace phi
Loading

0 comments on commit 46176ef

Please sign in to comment.