Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert elementwise add #53745

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
prim_white_list = [
"matmul_double_grad",
"tanh_double_grad",
"add_double_grad",
"subtract_double_grad",
]

Expand Down
12 changes: 12 additions & 0 deletions paddle/phi/api/yaml/legacy_backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
kernel :
func : add_double_grad
optional : grad_x_grad, grad_y_grad
backward : add_triple_grad
inplace : (grad_x_grad -> grad_out_grad)
composite : add_double_grad(y, grad_out, grad_x_grad, grad_y_grad, axis, grad_out_grad)

Expand All @@ -47,6 +48,17 @@
backward : add_double_grad
inplace : (out_grad -> x_grad)

- backward_op : add_triple_grad
forward : add_double_grad (Tensor y, Tensor grad_out, Tensor grad_grad_x, Tensor grad_grad_y, int axis = -1) -> Tensor(grad_grad_out)
args : (Tensor grad_grad_x, Tensor grad_grad_y, Tensor grad_grad_out_grad, int axis = -1)
output : Tensor(grad_grad_x_grad), Tensor(grad_grad_y_grad)
infer_meta :
func : GeneralBinaryGradInferMeta
param : [grad_grad_x, grad_grad_y]
kernel :
func : add_triple_grad
inplace : (grad_grad_out_grad -> grad_grad_x_grad)

- backward_op : amax_grad
forward: amax (Tensor x, int64_t[] axis={}, bool keepdim=false) -> Tensor(out)
args : (Tensor x, Tensor out, Tensor out_grad, int64_t[] axis={}, bool keepdim=false, bool reduce_all=false)
Expand Down
3 changes: 2 additions & 1 deletion test/prim/prim/vjp/test_comp_high_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from paddle import fluid
from paddle.fluid import core


'''
@param.parameterized_class(
('shape1', 'shape2'),
[
Expand Down Expand Up @@ -120,6 +120,7 @@ def test_high_grad(self):
for p in places:
self.func_double(p)
self.func_triple(p)
'''


@param.parameterized_class(
Expand Down