Skip to content

Commit bb71251

Browse files
authored
[API compatibility] support inplace and input parameter for silu api (#74788)
* support inplace and input parameter for silu api * add test * change position * fix codestyle * add print test for silu * add test
1 parent 1daec12 commit bb71251

File tree

6 files changed

+419
-6
lines changed

6 files changed

+419
-6
lines changed

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,7 @@ OP_SAME_OPERANDS_AND_RESULT(Polygamma_)
195195
OP_SAME_OPERANDS_AND_RESULT(EnableCheckModelNanInf)
196196
OP_SAME_OPERANDS_AND_RESULT(ViewShape)
197197
OP_SAME_OPERANDS_AND_RESULT(Silu)
198+
OP_SAME_OPERANDS_AND_RESULT(Silu_)
198199
OP_SAME_OPERANDS_AND_RESULT(ViewDtype)
199200
OP_SAME_OPERANDS_AND_RESULT(FusedSoftmaxMaskUpperTriangle)
200201
OP_SAME_OPERANDS_AND_RESULT(Gammaln)

paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/same_operands_result.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShadowFeed)
151151
OP_DECLARE_INFER_SYMBOLIC_SHAPE(ShareData_)
152152
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sign)
153153
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu)
154+
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Silu_)
154155
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin)
155156
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sin_)
156157
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Sinh)

paddle/phi/ops/yaml/ops.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5066,12 +5066,13 @@
50665066

50675067
- op : silu
50685068
args : (Tensor x)
5069-
output : Tensor
5069+
output : Tensor(out)
50705070
infer_meta :
50715071
func : UnchangedInferMeta
50725072
spmd_rule : ElementwiseUnaryInferSpmd
50735073
kernel :
50745074
func : silu
5075+
inplace : (x -> out)
50755076
backward : silu_grad
50765077
interfaces : paddle::dialect::LayoutTransformationInterface, paddle::dialect::InferSymbolicShapeInterface
50775078
traits: pir::UnaryElementWiseTrait

python/paddle/nn/functional/activation.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1078,7 +1078,7 @@ def selu(
10781078

10791079

10801080
@param_one_alias(["x", "input"])
1081-
def silu(x: Tensor, name: str | None = None) -> Tensor:
1081+
def silu(x: Tensor, inplace: bool = False, name: str | None = None) -> Tensor:
10821082
r"""
10831083
silu activation
10841084
@@ -1095,6 +1095,7 @@ def silu(x: Tensor, name: str | None = None) -> Tensor:
10951095
Parameters:
10961096
x (Tensor): The input Tensor with data type bfloat16, float16, float32, float64, complex64, complex128.
10971097
alias: ``input``.
1098+
inplace (bool, optional): Whether to use inplace operation. Default: False.
10981099
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
10991100
11001101
Returns:
@@ -1111,10 +1112,21 @@ def silu(x: Tensor, name: str | None = None) -> Tensor:
11111112
>>> print(out)
11121113
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
11131114
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
1115+
1116+
>>> out = F.silu(x, True)
1117+
>>> print(out)
1118+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
1119+
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
1120+
>>> print(x)
1121+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
1122+
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
11141123
"""
11151124

11161125
if in_dynamic_or_pir_mode():
1117-
return _C_ops.silu(x)
1126+
if inplace:
1127+
return _C_ops.silu_(x)
1128+
else:
1129+
return _C_ops.silu(x)
11181130
else:
11191131
check_variable_and_dtype(
11201132
x,

python/paddle/nn/layer/activation.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1263,6 +1263,7 @@ class Silu(Layer):
12631263
Where :math:`x` is the input Tensor.
12641264
12651265
Parameters:
1266+
inplace (bool, optional): Whether to use inplace operation. Default: False.
12661267
name (str|None, optional): For details, please refer to :ref:`api_guide_Name`. Generally, no setting is required. Default: None.
12671268
12681269
Shape:
@@ -1280,17 +1281,29 @@ class Silu(Layer):
12801281
>>> print(out)
12811282
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
12821283
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
1284+
1285+
>>> m = paddle.nn.Silu(True)
1286+
>>> out = m(x)
1287+
>>> print(out)
1288+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
1289+
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
1290+
>>> print(x)
1291+
Tensor(shape=[4], dtype=float32, place=Place(cpu), stop_gradient=True,
1292+
[0.73105860, 1.76159406, 2.85772228, 3.92805505])
12831293
"""
12841294

1285-
def __init__(self, name: str | None = None) -> str:
1295+
def __init__(self, inplace: bool = False, name: str | None = None) -> str:
12861296
super().__init__()
12871297
self._name = name
1298+
self._inplace = inplace
12881299

12891300
def forward(self, x: Tensor) -> Tensor:
1290-
return F.silu(x, self._name)
1301+
return F.silu(x, self._inplace, self._name)
12911302

12921303
def extra_repr(self) -> str:
1293-
name_str = f'name={self._name}' if self._name else ''
1304+
name_str = f'inplace={self._inplace}' + (
1305+
f', name={self._name}' if self._name else ''
1306+
)
12941307
return name_str
12951308

12961309

0 commit comments

Comments
 (0)