Skip to content

Commit 77fde0c

Browse files
support masked_select double grad in eager mode (#73601)
1 parent bd35ca3 commit 77fde0c

File tree

2 files changed

+8
-0
lines changed

2 files changed

+8
-0
lines changed

paddle/phi/ops/yaml/backward.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,6 +2164,12 @@
21642164
no_need_buffer : x, value
21652165
backward: masked_fill_double_grad
21662166

2167+
- backward_op : masked_select_double_grad
2168+
forward: masked_select_grad (Tensor x, Tensor mask, Tensor grad_out) -> Tensor(grad_x)
2169+
args : (Tensor mask, Tensor grad_x_grad)
2170+
output : Tensor(grad_out_grad)
2171+
invoke : masked_select(grad_x_grad, mask)
2172+
21672173
- backward_op : masked_select_grad
21682174
forward : masked_select (Tensor x, Tensor mask) -> Tensor(out)
21692175
args : (Tensor x, Tensor mask, Tensor out_grad)
@@ -2175,6 +2181,7 @@
21752181
func : masked_select_grad
21762182
data_type : x
21772183
no_need_buffer : x
2184+
backward: masked_select_double_grad
21782185

21792186
- backward_op : match_matrix_tensor_grad
21802187
forward: match_matrix_tensor(Tensor x, Tensor y, Tensor w, int dim_t = 1) -> Tensor (out), Tensor (tmp)

paddle/phi/ops/yaml/op_compat.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2443,6 +2443,7 @@
24432443
{softmax : Softmax, loss : Loss}
24442444

24452445
- op : masked_select
2446+
backward : masked_select_grad, masked_select_double_grad
24462447
inputs :
24472448
{x : X, mask : Mask}
24482449
outputs :

0 commit comments

Comments
 (0)