Skip to content

Commit

Permalink
improve the performence of divide_double_grad (#62533)
Browse files Browse the repository at this point in the history
* improve the performence of divide double grad

* fix infermeta

* update

* fix some bug

* fix bug and update test

* update

* fix bug

* update

* update

* update test

* update ddout

* update device

* add constant

* update

* fix bug

* remove vlog
  • Loading branch information
YibinLiu666 authored Apr 1, 2024
1 parent 149e543 commit 52984e3
Show file tree
Hide file tree
Showing 9 changed files with 496 additions and 97 deletions.
1 change: 1 addition & 0 deletions paddle/fluid/operators/elementwise/elementwise_div_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class ElementwiseDivDoubleGradMaker : public framework::SingleGradOpMaker<T> {
op->SetType("elementwise_div_grad_grad");
op->SetInput("Y", this->Input("Y"));
op->SetInput("Out", this->Input("Out"));
op->SetInput("Out@GRAD", this->Input(framework::GradVarName("Out")));
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DDY", this->OutputGrad(framework::GradVarName("Y")));
op->SetInput("DX", this->Output(framework::GradVarName("X")));
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/ops_signature/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ KernelSignature ElementwiseDivGradOpArgumentMapping(
KernelSignature ElementwiseDivDoubleGradOpArgumentMapping(
const ArgumentMappingContext& ctx UNUSED) {
return KernelSignature("divide_double_grad",
{"Y", "Out", "DX", "DDX", "DDY"},
{"Y", "Out", "Out@GRAD", "DX", "DDX", "DDY"},
{"axis"},
{"Y@GRAD", "DOut", "DDOut"});
}
Expand Down
6 changes: 3 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/ops_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,15 +201,15 @@

- backward_op : divide_double_grad
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [y, grad_x, grad_x]
param : [y, out, out]
kernel :
func : divide_double_grad
data_type : out
optional : grad_x_grad, grad_y_grad
optional : grad_x, grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : divide_grad
Expand Down
6 changes: 3 additions & 3 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -175,15 +175,15 @@

- backward_op : divide_double_grad
forward : divide_grad (Tensor x, Tensor y, Tensor out, Tensor grad_out, int axis = -1) -> Tensor(grad_x), Tensor(grad_y)
args : (Tensor y, Tensor out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
args : (Tensor y, Tensor out, Tensor grad_out, Tensor grad_x, Tensor grad_x_grad, Tensor grad_y_grad, int axis = -1)
output : Tensor(y_grad), Tensor(out_grad), Tensor(grad_out_grad)
infer_meta :
func : GeneralTernaryGradInferMeta
param : [y, grad_x, grad_x]
param : [y, out, out]
kernel :
func : divide_double_grad
data_type : out
optional : grad_x_grad, grad_y_grad
optional : grad_x, grad_x_grad, grad_y_grad
inplace : (grad_x_grad -> grad_out_grad)

- backward_op : divide_grad
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/elementwise_divide_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ template <typename T, typename Context>
void DivideDoubleGradKernel(const Context& dev_ctx,
const DenseTensor& y,
const DenseTensor& out,
const DenseTensor& dx,
const DenseTensor& grad_out,
const paddle::optional<DenseTensor>& dx,
const paddle::optional<DenseTensor>& ddx,
const paddle::optional<DenseTensor>& ddy,
int axis,
Expand Down
2 changes: 0 additions & 2 deletions paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
"Axis should be less than or equal to %d, but received axis is %d.",
max_dim,
axis));

if (x_dims.size() > y_dims.size()) {
std::fill(y_dims_array, y_dims_array + axis, 1);
if (axis + y_dims.size() < max_dim) {
Expand All @@ -68,7 +67,6 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
std::copy(x_dims.Get(), x_dims.Get() + x_dims.size(), x_dims_array + axis);
std::copy(y_dims.Get(), y_dims.Get() + y_dims.size(), y_dims_array);
}

for (int i = 0; i < max_dim; ++i) {
PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
Expand Down
Loading

0 comments on commit 52984e3

Please sign in to comment.