diff --git a/python/paddle/tensor/ops.py b/python/paddle/tensor/ops.py index 3585c48b0afffb..18771c147b76ad 100644 --- a/python/paddle/tensor/ops.py +++ b/python/paddle/tensor/ops.py @@ -17,7 +17,7 @@ from .. import _C_ops from ..base.data_feeder import check_variable_and_dtype -from ..framework import LayerHelper, in_dynamic_mode +from ..framework import LayerHelper, in_dynamic_mode, in_dynamic_or_pir_mode from .layer_function_generator import ( add_sample_code, generate_activation_fn, @@ -878,7 +878,7 @@ def rsqrt(x, name=None): Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True, [3.16227770, 2.23606801, 1.82574177, 1.58113885]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.rsqrt(x) else: check_variable_and_dtype( diff --git a/test/legacy_test/test_activation_op.py b/test/legacy_test/test_activation_op.py index 77dcacc270cfde..7da773d16327fe 100644 --- a/test/legacy_test/test_activation_op.py +++ b/test/legacy_test/test_activation_op.py @@ -1578,7 +1578,7 @@ def if_enable_cinn(self): pass def test_check_output(self): - self.check_output(check_prim=True) + self.check_output(check_prim=True, check_new_ir=True) def test_check_grad(self): if self.dtype == np.float16: @@ -1588,6 +1588,7 @@ def test_check_grad(self): 'Out', max_relative_error=0.0005, check_prim=True, + check_new_ir=True, ) @@ -4508,7 +4509,9 @@ def test_check_grad(self): TestLeakyReluAlpha3, check_prim=True, enable_cinn=True ) create_test_act_fp16_class(TestLeakyRelu_ZeroDim, check_prim=True) -create_test_act_fp16_class(TestRsqrt, check_prim=True, enable_cinn=True) +create_test_act_fp16_class( + TestRsqrt, check_prim=True, enable_cinn=True, check_new_ir=True +) def create_test_act_bf16_class( @@ -4631,7 +4634,7 @@ def test_check_grad(self): create_test_act_bf16_class(TestLeakyReluAlpha2, check_prim=True) create_test_act_bf16_class(TestLeakyReluAlpha3, check_prim=True) create_test_act_bf16_class(TestLeakyRelu_ZeroDim, check_prim=True) -create_test_act_bf16_class(TestRsqrt, check_prim=True) +create_test_act_bf16_class(TestRsqrt, check_prim=True, check_new_ir=True) if __name__ == "__main__": unittest.main()