Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR] No.10 Migrate silu into pir #57157

Merged
merged 11 commits into from
Sep 14, 2023
5 changes: 2 additions & 3 deletions python/paddle/nn/functional/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import paddle
from paddle import _C_ops, _legacy_C_ops, in_dynamic_mode
from paddle.framework import core
from paddle.framework import core, in_dynamic_or_new_ir_mode
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only

from ...base.data_feeder import check_dtype, check_variable_and_dtype
Expand Down Expand Up @@ -1053,14 +1053,13 @@ def silu(x, name=None):
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
"""

if in_dynamic_mode():
if in_dynamic_or_new_ir_mode():
return _C_ops.silu(x)
else:
check_variable_and_dtype(
x,
'x',
[
'float16',
'uint16',
'float32',
'float64',
Expand Down
8 changes: 8 additions & 0 deletions test/legacy_test/test_activation_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,14 @@ def init_dtype(self):
self.dtype = np.complex128


class TestSilu_NewIR(TestSilu):
GreatV marked this conversation as resolved.
Show resolved Hide resolved
def test_check_output(self):
self.check_output(check_new_ir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
def setUp(self):
Expand Down