Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
levi131 committed Apr 14, 2022
1 parent 21c8113 commit 7bf1a96
Showing 1 changed file with 25 additions and 23 deletions.
48 changes: 25 additions & 23 deletions python/paddle/fluid/tests/unittests/test_primops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,31 @@
import numpy as np

import paddle
from paddle.autograd.primops import (neg, add, sub, mul, div, sqrt, tanh,
reshape, broadcast, transpose, split,
concat, reduce, matmul, slice_select,
slice_assign, gather, scatter_add,
fill_const)
from paddle.autograd.primops import (
neg, add, sub, mul, div, sqrt, tanh, reshape, broadcast, transpose, split,
concat, reduce, matmul, slice_select, slice_assign, gather, scatter_add,
fill_const)
from paddle.autograd.primx import Transform


class TestPyPrimOps(unittest.TestCase):
""" Test Python wrappers of primitive ops. """

def setUp(self):
paddle.enable_static()


def test_ops(self):
A = np.random.rand(1)
B = np.random.rand(2)
C = np.random.rand(2, 3)
D = np.random.rand(2, 3)
E = np.random.rand(3, 2)

a = paddle.static.data(name='A', shape=A.shape, dtype='float')
b = paddle.static.data(name='B', shape=B.shape, dtype='float')
c = paddle.static.data(name='C', shape=C.shape, dtype='float')
d = paddle.static.data(name='D', shape=D.shape, dtype='float')
e = paddle.static.data(name='E', shape=E.shape, dtype='float')
a = paddle.static.data(name='A', shape=A.shape, dtype='float32')
b = paddle.static.data(name='B', shape=B.shape, dtype='float32')
c = paddle.static.data(name='C', shape=C.shape, dtype='float32')
d = paddle.static.data(name='D', shape=D.shape, dtype='float32')
e = paddle.static.data(name='E', shape=E.shape, dtype='float32')

add_1 = add(a, a)
self.assertEqual(add_1.dtype, a.dtype)
Expand Down Expand Up @@ -94,39 +93,39 @@ def test_ops(self):

reduce_1 = reduce(d, axis=[1])
self.assertEqual(reduce_1.dtype, d.dtype)
self.assertEqual(reduce_1.shape, (2,))
self.assertEqual(reduce_1.shape, (2, ))

reduce_2 = reduce(c, axis=[0, 1])
self.assertEqual(reduce_2.dtype, c.dtype)
self.assertEqual(reduce_2.shape, (1,))
self.assertEqual(reduce_2.shape, (1, ))
# TODO: reduce + keepdim

matmul_1 = matmul(d, e)
self.assertEqual(matmul_1.dtype, d.dtype)
self.assertEqual(matmul_1.shape, (2, 2))

slice_select_1 = slice_select(e, axis=[0], starts=[0], ends=[2],
strides=[1])
slice_select_1 = slice_select(
e, axis=[0], starts=[0], ends=[2], strides=[1])
self.assertEqual(slice_select_1.dtype, e.dtype)
self.assertEqual(slice_select_1.shape, (2, 2))
slice_select_2 = slice_select(d, axis=[0, 1], starts=[0, 1],
ends=[2, 3], strides=[1, 2])

slice_select_2 = slice_select(
d, axis=[0, 1], starts=[0, 1], ends=[2, 3], strides=[1, 2])
self.assertEqual(slice_select_2.dtype, d.dtype)
self.assertEqual(slice_select_2.shape, (2, 1))

y = broadcast(b, [2, 2])
slice_assign_1 = slice_assign(d, y, axis=[1], starts=[1], ends=[3],
strides=[1])
slice_assign_1 = slice_assign(
d, y, axis=[1], starts=[1], ends=[3], strides=[1])
self.assertEqual(slice_assign_1.dtype, d.dtype)
self.assertEqual(slice_assign_1.shape, d.shape)

index = paddle.static.data('index', shape=[5], dtype='int')
index = paddle.static.data('index', shape=[5], dtype='int32')
gather_1 = gather(e, index, axis=0)
self.assertEqual(gather_1.dtype, e.dtype)
self.assertEqual(gather_1.shape, (5, 2))

y = paddle.rand([5, 2])
y = paddle.rand([5, 2], dtype='float32')
scatter_add_1 = scatter_add(e, y, index, axis=0)
self.assertEqual(scatter_add_1.dtype, e.dtype)
self.assertEqual(scatter_add_1.shape, e.shape)
Expand All @@ -142,13 +141,16 @@ def test_linearize(self):
X_ = reshape(X, shape=[100, 2, 1])
Z = tanh(matmul(W_, X_))
Y = reduce(Z, axis=[1, 2])

def loss(y, x):
ad = Transform(y.block)
xs_dot, ys_dot = ad.linearize([x], [y])
ys_bar, xs_bar = ad.transpose(ys_dot, xs_dot)
return xs_bar

grad, = loss(Y, W)
assert grad.shape == W.shape


if __name__ == '__main__':
unittest.main()
unittest.main()

0 comments on commit 7bf1a96

Please sign in to comment.