Skip to content

Commit

Permalink
orig2prim segfault.
Browse files Browse the repository at this point in the history
  • Loading branch information
tongxin committed Apr 16, 2022
1 parent d60492d commit c203867
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 0 deletions.
30 changes: 30 additions & 0 deletions python/paddle/autograd/primx.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,36 @@ def transpose(self, ys_dot, xs_dot, ys_bar=None, retain_fwd=False):

return ys_bar, xs_bar

def _gradients(ys, xs, ys_bar=None):
""" A drop-in replacement of paddle.gradients for computing
the gradients of `xs` against `ys` using primitive ops based
AD rules.
Args:
ys: the target tensor or tensors
xs: the input tensor or tensors
ys_bar: the optional gradient tensors of `ys`
Returns:
xs_bar: a list gradients of input `xs`
"""

ys, xs = to_tensors(ys), to_tensors(xs)
block = ys[0].block

# TODO(Tongxin) without any prior knowledge about whether the program
# is completely lowered to primitive ops, it's mandatory to run the lowering
# pass once and again. This is obviously inefficient and needs to be
# optimized.
orig2prim(block)

ad = Transform(block)
xs_dot, ys_dot = ad.linearize(xs, ys)
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot, ys_bar)

prim2orig(block)
return xs_bar


def orig2prim(block=None):
_lower(block, reverse=False)
Expand Down
21 changes: 21 additions & 0 deletions python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@
concat, reduce, matmul, slice_select, slice_assign, gather, scatter_add,
fill_const)
from paddle.autograd.primx import Transform, topo_path, orig2prim, prim2orig
from paddle.autograd.primx import _gradients

def prog1(x, y):
t = paddle.matmul(x, y)
# z = paddle.sum(paddle.sqrt(x))
return t

class TestPyPrimOps(unittest.TestCase):
""" Test Python wrappers of primitive ops. """
Expand Down Expand Up @@ -175,6 +180,22 @@ def loss(y, x):
for op in topo_path(vs, grads):
print(op)


def test_first_order_gradients(self):
x = np.random.rand(100, 1, 2)
w = np.random.rand(100, 2, 5)
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
X = paddle.static.data('Input', shape=[100, 1, 2], dtype='float32')
W = paddle.static.data('Weight', shape=[100, 2, 5], dtype='float32')
Z = prog1(X, W)
X_grad, W_grad = _gradients([Z], [X, W])
exe = paddle.static.Executor()
exe.run(startup)
z = exe.run(main, feed={'X': x, 'W': w}, fetch_list=[Z])
print(z)

def test_lower(self):
main = paddle.static.Program()
startup = paddle.static.Program()
Expand Down

0 comments on commit c203867

Please sign in to comment.