Skip to content

Commit

Permalink
add indextensor_dot for gather and scatter_add
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 19, 2022
1 parent bf5f4a6 commit 6a2f1fd
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,14 +493,14 @@ def slice_assign_jvp(op, x_dot, y_dot):


@REGISTER_JVP('gather_p')
def gather_jvp(op, x_dot):
def gather_jvp(op, x_dot, indextensor_dot):
_, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, indextensor, axis=axis)


@REGISTER_JVP('scatter_add_p')
def scatter_add_jvp(op, x_dot, y_dot):
def scatter_add_jvp(op, x_dot, y_dot, indextensor_dot):
_, _, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis)
Expand Down Expand Up @@ -662,5 +662,5 @@ def scatter_add_transpose(op, check_dot, z_bar):
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
x_bar = scatter_add(z_bar, zeros, indextensor, axis=axis)
y_bar = gather(z_bar, indextensor, axis=axis)
indextensor_bar = None
indextensor_bar = None
return x_bar, y_bar, indextensor_bar

0 comments on commit 6a2f1fd

Please sign in to comment.