Skip to content

Commit

Permalink
merge tx/ad
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 13, 2022
2 parents 5664872 + 4b99805 commit 21c8113
Show file tree
Hide file tree
Showing 4 changed files with 211 additions and 119 deletions.
29 changes: 29 additions & 0 deletions python/paddle/autograd/primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.fluid import unique_name, core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid.layer_helper import LayerHelper
from .primreg import REGISTER_FN


def make_var(dtype, varset=None, shape=None, block=None, namekey='',
stop_gradient=False):
""" Create a type inferred variable. """

if block is None:
block = default_main_program().current_block()

name = unique_name.generate_with_ignorable_key(namekey + '%')

var = block.create_var(
name=name,
dtype=dtype,
shape=shape,
type=core.VarDesc.VarType.LOD_TENSOR,
persistable=False,
stop_gradient=stop_gradient)

if varset is not None:
varset.add(var)


def make_varlike(x, block=None, namekey='', stop_gradient=False):
""" Make a variable using the dtype and shape of the given input. """
return make_var(x.dtype, x.shape, block, namekey, stop_gradient)


def _simple_unop(helper):
optype = helper.layer_type
x, out = tuple(map(helper.kwargs.get, ('x', 'out')))
Expand Down
4 changes: 2 additions & 2 deletions python/paddle/autograd/primreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def lookup(self, name):
_primop_fn = Registry('primop_fn')
_orig2prim = Registry('orig2prim')
_prim2orig = Registry('prim2orig')
_primop_jvp = Registry('primop_jvps')
_primop_transpose = Registry('primop_vjps')
_primop_jvp = Registry('primop_jvp')
_primop_transpose = Registry('primop_transpose')


def lookup_fn(optype):
Expand Down
35 changes: 21 additions & 14 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,21 +470,25 @@ def scatter_add_jvp(op, x_dot, y_dot):
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and check_dot(y)
return z_bar, z_bar
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = z_bar if check_dot(y) else None
return x_bar, y_bar


def sub_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and check_dot(y)
return z_bar, neg(z_bar)
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = neg(z_bar) if check_dot(y) else None
return x_bar, y_bar


@REGISTER_TRANSPOSE('mul_p')
def mul_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) ^ check_dot(y)
if x.is_dot:
if check_dot(x):
return mul(z_bar, y), None
else:
return None, mul(x, z_bar)
Expand All @@ -493,8 +497,9 @@ def mul_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('div_p')
def div_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and not check_dot(y)
return div(z_bar, y), None
assert not check_dot(y)
x_bar = div(z_bar, y) if check_dot(x) else None
return x_bar, None


@REGISTER_TRANSPOSE('reshape_p')
Expand All @@ -512,7 +517,8 @@ def broadcast_transpose(op, check_dot, y_bar):
axis = list(range(bat))
keepdim = [(bat + i) for i, s in enumerate(x.shape) if s == 1]
axis += keepdim
return reduce(y_bar, axis=axis, keepdim=keepdim)
# TODO: Change it. keepdim boolean
return reduce(y_bar, axis=axis, keepdim=False)


@REGISTER_TRANSPOSE('transpose_p')
Expand Down Expand Up @@ -546,9 +552,8 @@ def concat_transpose(op, check_dot, y_bar):
def reduce_transpose(op, check_dot, y_bar):
x, = get_input_vars(op)
assert check_dot(x)
shape = x.shape
for i in op.attr('axis'):
shape[i] = 1
axes = op.attr('axis')
shape = tuple(1 if i in axes else size for i, size in enumerate(x.shape))
t = reshape(y_bar, shape=shape)
return broadcast(t, shape=x.shape)

Expand All @@ -557,10 +562,12 @@ def reduce_transpose(op, check_dot, y_bar):
def matmul_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) ^ check_dot(y)
if x.is_dot:
return matmul(z_bar, transpose(y)), None
# TODO: replace it. this is hacky
axis = [1, 0] if len(x.shape) == 2 else [0, 2, 1]
if check_dot(x):
return matmul(z_bar, transpose(y, axis=axis)), None
else:
return None, matmul(transpose(x), z_bar)
return None, matmul(transpose(x, axis=axis), z_bar)


@REGISTER_TRANSPOSE('slice_select_p')
Expand Down
Loading

0 comments on commit 21c8113

Please sign in to comment.