Skip to content

Commit

Permalink
Update input types in the unit test.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 15, 2022
1 parent 485fc65 commit 2c82691
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 50 deletions.
96 changes: 66 additions & 30 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,41 +22,72 @@


def topo_path(xs, ys, block=None):
""" Returns the ops in topological on the paths from input `xs` to
output `ys`. """
""" Returns the list of ops on the path from `xs` to `ys` in topological
order.
TODO(Tongxin): supporting control flow and nested blocks.
Args:
xs: a list|tuple of vars as source
ys: a list|tuple of vars as sink
block: the program block containing the path, optional
Returns:
path: a list of ops
"""

if block is None:
block = default_main_program().current_block()
path = []
def_vars = list(xs)
reached_vars = list(xs)
sink_vars = list(ys)
sink_ops = {}

# block.ops are supposedly in topological order as of now
for op in block.ops:
if len(sink_ops) == len(ys):
break
ins = set(get_input_vars(op))
if any(ins.intersection(def_vars)):
if any(ins.intersection(reached_vars)):
path.append(op)
outs = set(get_output_vars(op))
for out in outs:
if any(out is y for y in ys):
if any(out is y for y in sink_vars):
# Found an output op
assert not any(out is y for y in sink_ops)
sink_ops[id(out)] = op
# TODO(Tongxin): handling cases where sink vars
# have dependencies themselves
else:
def_vars.append(out)
if len(sink_ops) != len(ys):
not_reachable = (var for var in ys if id(var) not in sink_ops)
raise f"Output vars: {' '.join(not_reachable)} are not reachable from inputs."
reached_vars.append(out)

if len(sink_ops) != len(sink_vars):
for var in sink_vars:
assert id(var) in sink_ops, (
f"{var} is not reachable from input vars.")

return path


def vars_on_path(xs, ys, block=None):
def output_vars_on_path(xs, ys, block=None):
""" Returns the output variables of all the ops on the path from `xs`
to `ys`.
Args:
xs: a list|tuple of vars as source
ys: a list|tuple of vars as sink
block: the program block containing the path, optional
Returns:
vars: the output vars
"""
if block is None:
block = default_main_program().current_block()

vars = OrderedDict()

for var in xs + ys:
vars[id(var)] = var

Expand All @@ -68,21 +99,30 @@ def vars_on_path(xs, ys, block=None):
vars[id(out)] = out

return vars


class VarMap(object):
""" A general map data structure for linking variables to variables.
An example is linking variables to their gradients.
"""

__slots__ = ['name', 'varset', 'tab']

def __init__(self, name, varset):
self.name = name
self.varset = varset
self.tab = OrderedDict()

def add(self, key_var, value_var):
self.tab[id(key_var)] = id(value_var)

def add_rec(self, key_vars, value_vars):
if isinstance(key_vars, paddle.fluid.framework.Variable):
assert isinstance(value_vars, paddle.fluid.framework.Variable)
self.tab[id(key_vars)] = id(value_vars)
else:
assert len(key_vars) == len(value_vars)
for key_var, value_var in zip(key_vars, value_vars):
self.add_rec(key_var, value_var)

Expand Down Expand Up @@ -112,6 +152,7 @@ def contain_var(self, key_var):
def contain_value(self, value_var):
return id(value_var) in self.tab.values()


class Transform(object):
""" An object that maintains the state of transformations applied to a
primitve program. """
Expand All @@ -129,11 +170,11 @@ def init_vars(self, block):
return vars

def add_vars(self, new_vars):
self.vars.update({id(v) : v for v in new_vars if v is not None})
self.vars.update({id(v): v for v in new_vars if v is not None})

def add_vars_rec(self, new_vars):
if isinstance(new_vars, paddle.fluid.framework.Variable):
self.vars.update({id(new_vars) : new_vars})
self.vars.update({id(new_vars): new_vars})
return
assert isinstance(new_vars, list)
for var in new_vars:
Expand All @@ -148,12 +189,6 @@ def erase_dots(self, vars_to_erase):
for var in vars_to_erase:
del var.block.vars[var.name]

# def is_dot(self, var):
# return self.var2dot.contain_value(var)

def lower2prim(self):
pass

def var2dot_rec(self, vars, defaults=None):

if isinstance(vars, paddle.fluid.framework.Variable):
Expand All @@ -165,9 +200,11 @@ def var2dot_rec(self, vars, defaults=None):
if defaults is None:
defaults = [None for _ in range(vars)]

dots = [self.var2dot_rec(var, default) for var, default
in zip(vars, defaults)]

dots = [
self.var2dot_rec(var, default)
for var, default in zip(vars, defaults)
]

return dots

def linearize(self, xs, ys, xs_dot=None):
Expand All @@ -194,7 +231,7 @@ def linearize(self, xs, ys, xs_dot=None):
self.add_vars_rec(outs_dot)
outs = op_position_output(op)
self.var2dot.add_rec(outs, outs_dot)

ys_dot = [self.var2dot.lookup(y) for y in ys]
return xs_dot, ys_dot

Expand All @@ -212,9 +249,9 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):

