Skip to content

Commit 3fffaab

Browse files
committed
add test
1 parent c924834 commit 3fffaab

File tree

1 file changed

+51
-1
lines changed

1 file changed

+51
-1
lines changed

test/legacy_test/test_silu_op.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import unittest
1717

1818
import numpy as np
19-
from op_test import OpTest
19+
from op_test import OpTest, get_places
2020

2121
import paddle
2222
import paddle.base.dygraph as dg
@@ -331,5 +331,55 @@ def init_dtype(self):
331331
self.dtype = np.complex128
332332

333333

334+
class TestSiluAPI(unittest.TestCase):
335+
def setUp(self):
336+
np.random.seed(0)
337+
self.shape = [10, 10]
338+
self.x_np = np.random.random(self.shape).astype(np.float32)
339+
self.place = get_places()
340+
self.x_feed = copy.deepcopy(self.x_np)
341+
342+
def test_api_static(self):
343+
paddle.enable_static()
344+
345+
def run(place, inplace):
346+
with paddle.static.program_guard(paddle.static.Program()):
347+
x = paddle.static.data('X', self.shape)
348+
out = F.silu(x, inplace)
349+
exe = paddle.static.Executor(self.place[0])
350+
res = exe.run(
351+
feed={
352+
'X': self.x_feed,
353+
},
354+
fetch_list=[out],
355+
)
356+
target = copy.deepcopy(self.x_np)
357+
out_ref = silu(target)
358+
359+
for out in res:
360+
np.testing.assert_allclose(out, out_ref, rtol=0.001)
361+
362+
for place in self.place:
363+
run(place, True)
364+
run(place, False)
365+
366+
def test_api_dygraph(self):
367+
def run(place, inplace):
368+
paddle.disable_static(place)
369+
x_tensor = paddle.to_tensor(self.x_np)
370+
out = F.silu(x_tensor, inplace)
371+
372+
target = copy.deepcopy(self.x_np)
373+
out_ref = silu(target)
374+
375+
np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001)
376+
377+
paddle.enable_static()
378+
379+
for place in self.place:
380+
run(place, True)
381+
run(place, False)
382+
383+
334384
if __name__ == '__main__':
335385
unittest.main()

0 commit comments

Comments
 (0)