Skip to content

Commit aaf185b

Browse files
Deivanayaki-Sdeivanayakisankaralingam
andauthored
[Relax][PyTorch] Support softshrink op for ExportedProgram (#17786)
* softshrink op support into exported program and test script code added * fix lint issue * update the formatting to fix lint issues * modify the code format to fix lint issue --------- Co-authored-by: deivanayakisankaralingam <deiva@Deivanayaki>
1 parent 3f16ec2 commit aaf185b

File tree

3 files changed

+85
-0
lines changed

3 files changed

+85
-0
lines changed

python/tvm/relax/frontend/torch/base_fx_graph_translator.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,39 @@ def _softmax(self, node: fx.Node) -> relax.Var:
307307
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", -1)
308308
return self.block_builder.emit(relax.op.nn.softmax(x, dim))
309309

310+
def _softshrink(self, node: fx.Node) -> relax.Var:
311+
"""
312+
Applies the Softshrink activation function in Relax.
313+
314+
Softshrink(x) =
315+
x - λ if x > λ
316+
x + λ if x < -λ
317+
0 otherwise
318+
319+
Args:
320+
node (fx.Node): The input node containing the tensor and lambda value.
321+
322+
Returns:
323+
relax.Var: The resulting tensor after applying Softshrink.
324+
"""
325+
args = self.retrieve_args(node)
326+
x = args[0]
327+
lambd = relax.const(args[1] if len(args) > 1 else 0.5, x.struct_info.dtype)
328+
329+
# Apply Softshrink transformation with masking
330+
shrink_pos = relax.op.multiply(
331+
relax.op.subtract(x, lambd),
332+
relax.op.astype(relax.op.greater(x, lambd), x.struct_info.dtype),
333+
)
334+
335+
shrink_neg = relax.op.multiply(
336+
relax.op.add(x, lambd),
337+
relax.op.astype(relax.op.less(x, relax.op.negative(lambd)), x.struct_info.dtype),
338+
)
339+
340+
# Combine the positive and negative shrink results
341+
return self.block_builder.emit(relax.op.add(shrink_pos, shrink_neg))
342+
310343
def _selu(self, node: fx.Node) -> relax.Var:
311344
x = self.env[node.args[0]]
312345
alpha = node.args[1] if len(node.args) > 1 else node.kwargs.get("alpha", 1.6732631921768188)

python/tvm/relax/frontend/torch/exported_program_translator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ def create_convert_map(
282282
"sin.default": self._unary_op(relax.op.sin),
283283
"sinh.default": self._unary_op(relax.op.sinh),
284284
"softmax.int": self._softmax,
285+
"softshrink.default": self._softshrink,
285286
"sqrt.default": self._unary_op(relax.op.sqrt),
286287
"square.default": self._unary_op(relax.op.square),
287288
"tan.default": self._unary_op(relax.op.tan),

tests/python/relax/test_frontend_from_exported_program.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,9 @@ def main(
607607
# softmax
608608
test_softmax()
609609

610+
# softshrink
611+
test_softshrink()
612+
610613
# tril, triu
611614
test_tril_triu()
612615

@@ -741,6 +744,54 @@ def main(
741744
verify_model(Softmax2(), example_args, {}, expected1)
742745

743746

747+
def test_softshrink():
748+
class Softshrink(Module):
749+
def __init__(self):
750+
super().__init__()
751+
self.softshrink = torch.nn.Softshrink(lambd=0.5)
752+
753+
def forward(self, input):
754+
return self.softshrink(input)
755+
756+
class Softshrink2(Module):
757+
def forward(self, input):
758+
return torch.nn.functional.softshrink(input, lambd=0.5)
759+
760+
@tvm.script.ir_module
761+
class expected_softshrink:
762+
@R.function
763+
def main(
764+
input: R.Tensor((1, 3, 10, 10), dtype="float32"),
765+
) -> R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")):
766+
with R.dataflow():
767+
lv: R.Tensor((1, 3, 10, 10), dtype="float32") = R.subtract(
768+
input, R.const(0.5, "float32")
769+
)
770+
lv1: R.Tensor((1, 3, 10, 10), dtype="bool") = R.greater(
771+
input, R.const(0.5, "float32")
772+
)
773+
lv2: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv1, "float32")
774+
lv3: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv, lv2)
775+
776+
lv4: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(
777+
input, R.const(0.5, "float32")
778+
)
779+
lv5: R.Tensor((), dtype="float32") = R.negative(R.const(0.5, "float32"))
780+
lv6: R.Tensor((1, 3, 10, 10), dtype="bool") = R.less(input, lv5)
781+
lv7: R.Tensor((1, 3, 10, 10), dtype="float32") = R.astype(lv6, "float32")
782+
lv8: R.Tensor((1, 3, 10, 10), dtype="float32") = R.multiply(lv4, lv7)
783+
784+
lv9: R.Tensor((1, 3, 10, 10), dtype="float32") = R.add(lv3, lv8)
785+
786+
gv: R.Tuple(R.Tensor((1, 3, 10, 10), dtype="float32")) = (lv9,)
787+
R.output(gv)
788+
return gv
789+
790+
example_args = (torch.randn(1, 3, 10, 10, dtype=torch.float32),)
791+
verify_model(Softshrink(), example_args, {}, expected_softshrink)
792+
verify_model(Softshrink2(), example_args, {}, expected_softshrink)
793+
794+
744795
def test_tril_triu():
745796
example_args = (torch.randn(10, 10, dtype=torch.float32),)
746797

0 commit comments

Comments
 (0)