From b4aab30151ac0367bb1d140925ae9cba94e9e86d Mon Sep 17 00:00:00 2001 From: WangZhen <23097963+0x45f@users.noreply.github.com> Date: Fri, 15 Sep 2023 10:34:52 +0800 Subject: [PATCH] [PIR]Migrate paddle.gelu into pir (#57317) --- python/paddle/nn/functional/activation.py | 2 +- test/legacy_test/test_activation_op.py | 10 ++++++---- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/paddle/nn/functional/activation.py b/python/paddle/nn/functional/activation.py index 4ba784f3b2d97..e02a47d7bf8dd 100644 --- a/python/paddle/nn/functional/activation.py +++ b/python/paddle/nn/functional/activation.py @@ -189,7 +189,7 @@ def gelu(x, approximate=False, name=None): [ 0.84119201, 1.39957154]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.gelu(x, approximate) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 295e1dfd9a8e8..77dcacc270cfd 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -2442,12 +2442,12 @@ def setUp(self): self.cinn_atol = 1e-8 def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_new_ir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGelu(TestActivation): @@ -2480,12 +2480,12 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_new_ir=True) def test_check_grad(self): if self.dtype == np.float16: return - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestGelu_ZeroDim(TestGelu): @@ -4466,6 +4466,7 @@ def test_check_grad(self): create_test_act_fp16_class( TestGelu, check_prim=True, + check_new_ir=True, enable_cinn=True, rev_comp_rtol=1e-3, rev_comp_atol=1e-3, @@ -4595,6 +4596,7 @@ def test_check_grad(self): create_test_act_bf16_class( TestGelu, check_prim=True, + check_new_ir=True, rev_comp_rtol=1e-2, rev_comp_atol=1e-2, cinn_rtol=1e-2,