Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【SCU】【Paddle TensorRT No.56】Add pd_op.tanh_shrink converter #69693

Open
wants to merge 25 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
c159fbf
add_tanhshrink
PolaKuma Nov 25, 2024
a08c06d
Merge branch 'PaddlePaddle:develop' into add_Tanhshrink
PolaKuma Nov 25, 2024
e7d0e8b
fix
PolaKuma Nov 26, 2024
eccbe09
Merge branch 'develop' into add_Tanhshrink
PolaKuma Nov 26, 2024
a88cd02
fix codestyle
PolaKuma Nov 26, 2024
70aa755
Merge branch 'develop' into add_Tanhshrink
PolaKuma Nov 27, 2024
3bfafb3
fix codestyle
PolaKuma Nov 28, 2024
8ff507d
fix codestyle
PolaKuma Nov 28, 2024
257b3ec
Merge branch 'PaddlePaddle:develop' into add_Tanhshrink
PolaKuma Dec 2, 2024
7d6b9cc
Merge branch 'develop' into add_Tanhshrink
PolaKuma Dec 4, 2024
5ac1af2
Merge branch 'develop' into add_Tanhshrink
PolaKuma Dec 11, 2024
95ca775
Update test_converter_activation.py
PolaKuma Dec 11, 2024
581c220
fix codestyle
PolaKuma Dec 11, 2024
c9e9b6d
add_fp16
PolaKuma Dec 16, 2024
31b0c56
fix
PolaKuma Dec 16, 2024
b0dd283
Merge branch 'PaddlePaddle:develop' into add_Tanhshrink
PolaKuma Dec 18, 2024
d1f967f
Merge branch 'develop' into add_Tanhshrink
PolaKuma Dec 23, 2024
786c025
Update test_converter_activation.py
PolaKuma Dec 24, 2024
de96546
Update trt_op_marker_pass.cc
PolaKuma Dec 24, 2024
b12c5a7
Merge branch 'PaddlePaddle:develop' into add_Tanhshrink
PolaKuma Dec 24, 2024
8d3867e
Update activation.py
PolaKuma Dec 24, 2024
4cf64e5
Update activation.py
PolaKuma Dec 25, 2024
75fec01
update
PolaKuma Dec 26, 2024
00603b9
Merge branch 'PaddlePaddle:develop' into add_Tanhshrink
PolaKuma Jan 3, 2025
2728c5f
Update test_converter_activation.py
PolaKuma Jan 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ class ActOpPattern : public pir::OpRewritePattern<OpType> {
};
using TanhOpPattern = ActOpPattern<paddle::dialect::TanhOp>;
using CeluOpPattern = ActOpPattern<paddle::dialect::CeluOp>;
using TanhShrinkOpPattern = ActOpPattern<paddle::dialect::TanhShrinkOp>;
using LogicalNotOpPattern = ActOpPattern<paddle::dialect::LogicalNotOp>;
using LogicalNot_OpPattern = ActOpPattern<paddle::dialect::LogicalNot_Op>;

Expand Down Expand Up @@ -2302,6 +2303,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<ClipPattern>(context));
ps.Add(std::make_unique<GridSampleOpPattern>(context));
ps.Add(std::make_unique<StackOpPattern>(context));
ps.Add(std::make_unique<TanhShrinkOpPattern>(context));
ps.Add(std::make_unique<WherePattern>(context));
ps.Add(std::make_unique<FullLikeOpPattern>(context));
ps.Add(std::make_unique<FullWithTensorPattern>(context));
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/tensorrt/impls/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,16 @@ def swish_silu_converter(network, paddle_op, inputs):
return trt_prod(network, inputs[0], layer_output)


@converter_registry.register("pd_op.tanh_shrink", trt_version="8.x")
def tanh_shrink_converter(network, paddle_op, inputs):
x = inputs[0]
tanh_layer = network.add_activation(x, trt.ActivationType.TANH)
subtract_layer = network.add_elementwise(
x, tanh_layer.get_output(0), trt.ElementWiseOperation.SUB
)
return subtract_layer.get_output(0)


@converter_registry.register("pd_op.stanh", trt_version="8.x")
def stanh_converter(network, paddle_op, inputs):
x = inputs[0]
Expand Down
17 changes: 17 additions & 0 deletions test/tensorrt/test_converter_activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,23 @@ def test_trt_result(self):
self.check_trt_result()


class TestTanhShrinkOpFloatTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle._C_ops.tanh_shrink
self.api_args = {
"x": np.random.randn(2, 3).astype("float32"),
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result_fp16(self):
self.check_trt_result(precision_mode="fp16")

def test_trt_result_fp32(self):
self.check_trt_result()


class TestStanhFloatTRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.stanh
Expand Down
Loading