Skip to content

Commit

Permalink
Merge remote-tracking branch 'tx/ad' into lml/lower2prim
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 16, 2022
2 parents 7c565c3 + 2fd5fde commit e347444
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 17 deletions.
96 changes: 81 additions & 15 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,36 +24,67 @@


def topo_path(xs, ys, block=None):
""" Returns the ops in topological on the paths from input `xs` to
output `ys`. """
""" Returns the list of ops on the path from `xs` to `ys` in topological
order.
TODO(Tongxin): supporting control flow and nested blocks.
Args:
xs: a list|tuple of vars as source
ys: a list|tuple of vars as sink
block: the program block containing the path, optional
Returns:
path: a list of ops
"""

if block is None:
block = default_main_program().current_block()
path = []
def_vars = list(xs)
reached_vars = list(xs)
sink_vars = list(ys)
sink_ops = {}

# block.ops are supposedly in topological order as of now
for op in block.ops:
if len(sink_ops) == len(ys):
break
ins = set(get_input_vars(op))
if any(ins.intersection(def_vars)):
if any(ins.intersection(reached_vars)):
path.append(op)
outs = set(get_output_vars(op))
for out in outs:
if any(out is y for y in ys):
if any(out is y for y in sink_vars):
# Found an output op
assert not any(out is y for y in sink_ops)
sink_ops[id(out)] = op
# TODO(Tongxin): handling cases where sink vars
# have dependencies themselves
else:
def_vars.append(out)
if len(sink_ops) != len(ys):
not_reachable = (var for var in ys if id(var) not in sink_ops)
raise f"Output vars: {' '.join(not_reachable)} are not reachable from inputs."
reached_vars.append(out)

if len(sink_ops) != len(sink_vars):
for var in sink_vars:
assert id(var) in sink_ops, (
f"{var} is not reachable from input vars.")

return path


def vars_on_path(xs, ys, block=None):
def output_vars_on_path(xs, ys, block=None):
""" Returns the output variables of all the ops on the path from `xs`
to `ys`.
Args:
xs: a list|tuple of vars as source
ys: a list|tuple of vars as sink
block: the program block containing the path, optional
Returns:
vars: the output vars
"""
if block is None:
block = default_main_program().current_block()

Expand All @@ -73,6 +104,11 @@ def vars_on_path(xs, ys, block=None):


class VarMap(object):
""" A general map data structure for linking variables to variables.
An example is linking variables to their gradients.
"""

__slots__ = ['name', 'varset', 'tab']

def __init__(self, name, varset):
Expand All @@ -85,8 +121,10 @@ def add(self, key_var, value_var):

def add_rec(self, key_vars, value_vars):
if isinstance(key_vars, paddle.fluid.framework.Variable):
assert isinstance(value_vars, paddle.fluid.framework.Variable)
self.tab[id(key_vars)] = id(value_vars)
else:
assert len(key_vars) == len(value_vars)
for key_var, value_var in zip(key_vars, value_vars):
self.add_rec(key_var, value_var)

Expand Down Expand Up @@ -153,9 +191,6 @@ def erase_dots(self, vars_to_erase):
for var in vars_to_erase:
del var.block.vars[var.name]

# def is_dot(self, var):
# return self.var2dot.contain_value(var)

def var2dot_rec(self, vars, defaults=None):

if isinstance(vars, paddle.fluid.framework.Variable):
Expand Down Expand Up @@ -218,7 +253,7 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
self.dot2bar.add(dot, bar)

# find all the relevant forward gradients
dotvars = vars_on_path(xs_dot, ys_dot)
dotvars = output_vars_on_path(xs_dot, ys_dot)

is_dot = lambda v: id(v) in dotvars

Expand Down Expand Up @@ -258,6 +293,36 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):

return ys_bar, xs_bar

def _gradients(ys, xs, ys_bar=None):
""" A drop-in replacement of paddle.gradients for computing
the gradients of `xs` against `ys` using primitive ops based
AD rules.
Args:
ys: the target tensor or tensors
xs: the input tensor or tensors
ys_bar: the optional gradient tensors of `ys`
Returns:
xs_bar: a list gradients of input `xs`
"""

ys, xs = to_tensors(ys), to_tensors(xs)
block = ys[0].block

# TODO(Tongxin) without any prior knowledge about whether the program
# is completely lowered to primitive ops, it's mandatory to run the lowering
# pass once and again. This is obviously inefficient and needs to be
# optimized.
orig2prim(block)

ad = Transform(block)
xs_dot, ys_dot = ad.linearize(xs, ys)
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, ys_bar)

prim2orig(block)
return xs_bar


def orig2prim(block=None, update_var_list=None):
_lower(block, reverse=False, update_var_list=update_var_list)
Expand Down Expand Up @@ -318,7 +383,8 @@ def _lower(block, reverse, update_var_list):
for op_idx in reversed(ops_to_remove):
block._remove_op(op_idx)
for var_name in vars_to_remove:
block._remove_var(var_name)
# block._remove_var(var_name)
del block.vars[var_name]

if update_var_list is not None:
for i in range(len(update_var_list)):
Expand Down
23 changes: 21 additions & 2 deletions python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@
transpose, split, concat, reduce, matmul, slice_select, slice_assign,
gather, scatter_add, fill_const)
from paddle.autograd.primx import Transform, topo_path, orig2prim, prim2orig
from paddle.autograd.primx import _gradients

from paddle.autograd.new_adam_optimizer import AdamOptimizer

def prog1(x, y):
t = paddle.matmul(x, y)
# z = paddle.sum(paddle.sqrt(x))
return t

class TestPyPrimOps(unittest.TestCase):
""" Test Python wrappers of primitive ops. """
Expand Down Expand Up @@ -190,6 +193,22 @@ def loss(y, x):
for op in topo_path(vs, grads):
print(op)


def test_first_order_gradients(self):
x = np.random.rand(100, 1, 2)
y = np.random.rand(100, 2, 5)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
X = paddle.static.data('X', shape=[100, 1, 2], dtype='float32')
Y = paddle.static.data('Y', shape=[100, 2, 5], dtype='float32')
Z = prog1(X, Y)
X_grad, W_grad = _gradients([Z], [X, Y])
exe = paddle.static.Executor()
exe.run(startup)
z = exe.run(main, feed={'X': x, 'Y': y}, fetch_list=[Z])
print(z)

def test_lower(self):
main = paddle.static.Program()
startup = paddle.static.Program()
Expand Down

0 comments on commit e347444

Please sign in to comment.