Skip to content

Commit

Permalink
merge tx/ad
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 18, 2022
2 parents 14ffcf1 + 13db4cd commit 32256d5
Show file tree
Hide file tree
Showing 2 changed files with 173 additions and 98 deletions.
48 changes: 25 additions & 23 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

import paddle
from .primreg import REGISTER_ORIG2PRIM, REGISTER_PRIM2ORIG, REGISTER_JVP, REGISTER_TRANSPOSE
from .primreg import lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp, lookup_transpose
from .primreg import (lookup_fn, lookup_orig2prim, lookup_prim2orig, lookup_jvp,
lookup_transpose, op_position_inputs, op_position_output)
from .primops import (neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast,
transpose, split, concat, reduce, matmul, slice_select,
slice_assign, gather, scatter_add, fill_const, set_value)
Expand Down Expand Up @@ -361,31 +362,31 @@ def sub_jvp(op, x_dot, y_dot):
@REGISTER_JVP('mul_p')
def mul_jvp(op, x_dot, y_dot):
assert op.type == 'mul_p'
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
t1, t2 = mul(x_dot, y), mul(x, y_dot)
z_dot = add(t1, t2)
return z_dot


@REGISTER_JVP('div_p')
def div_jvp(op, x_dot, y_dot):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
t1, t2 = div(x_dot, y), div(mul(x, y_dot), mul(y, y))
z_dot = sub(t1, t2)
return z_dot


@REGISTER_JVP('sqrt_p')
def sqrt_jvp(op, x_dot):
x, = get_input_vars(op)
x, = op_position_inputs(op)
c2 = fill_const(value=2.0, shape=x.shape, dtype=x.dtype)
y_dot = div(x_dot, mul(c2, sqrt(x)))
return y_dot


@REGISTER_JVP('tanh_p')
def tanh_jvp(op, x_dot):
y, = get_output_vars(op)
y, = op_position_inputs(op)
c1 = fill_const(value=1.0, shape=y.shape, dtype=y.dtype)
y_dot = mul(x_dot, sub(c1, mul(y, y)))
return y_dot
Expand Down Expand Up @@ -431,7 +432,7 @@ def reduce_jvp(op, x_dot):

@REGISTER_JVP('matmul_p')
def matmul_jvp(op, x_dot, y_dot):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
t1 = matmul(x, y_dot)
t2 = matmul(x_dot, y)
z_dot = add(t1, t2)
Expand Down Expand Up @@ -460,14 +461,14 @@ def slice_assign_jvp(op, x_dot, y_dot):

@REGISTER_JVP('gather_p')
def gather_jvp(op, x_dot):
_, indextensor = get_input_vars(op)
_, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, indextensor, axis=axis)


@REGISTER_JVP('scatter_add_p')
def scatter_add_jvp(op, x_dot, y_dot):
_, _, indextensor = get_input_vars(op)
_, _, indextensor = op_position_inputs(op)
axis = op.attr('axis')
return linear_jvp(op, x_dot, y_dot, indextensor, axis=axis)

Expand All @@ -477,7 +478,7 @@ def scatter_add_jvp(op, x_dot, y_dot):

@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = z_bar if check_dot(y) else None
Expand All @@ -486,7 +487,7 @@ def add_transpose(op, check_dot, z_bar):

@REGISTER_TRANSPOSE('sub_p')
def sub_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = neg(z_bar) if check_dot(y) else None
Expand All @@ -495,7 +496,7 @@ def sub_transpose(op, check_dot, z_bar):

@REGISTER_TRANSPOSE('mul_p')
def mul_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y)
if check_dot(x):
return mul(z_bar, y), None
Expand All @@ -505,22 +506,22 @@ def mul_transpose(op, check_dot, z_bar):

@REGISTER_TRANSPOSE('div_p')
def div_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert not check_dot(y)
x_bar = div(z_bar, y) if check_dot(x) else None
return x_bar, None


@REGISTER_TRANSPOSE('reshape_p')
def reshape_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
return reshape(y_bar, shape=x.shape)


@REGISTER_TRANSPOSE('broadcast_p')
def broadcast_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
bat = len(y_bar.shape) - len(x.shape)
axis = list(range(bat))
Expand All @@ -532,7 +533,7 @@ def broadcast_transpose(op, check_dot, y_bar):

@REGISTER_TRANSPOSE('transpose_p')
def transpose_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
axis = op.attr('axis')
reordered = sorted((k, i) for i, k in enumerate(axis))
Expand All @@ -542,14 +543,14 @@ def transpose_transpose(op, check_dot, y_bar):

@REGISTER_TRANSPOSE('split_p')
def split_transpose(op, check_dot, ys_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
return concat(ys_bar, axis=op.attr('axis'))


@REGISTER_TRANSPOSE('concat_p')
def concat_transpose(op, check_dot, y_bar):
xs = get_input_vars(op)
xs, = op_position_inputs(op)
for x in xs:
assert check_dot(x)
axis = op.attr('axis')
Expand All @@ -559,7 +560,7 @@ def concat_transpose(op, check_dot, y_bar):

@REGISTER_TRANSPOSE('reduce_p')
def reduce_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
axes = op.attr('axis')
shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape))
Expand All @@ -569,7 +570,7 @@ def reduce_transpose(op, check_dot, y_bar):

@REGISTER_TRANSPOSE('matmul_p')
def matmul_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert check_dot(x) ^ check_dot(y)
# TODO: replace it. this is hacky
axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1]
Expand All @@ -581,7 +582,7 @@ def matmul_transpose(op, check_dot, z_bar):

@REGISTER_TRANSPOSE('slice_select_p')
def slice_select_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
x, = op_position_inputs(op)
assert check_dot(x)
zeros = fill_const(value=0.0, shape=x.shape, dtype=x.dtype)
axis = op.attr('axis')
Expand All @@ -594,7 +595,7 @@ def slice_select_transpose(op, check_dot, y_bar):

@REGISTER_TRANSPOSE('slice_assign_p')
def slice_assign_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
x, y = op_position_inputs(op)
assert check_dot(x) and check_dot(y)
zeros = fill_const(value=0.0, shape=y.shape, dtype=y.dtype)
axis = op.attr('axis')
Expand All @@ -610,10 +611,11 @@ def slice_assign_transpose(op, check_dot, z_bar):

@REGISTER_TRANSPOSE('gather_p')
def gather_transpose(op, check_dot, y_bar):
x, indextensor = get_input_vars(op)
x, indextensor = op_position_inputs(op)
assert check_dot(x)
axis = op.attr('axis')
return scatter_add(y_bar, indextensor, axis=axis)
zeros = fill_const(0.0, x.shape, x.dtype)
return scatter_add(zeros, y_bar, indextensor, axis=axis), None


@REGISTER_TRANSPOSE('scatter_add_p')
Expand Down
Loading

0 comments on commit 32256d5

Please sign in to comment.