Skip to content

Commit

Permalink
use cumprod fix bug of prod_grad (PaddlePaddle#64127)
Browse files Browse the repository at this point in the history
* add reverse and exclusive

* support reverse and exclusive

* fix inplace test

* fix make on xpu

* remove print

* update test time

* fix prod_grad use cumprod

* add 1-5D test

* speed up test

* mul out_grad

* update test

* Update CMakeLists.txt

* update op_version

* update detail.h

* update composite

* update detail.h

* stage

* update test

* only test on gpu

* update details.h

* CI

* Update composite_backward_api.h

* Update details.h

* update test

* update

* Update test_reduce_op.py

* fix cumprod cpu bug

* update test

* update test

* update

* update

* ci

* ci

* remove comment

* CI

* Your commit message

* CI

* update
  • Loading branch information
YibinLiu666 authored and co63oc committed Jun 3, 2024
1 parent be99ba2 commit 8e9e2e9
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 23 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/prim/api/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
- pad
- sqrt
- cumsum
- cumprod
- put_along_axis
- sin
- cos
Expand Down
77 changes: 66 additions & 11 deletions paddle/fluid/prim/api/composite_backward/composite_backward_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -1127,11 +1127,13 @@ void prod_grad(const Tensor& x,
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
auto out_tmp = Tensor();
auto out_grad_tmp = Tensor();
auto x_reshape = Tensor();
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
cumprod_shape;
std::vector<int> transpose_dim, origin_position;
if (x_dim_size == 1) {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
out_grad_tmp = out_grad.expand(IntArray(x_dim));
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
Expand All @@ -1149,16 +1151,69 @@ void prod_grad(const Tensor& x,
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
x_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_ = reshape<T>(out, out_grad_shape);
out_tmp = out_.expand(IntArray(x_dim));
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
} else {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
out_grad_tmp = out_grad.expand(IntArray(x_dim));
}
}
auto x_grad_res = x_grad_tmp * out_tmp * (1 / x);
set_output<T>(x_grad_res, x_grad);
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
int64_t numel = 1;
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
numel *= x_dim[i];
}
cumprod_shape.push_back(numel);
x_reshape = reshape<T>(x, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
} else {
int64_t unchange_size = x_dim_size - axis_size;
int64_t unchange_index = 0;
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_.push_back(axis[i] + x_dim_size);
} else {
axis_.push_back(axis[i]);
}
}
for (int64_t i = 0; i < x_dim_size; i++) {
auto it = find(axis_.begin(), axis_.end(), i);
if (it != axis_.end()) {
int64_t index = it - axis_.begin();
origin_position.push_back(static_cast<int>(unchange_size + index));
} else {
unchange_axis.push_back(i);
origin_position.push_back(static_cast<int>(unchange_index));
unchange_index += 1;
}
}
int64_t numel = 1;
for (int64_t i = 0; i < unchange_size; i++) {
transpose_shape.push_back(x_dim[unchange_axis[i]]);
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
}
for (int64_t i = 0; i < axis_size; i++) {
transpose_shape.push_back(x_dim[axis_[i]]);
transpose_dim.push_back(static_cast<int>(axis_[i]));
numel *= x_dim[axis_[i]];
}
cumprod_shape.push_back(numel);
auto x_transpose = transpose<T>(x, transpose_dim);
x_reshape = reshape<T>(x_transpose, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
}
}
}

