diff --git a/test/legacy_test/test_deg2rad.py b/test/legacy_test/test_deg2rad.py index 350471f896e69..b74033af894b2 100644 --- a/test/legacy_test/test_deg2rad.py +++ b/test/legacy_test/test_deg2rad.py @@ -19,6 +19,7 @@ import paddle from paddle import base from paddle.base import core +from paddle.pir_utils import test_with_pir_api paddle.enable_static() @@ -32,10 +33,11 @@ def setUp(self): self.x_shape = [6] self.out_np = np.deg2rad(self.x_np) + @test_with_pir_api def test_static_graph(self): - startup_program = base.Program() - train_program = base.Program() - with base.program_guard(startup_program, train_program): + startup_program = paddle.static.Program() + train_program = paddle.static.Program() + with paddle.static.program_guard(startup_program, train_program): x = paddle.static.data( name='input', dtype=self.x_dtype, shape=self.x_shape ) @@ -48,11 +50,12 @@ def test_static_graph(self): ) exe = base.Executor(place) res = exe.run( - base.default_main_program(), feed={'input': self.x_np}, fetch_list=[out], ) - self.assertTrue((np.array(out[0]) == self.out_np).all()) + np.testing.assert_allclose( + np.array(res[0]), self.out_np, rtol=1e-05 + ) def test_dygraph(self): paddle.disable_static() @@ -79,3 +82,7 @@ def test_dygraph(self): np.testing.assert_allclose(np.pi, result2.numpy(), rtol=1e-05) paddle.enable_static() + + +if __name__ == '__main__': + unittest.main()