|
16 | 16 | import unittest |
17 | 17 |
|
18 | 18 | import numpy as np |
19 | | -from op_test import OpTest |
| 19 | +from op_test import OpTest, get_places |
20 | 20 |
|
21 | 21 | import paddle |
22 | 22 | import paddle.base.dygraph as dg |
@@ -331,5 +331,55 @@ def init_dtype(self): |
331 | 331 | self.dtype = np.complex128 |
332 | 332 |
|
333 | 333 |
|
| 334 | +class TestSiluAPI(unittest.TestCase): |
| 335 | + def setUp(self): |
| 336 | + np.random.seed(0) |
| 337 | + self.shape = [10, 10] |
| 338 | + self.x_np = np.random.random(self.shape).astype(np.float32) |
| 339 | + self.place = get_places() |
| 340 | + self.x_feed = copy.deepcopy(self.x_np) |
| 341 | + |
| 342 | + def test_api_static(self): |
| 343 | + paddle.enable_static() |
| 344 | + |
| 345 | + def run(place, inplace): |
| 346 | + with paddle.static.program_guard(paddle.static.Program()): |
| 347 | + x = paddle.static.data('X', self.shape) |
| 348 | + out = F.silu(x, inplace) |
| 349 | + exe = paddle.static.Executor(self.place[0]) |
| 350 | + res = exe.run( |
| 351 | + feed={ |
| 352 | + 'X': self.x_feed, |
| 353 | + }, |
| 354 | + fetch_list=[out], |
| 355 | + ) |
| 356 | + target = copy.deepcopy(self.x_np) |
| 357 | + out_ref = silu(target) |
| 358 | + |
| 359 | + for out in res: |
| 360 | + np.testing.assert_allclose(out, out_ref, rtol=0.001) |
| 361 | + |
| 362 | + for place in self.place: |
| 363 | + run(place, True) |
| 364 | + run(place, False) |
| 365 | + |
| 366 | + def test_api_dygraph(self): |
| 367 | + def run(place, inplace): |
| 368 | + paddle.disable_static(place) |
| 369 | + x_tensor = paddle.to_tensor(self.x_np) |
| 370 | + out = F.silu(x_tensor, inplace) |
| 371 | + |
| 372 | + target = copy.deepcopy(self.x_np) |
| 373 | + out_ref = silu(target) |
| 374 | + |
| 375 | + np.testing.assert_allclose(out.numpy(), out_ref, rtol=0.001) |
| 376 | + |
| 377 | + paddle.enable_static() |
| 378 | + |
| 379 | + for place in self.place: |
| 380 | + run(place, True) |
| 381 | + run(place, False) |
| 382 | + |
| 383 | + |
334 | 384 | if __name__ == '__main__': |
335 | 385 | unittest.main() |
0 commit comments