Expand Down
77 changes: 66 additions & 11 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -1773,11 +1773,13 @@ void prod_grad(const Tensor& x,
} else {
reduce_all = false;
}
auto x_grad_tmp = Tensor();
auto out_tmp = Tensor();
auto out_grad_tmp = Tensor();
auto x_reshape = Tensor();
std::vector<int64_t> unchange_axis, change_axis, transpose_shape,
cumprod_shape;
std::vector<int> transpose_dim, origin_position;
if (x_dim_size == 1) {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
out_grad_tmp = out_grad.expand(IntArray(x_dim));
} else {
if (!keep_dim) {
auto axis_ = std::vector<int64_t>();
Expand All @@ -1795,16 +1797,69 @@ void prod_grad(const Tensor& x,
}
auto out_grad_shape = get_unsqueeze_dims(out_grad, axis_);
auto out_grad_ = reshape<T>(out_grad, out_grad_shape);
x_grad_tmp = out_grad_.expand(IntArray(x_dim));
auto out_ = reshape<T>(out, out_grad_shape);
out_tmp = out_.expand(IntArray(x_dim));
out_grad_tmp = out_grad_.expand(IntArray(x_dim));
} else {
x_grad_tmp = out_grad.expand(IntArray(x_dim));
out_tmp = out.expand(IntArray(x_dim));
out_grad_tmp = out_grad.expand(IntArray(x_dim));
}
}
auto x_grad_res = x_grad_tmp * out_tmp * (1 / x);
set_output<T>(x_grad_res, x_grad);
auto axis_ = std::vector<int64_t>();
if (reduce_all) {
int64_t numel = 1;
for (int64_t i = 0; i < x_dim_size; i++) {
axis_.push_back(i);
numel *= x_dim[i];
}
cumprod_shape.push_back(numel);
x_reshape = reshape<T>(x, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_tmp2 = reshape<T>(x_grad_tmp, x.shape());
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
} else {
int64_t unchange_size = x_dim_size - axis_size;
int64_t unchange_index = 0;
for (int64_t i = 0; i < axis_size; i++) {
if (axis[i] < 0) {
axis_.push_back(axis[i] + x_dim_size);
} else {
axis_.push_back(axis[i]);
}
}
for (int64_t i = 0; i < x_dim_size; i++) {
auto it = find(axis_.begin(), axis_.end(), i);
if (it != axis_.end()) {
int64_t index = it - axis_.begin();
origin_position.push_back(static_cast<int>(unchange_size + index));
} else {
unchange_axis.push_back(i);
origin_position.push_back(static_cast<int>(unchange_index));
unchange_index += 1;
}
}
int64_t numel = 1;
for (int64_t i = 0; i < unchange_size; i++) {
transpose_shape.push_back(x_dim[unchange_axis[i]]);
cumprod_shape.push_back(x_dim[unchange_axis[i]]);
transpose_dim.push_back(static_cast<int>(unchange_axis[i]));
}
for (int64_t i = 0; i < axis_size; i++) {
transpose_shape.push_back(x_dim[axis_[i]]);
transpose_dim.push_back(static_cast<int>(axis_[i]));
numel *= x_dim[axis_[i]];
}
cumprod_shape.push_back(numel);
auto x_transpose = transpose<T>(x, transpose_dim);
x_reshape = reshape<T>(x_transpose, cumprod_shape);
auto left_cumprod = cumprod<T>(x_reshape, -1, true, false);
auto right_cumprod = cumprod<T>(x_reshape, -1, true, true);
auto x_grad_tmp = left_cumprod * right_cumprod;
auto x_grad_reshape = reshape<T>(x_grad_tmp, transpose_shape);
auto x_grad_tmp2 = transpose<T>(x_grad_reshape, origin_position);
auto x_grad_res = x_grad_tmp2 * out_grad_tmp;
set_output<T>(x_grad_res, x_grad);
}
}
}

Expand Down
13 changes: 12 additions & 1 deletion paddle/phi/kernels/cpu/cumprod_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,16 @@ void CumprodKernel(const Context& dev_ctx,
DenseTensor* out) {
const DenseTensor* x = &input;
auto* x_data = x->data<T>();
auto* out_data = dev_ctx.template Alloc<T>(out);
auto* out_ptr = dev_ctx.template Alloc<T>(out);
DDim shape = x->dims();
DenseTensor out_tmp;
T* out_data = nullptr;
if (x_data == out_ptr) {
out_tmp.Resize(shape);
out_data = dev_ctx.template Alloc<T>(&out_tmp);
} else {
out_data = out_ptr;
}

size_t outer_dim = 1;
size_t mid_dim = 1;
Expand Down Expand Up @@ -88,6 +96,9 @@ void CumprodKernel(const Context& dev_ctx,
}
}
}
if (x_data == out_ptr) {
memcpy(out_ptr, out_data, out->numel() * sizeof(T));
}
}

} // namespace phi
Expand Down

0 comments on commit 8e9e2e9

Please sign in to comment.