Skip to content

Commit

Permalink
add gradient aggregation, fix add_transpose.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 13, 2022
1 parent 2aae968 commit e980c9d
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 24 deletions.
9 changes: 5 additions & 4 deletions python/paddle/autograd/primrules.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,15 +172,15 @@ def scatter_add_jvp(op, x_dot, y_dot):
@REGISTER_TRANSPOSE('add_p')
def add_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and check_dot(y)
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = z_bar if check_dot(y) else None
return x_bar, y_bar


def sub_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and check_dot(y)
assert check_dot(x) or check_dot(y)
x_bar = z_bar if check_dot(x) else None
y_bar = neg(z_bar) if check_dot(y) else None
return x_bar, y_bar
Expand All @@ -199,8 +199,9 @@ def mul_transpose(op, check_dot, z_bar):
@REGISTER_TRANSPOSE('div_p')
def div_transpose(op, check_dot, z_bar):
x, y = get_input_vars(op)
assert check_dot(x) and not check_dot(y)
return div(z_bar, y), None
assert not check_dot(y)
x_bar = div(z_bar, y) if check_dot(x) else None
return x_bar, None


@REGISTER_TRANSPOSE('reshape_p')
Expand Down
66 changes: 46 additions & 20 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name, core
from .primops import fill_const
from .primops import fill_const, add
from .primrules import get_input_vars, get_output_vars, _jvp, _transpose
from collections import OrderedDict

Expand Down Expand Up @@ -79,8 +79,10 @@ def add(self, key_var, value_var):

def lookup(self, key_var):
value_id = self.tab.get(id(key_var))
value_var = self.varset.get(value_id)
return value_var
if value_id is not None:
return self.varset.get(value_id)
else:
return None

def delete(self, key_var):
del self.tab[id(key_var)]
Expand All @@ -101,7 +103,6 @@ def contain_var(self, key_var):
def contain_value(self, value_var):
return id(value_var) in self.tab.values()


class Transform(object):
""" An object that maintains the state of transformations applied to a
primitve program. """
Expand Down Expand Up @@ -136,31 +137,44 @@ def lower2prim(self):
pass

def linearize(self, xs, ys, xs_dot=None):
dotvars = VarMap('dotvars', self.vars)

if xs_dot is None:
xs_dot = []
for x in xs:
xs_dot.append(fill_const(1.0, shape=x.shape, dtype=x.dtype))
xs_dot = [fill_const(1.0, shape=x.shape, dtype=x.dtype) for x in xs]
self.add_vars(xs_dot)
else:
assert len(xs) == len(xs_dot)
assert all(x.dtype == x_dot.dtype for x, x_dot in zip(xs, xs_dot))
assert all(x.shape == x_dot.shape for x, x_dot in zip(xs, xs_dot))

self.add_vars(xs_dot)
for x, dot in zip(xs, xs_dot):
assert x.dtype == dot.dtype
assert x.shape == dot.shape
self.var2dot.add(x, dot)

for op in topo_path(xs, ys, self.block):
ins = get_input_vars(op)
ins_dot = [self.var2dot.lookup(var) for var in ins]
jvp_ins = [x if dot is None else dot for dot, x in zip(ins_dot, ins)]
# an input var may not be on the input-output path, therefore
# there's no forward gradient linked to it. This implies
# there may be None's in `ins_dot`. In this case we place
# the original input in the position of the otherwise forward
# gradient.
jvp_ins = []
for var in ins:
dot = self.var2dot.lookup(var)
if dot is None:
jvp_ins.append(var)
else:
jvp_ins.append(dot)

# apply op's forward ad rule
outs_dot = _jvp(op, *jvp_ins)

if not isinstance(outs_dot, list):
outs_dot = [outs_dot]

self.add_vars(outs_dot)

for x, dot in zip(get_output_vars(op), outs_dot):
self.var2dot.add(x, dot)
outs = get_output_vars(op)
assert len(outs) == len(outs_dot)
for out, dot in zip(outs, outs_dot):
self.var2dot.add(out, dot)

ys_dot = [self.var2dot.lookup(y) for y in ys]
return xs_dot, ys_dot
Expand All @@ -170,17 +184,21 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
ys_bar = []
for y in ys_dot:
ys_bar.append(fill_const(1.0, shape=y.shape, dtype=y.dtype))
self.add_vars(ys_bar)
else:
assert len(ys_dot) == len(ys_bar)
for y_dot, y_bar in zip(ys_dot, ys_bar):
assert y_dot.shape == y_bar.shape
assert y_dot.dtype == y_bar.dtype

dotvars = vars_on_path(xs_dot, ys_dot)
is_dot = lambda v: id(v) in dotvars
self.add_vars(ys_bar)
for dot, bar in zip(ys_dot, ys_bar):
self.dot2bar.add(dot, bar)

# find all the relevant forward gradients
dotvars = vars_on_path(xs_dot, ys_dot)

is_dot = lambda v: id(v) in dotvars

for op in reversed(topo_path(xs_dot, ys_dot, self.block)):
outs_bar = [self.dot2bar.lookup(var) for var in get_output_vars(op)]
ins_bar = _transpose(op, is_dot, *outs_bar)
Expand All @@ -191,7 +209,15 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
self.add_vars(ins_bar)
for dot, bar in zip(get_input_vars(op), ins_bar):
if bar is not None:
self.dot2bar.add(dot, bar)
# aggregate gradient
grad = self.dot2bar.lookup(dot)
if grad is None:
self.dot2bar.add(dot, bar)
else:
grad = add(grad, bar)
self.add_vars([grad])
self.dot2bar.add(dot, grad)

xs_bar = [self.dot2bar.lookup(x) for x in xs_dot]

if not retain_fwd:
Expand Down

0 comments on commit e980c9d

Please sign in to comment.