for dot, bar in zip(ys_dot, ys_bar):
self.dot2bar.add(dot, bar)

# find all the relevant forward gradients
dotvars = vars_on_path(xs_dot, ys_dot)
dotvars = output_vars_on_path(xs_dot, ys_dot)

is_dot = lambda v: id(v) in dotvars

Expand All @@ -233,11 +270,11 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
grad = self.dot2bar.lookup(dot)
if grad is None:
self.dot2bar.add(dot, bar)
else:
else:
grad = add(grad, bar)
self.add_vars([grad])
self.dot2bar.add(dot, grad)

xs_bar = [self.dot2bar.lookup(x) for x in xs_dot]

if not retain_fwd:
Expand All @@ -253,4 +290,3 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):
self.erase_dots(dots_to_remove)

return ys_bar, xs_bar

41 changes: 21 additions & 20 deletions python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,10 @@
import numpy as np

import paddle
from paddle.autograd.primops import (neg, add, sub, mul, div, sqrt, tanh,
reshape, broadcast, transpose, split,
concat, reduce, matmul, slice_select,
slice_assign, gather, scatter_add,
fill_const)
from paddle.autograd.primops import (
neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast, transpose, split,
concat, reduce, matmul, slice_select, slice_assign, gather, scatter_add,
fill_const)
from paddle.autograd.primx import Transform, topo_path


Expand All @@ -30,7 +29,6 @@ class TestPyPrimOps(unittest.TestCase):
def setUp(self):
paddle.enable_static()


def test_ops(self):
A = np.random.rand(1)
B = np.random.rand(2)
Expand Down Expand Up @@ -95,39 +93,39 @@ def test_ops(self):

reduce_1 = reduce(d, axis=[1])
self.assertEqual(reduce_1.dtype, d.dtype)
self.assertEqual(reduce_1.shape, (2,))
self.assertEqual(reduce_1.shape, (2, ))

reduce_2 = reduce(c, axis=[0, 1])
self.assertEqual(reduce_2.dtype, c.dtype)
self.assertEqual(reduce_2.shape, (1,))
self.assertEqual(reduce_2.shape, (1, ))
# TODO: reduce + keepdim

matmul_1 = matmul(d, e)
self.assertEqual(matmul_1.dtype, d.dtype)
self.assertEqual(matmul_1.shape, (2, 2))

slice_select_1 = slice_select(e, axis=[0], starts=[0], ends=[2],
strides=[1])
slice_select_1 = slice_select(
e, axis=[0], starts=[0], ends=[2], strides=[1])
self.assertEqual(slice_select_1.dtype, e.dtype)
self.assertEqual(slice_select_1.shape, (2, 2))
slice_select_2 = slice_select(d, axis=[0, 1], starts=[0, 1],
ends=[2, 3], strides=[1, 2])

slice_select_2 = slice_select(
d, axis=[0, 1], starts=[0, 1], ends=[2, 3], strides=[1, 2])
self.assertEqual(slice_select_2.dtype, d.dtype)
self.assertEqual(slice_select_2.shape, (2, 1))

y = broadcast(b, [2, 2])
slice_assign_1 = slice_assign(d, y, axis=[1], starts=[1], ends=[3],
strides=[1])
slice_assign_1 = slice_assign(
d, y, axis=[1], starts=[1], ends=[3], strides=[1])
self.assertEqual(slice_assign_1.dtype, d.dtype)
self.assertEqual(slice_assign_1.shape, d.shape)

index = paddle.static.data('index', shape=[5], dtype='int')
index = paddle.static.data('index', shape=[5], dtype='int32')
gather_1 = gather(e, index, axis=0)
self.assertEqual(gather_1.dtype, e.dtype)
self.assertEqual(gather_1.shape, (5, 2))

y = paddle.rand([5, 2])
y = paddle.rand([5, 2], dtype='float')
scatter_add_1 = scatter_add(e, y, index, axis=0)
self.assertEqual(scatter_add_1.dtype, e.dtype)
self.assertEqual(scatter_add_1.shape, e.shape)
Expand All @@ -145,7 +143,7 @@ def test_vjp_set1(self):
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot)
assert xs_bar[0].shape == X.shape
assert xs_bar[1].shape == W.shape

print(f'-------test_vjp_set1-------')
for op in topo_path(ys_bar, xs_bar):
print(op)
Expand All @@ -158,22 +156,25 @@ def test_vjp_set2(self):
X_ = reshape(X, shape=[100, 2, 1])
Z = tanh(matmul(W_, X_))
Y = reduce(Z, axis=[1, 2])

def loss(y, x):
ad = Transform(y.block)
xs_dot, ys_dot = ad.linearize([x], [y])
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot)
# ad = Transform(y.block)
# xs_dot, ys_dot = ad.linearize([x], xs_bar)
# for op in topo_path(xs_dot, ys_dot):
# print(op)
# print(op)
# ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot)
return ys_bar, xs_bar

vs, grads = loss(Y, W)
assert grads[0].shape == W.shape

print(f'-------test_vjp_set2-------')
for op in topo_path(vs, grads):
print(op)


if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 2c82691

Please sign in to comment.