From 37bb168272e05476f9b72f1d0287301a1ce6b99a Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Sun, 13 Oct 2024 08:53:42 +0000 Subject: [PATCH 01/10] add op_greater_less_than --- .../transforms/tensorrt/trt_op_marker_pass.cc | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index e8890d6156deb7..9465393eec5c81 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -931,6 +931,60 @@ class GreaterEqualOpPattern return true; } }; +class GreaterThanOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::GreaterThanOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8400) + VLOG(3) << "GreaterThanOp is not supported when TensorRT < 8.4"; + return false; +#else + pir::Value x = op.operand_source(0); + pir::Value y = op.operand_source(1); + auto x_dtype = pir::GetDataTypeFromValue(x); + auto y_dtype = pir::GetDataTypeFromValue(y); + if (x_dtype.isa() || y_dtype.isa()) { + VLOG(3) << "Greater_than op do not support bool datatype"; + return false; + } +#endif + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; +class LessThanOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::LessThanOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8400) + VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; + return false; +#else + pir::Value x = op.operand_source(0); + pir::Value y = op.operand_source(1); + auto x_dtype = pir::GetDataTypeFromValue(x); + auto y_dtype = pir::GetDataTypeFromValue(y); + if (x_dtype.isa() || y_dtype.isa()) { + VLOG(3) << "Less_than op do not support bool datatype"; + return false; + } +#endif + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } +}; class MultiplyOpPattern : public pir::OpRewritePattern { public: @@ -1521,6 +1575,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); From 55ca34614e70cb74106107b3702bd4f85e621495 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Mon, 14 Oct 2024 02:06:19 +0000 Subject: [PATCH 02/10] add converter --- python/paddle/tensorrt/impls/logic.py | 34 +++++++++++++++++ test/tensorrt/test_converter_logic.py | 54 +++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 python/paddle/tensorrt/impls/logic.py create mode 100644 test/tensorrt/test_converter_logic.py diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py new file mode 100644 index 00000000000000..a4917e20f98965 --- /dev/null +++ b/python/paddle/tensorrt/impls/logic.py @@ -0,0 +1,34 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorrt as trt + +from paddle.tensorrt.converter_utils import ( + add_elementwise_layer, +) +from paddle.tensorrt.register import converter_registry + + +@converter_registry.register("pd_op.greater_than", trt_version="8.x") +def substract_converter(network, paddle_op, inputs): + return add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.GREATER + ) + + +@converter_registry.register("pd_op.less_than", trt_version="8.x") +def multiply_converter(network, paddle_op, inputs): + return add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.LESS + ) diff --git a/test/tensorrt/test_converter_logic.py b/test/tensorrt/test_converter_logic.py new file mode 100644 index 00000000000000..8a3c6460caf5e3 --- /dev/null +++ b/test/tensorrt/test_converter_logic.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from tensorrt_test_base import TensorRTBaseTest + +import paddle + + +class TestGreaterThanTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.greater_than + self.api_args = { + "x": np.random.randn(2, 3).astype(np.float32), + "y": np.random.randn(2, 3).astype(np.float32), + } + self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1, 3], "y": [1, 3]} + self.max_shape = {"x": [5, 3], "y": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestLessThanTRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.less_than + self.api_args = { + "x": np.random.randn(2, 3).astype(np.float32), + "y": np.random.randn(2, 3).astype(np.float32), + } + self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1, 3], "y": [1, 3]} + self.max_shape = {"x": [5, 3], "y": [5, 3]} + + def test_trt_result(self): + self.check_trt_result() + + +if __name__ == '__main__': + unittest.main() From c17330d813b5be538f03f9dc5a5ec2027adfc525 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Mon, 14 Oct 2024 08:05:42 +0000 Subject: [PATCH 03/10] mod logic converter --- python/paddle/tensorrt/converter.py | 1 + python/paddle/tensorrt/impls/logic.py | 11 +++++++---- test/tensorrt/test_converter_logic.py | 16 ++++++++-------- 3 files changed, 16 insertions(+), 12 deletions(-) diff --git a/python/paddle/tensorrt/converter.py b/python/paddle/tensorrt/converter.py index 19f3aebb07116f..d04831b59bc2d9 100644 --- a/python/paddle/tensorrt/converter.py +++ b/python/paddle/tensorrt/converter.py @@ -35,6 +35,7 @@ from .impls.conv import * # noqa: F403 from .impls.creation import * # noqa: F403 from .impls.linalg import * # noqa: F403 +from .impls.logic import * # noqa: F403 from .impls.manipulation import * # noqa: F403 from .impls.math import * # noqa: F403 from .impls.norm import * # noqa: F403 diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index a4917e20f98965..98fd7064bf7d03 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -16,19 +16,22 @@ from paddle.tensorrt.converter_utils import ( add_elementwise_layer, + trt_cast, ) from paddle.tensorrt.register import converter_registry @converter_registry.register("pd_op.greater_than", trt_version="8.x") -def substract_converter(network, paddle_op, inputs): - return add_elementwise_layer( +def greater_than_converter(network, paddle_op, inputs): + layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.GREATER ) + return trt_cast(network, layer_output, trt.float32) @converter_registry.register("pd_op.less_than", trt_version="8.x") -def multiply_converter(network, paddle_op, inputs): - return add_elementwise_layer( +def less_than_converter(network, paddle_op, inputs): + layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.LESS ) + return trt_cast(network, layer_output, trt.float32) diff --git a/test/tensorrt/test_converter_logic.py b/test/tensorrt/test_converter_logic.py index 8a3c6460caf5e3..e607a4ec059424 100644 --- a/test/tensorrt/test_converter_logic.py +++ b/test/tensorrt/test_converter_logic.py @@ -24,12 +24,12 @@ class TestGreaterThanTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.greater_than self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), - "y": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(3).astype(np.float32), + "y": np.random.randn(3).astype(np.float32), } self.program_config = {"feed_list": ["x", "y"]} - self.min_shape = {"x": [1, 3], "y": [1, 3]} - self.max_shape = {"x": [5, 3], "y": [5, 3]} + self.min_shape = {"x": [1], "y": [1]} + self.max_shape = {"x": [5], "y": [5]} def test_trt_result(self): self.check_trt_result() @@ -39,12 +39,12 @@ class TestLessThanTRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.less_than self.api_args = { - "x": np.random.randn(2, 3).astype(np.float32), - "y": np.random.randn(2, 3).astype(np.float32), + "x": np.random.randn(3).astype(np.float32), + "y": np.random.randn(3).astype(np.float32), } self.program_config = {"feed_list": ["x", "y"]} - self.min_shape = {"x": [1, 3], "y": [1, 3]} - self.max_shape = {"x": [5, 3], "y": [5, 3]} + self.min_shape = {"x": [1], "y": [1]} + self.max_shape = {"x": [5], "y": [5]} def test_trt_result(self): self.check_trt_result() From 13ae2d0f028855cb00bccecbb6e72d20e5c8407e Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Mon, 14 Oct 2024 13:01:09 +0000 Subject: [PATCH 04/10] mod converter_marker --- .../transforms/tensorrt/trt_op_marker_pass.cc | 80 ++++++------------- python/paddle/tensorrt/impls/logic.py | 4 +- 2 files changed, 25 insertions(+), 59 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 3aeb040452a299..bd43526a8187e5 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -907,19 +907,21 @@ class SplitWithNumOpPattern return true; } }; -class GreaterEqualOpPattern - : public pir::OpRewritePattern { + +template +class LogicOpPattern : public pir::OpRewritePattern { public: - using pir::OpRewritePattern< - paddle::dialect::GreaterEqualOp>::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::GreaterEqualOp op, + using pir::OpRewritePattern::OpRewritePattern; + + bool MatchAndRewrite(OpType op, pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { + op->template attribute(kCanRunTrtAttr).data()) { return false; } #if IS_TRT_VERSION_LT(8400) - VLOG(3) << "GreaterEqualOp is not supported when TensorRT < 8.4"; + VLOG(3) << OpType::getOperationName() + << " is not supported when TensorRT < 8.4"; return false; #else pir::Value x = op.operand_source(0); @@ -927,7 +929,7 @@ class GreaterEqualOpPattern auto x_dtype = pir::GetDataTypeFromValue(x); auto y_dtype = pir::GetDataTypeFromValue(y); if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << "Greate_equal op do not support bool datatype"; + VLOG(3) << op.name() << " do not support bool datatype"; return false; } #endif @@ -935,60 +937,24 @@ class GreaterEqualOpPattern return true; } }; + +class GreaterEqualOpPattern + : public LogicOpPattern { + public: + using LogicOpPattern::LogicOpPattern; +}; + class GreaterThanOpPattern - : public pir::OpRewritePattern { + : public LogicOpPattern { public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::GreaterThanOp op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } -#if IS_TRT_VERSION_LT(8400) - VLOG(3) << "GreaterThanOp is not supported when TensorRT < 8.4"; - return false; -#else - pir::Value x = op.operand_source(0); - pir::Value y = op.operand_source(1); - auto x_dtype = pir::GetDataTypeFromValue(x); - auto y_dtype = pir::GetDataTypeFromValue(y); - if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << "Greater_than op do not support bool datatype"; - return false; - } -#endif - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } + using LogicOpPattern::LogicOpPattern; }; -class LessThanOpPattern - : public pir::OpRewritePattern { + +class LessThanOpPattern : public LogicOpPattern { public: - using pir::OpRewritePattern::OpRewritePattern; - bool MatchAndRewrite(paddle::dialect::LessThanOp op, - pir::PatternRewriter &rewriter) const override { - if (op->HasAttribute(kCanRunTrtAttr) && - op->attribute(kCanRunTrtAttr).data()) { - return false; - } -#if IS_TRT_VERSION_LT(8400) - VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; - return false; -#else - pir::Value x = op.operand_source(0); - pir::Value y = op.operand_source(1); - auto x_dtype = pir::GetDataTypeFromValue(x); - auto y_dtype = pir::GetDataTypeFromValue(y); - if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << "Less_than op do not support bool datatype"; - return false; - } -#endif - op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); - return true; - } + using LogicOpPattern::LogicOpPattern; }; + class MultiplyOpPattern : public pir::OpRewritePattern { public: diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index 98fd7064bf7d03..aaae63392c10a6 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -26,7 +26,7 @@ def greater_than_converter(network, paddle_op, inputs): layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.GREATER ) - return trt_cast(network, layer_output, trt.float32) + return trt_cast(network, layer_output, inputs[0].dtype) @converter_registry.register("pd_op.less_than", trt_version="8.x") @@ -34,4 +34,4 @@ def less_than_converter(network, paddle_op, inputs): layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.LESS ) - return trt_cast(network, layer_output, trt.float32) + return trt_cast(network, layer_output, inputs[0].dtype) From 2d0fbb2e48e8dd1432f407c313029330888ba50a Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Mon, 14 Oct 2024 13:23:09 +0000 Subject: [PATCH 05/10] restore --- .../transforms/tensorrt/trt_op_marker_pass.cc | 76 ++++++++++++++----- 1 file changed, 57 insertions(+), 19 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index bd43526a8187e5..61e60f45c71408 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -908,20 +908,19 @@ class SplitWithNumOpPattern } }; -template -class LogicOpPattern : public pir::OpRewritePattern { +class GreaterEqualOpPattern + : public pir::OpRewritePattern { public: - using pir::OpRewritePattern::OpRewritePattern; - - bool MatchAndRewrite(OpType op, + using pir::OpRewritePattern< + paddle::dialect::GreaterEqualOp>::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::GreaterEqualOp op, pir::PatternRewriter &rewriter) const override { if (op->HasAttribute(kCanRunTrtAttr) && - op->template attribute(kCanRunTrtAttr).data()) { + op->attribute(kCanRunTrtAttr).data()) { return false; } #if IS_TRT_VERSION_LT(8400) - VLOG(3) << OpType::getOperationName() - << " is not supported when TensorRT < 8.4"; + VLOG(3) << "GreaterEqualOp is not supported when TensorRT < 8.4"; return false; #else pir::Value x = op.operand_source(0); @@ -929,7 +928,7 @@ class LogicOpPattern : public pir::OpRewritePattern { auto x_dtype = pir::GetDataTypeFromValue(x); auto y_dtype = pir::GetDataTypeFromValue(y); if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << op.name() << " do not support bool datatype"; + VLOG(3) << "Greate_equal op do not support bool datatype"; return false; } #endif @@ -938,21 +937,60 @@ class LogicOpPattern : public pir::OpRewritePattern { } }; -class GreaterEqualOpPattern - : public LogicOpPattern { - public: - using LogicOpPattern::LogicOpPattern; -}; - class GreaterThanOpPattern - : public LogicOpPattern { + : public pir::OpRewritePattern { public: - using LogicOpPattern::LogicOpPattern; + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::GreaterThanOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8400) + VLOG(3) << "GreaterThanOp is not supported when TensorRT < 8.4"; + return false; +#else + pir::Value x = op.operand_source(0); + pir::Value y = op.operand_source(1); + auto x_dtype = pir::GetDataTypeFromValue(x); + auto y_dtype = pir::GetDataTypeFromValue(y); + if (x_dtype.isa() || y_dtype.isa()) { + VLOG(3) << "Greater_than op do not support bool datatype"; + return false; + } +#endif + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } }; -class LessThanOpPattern : public LogicOpPattern { +class LessThanOpPattern + : public pir::OpRewritePattern { public: - using LogicOpPattern::LogicOpPattern; + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::LessThanOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } +#if IS_TRT_VERSION_LT(8400) + VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; + return false; +#else + pir::Value x = op.operand_source(0); + pir::Value y = op.operand_source(1); + auto x_dtype = pir::GetDataTypeFromValue(x); + auto y_dtype = pir::GetDataTypeFromValue(y); + if (x_dtype.isa() || y_dtype.isa()) { + VLOG(3) << "Less_than op do not support bool datatype"; + return false; + } +#endif + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } }; class MultiplyOpPattern From 5118ff1c793434c57046ab4c185da64f703ede13 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Tue, 15 Oct 2024 12:30:07 +0000 Subject: [PATCH 06/10] merge converter --- python/paddle/tensorrt/converter_utils.py | 20 ++++++++++++++++++++ python/paddle/tensorrt/impls/logic.py | 18 +++--------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/python/paddle/tensorrt/converter_utils.py b/python/paddle/tensorrt/converter_utils.py index 337d876b3df2a1..b35fc74845ec9e 100644 --- a/python/paddle/tensorrt/converter_utils.py +++ b/python/paddle/tensorrt/converter_utils.py @@ -322,3 +322,23 @@ def build_size_tensor( ).get_output(0) return size_tensor + + +def trt_greater_than(network, paddle_op, inputs): + layer_output = add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.GREATER + ) + return trt_cast(network, layer_output, inputs[0].dtype) + + +def trt_less_than(network, paddle_op, inputs): + layer_output = add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.LESS + ) + return trt_cast(network, layer_output, inputs[0].dtype) + + +elementwise_map = { + "pd_op.greater_than": trt_greater_than, + "pd_op.less_than": trt_less_than, +} diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index aaae63392c10a6..399e26fc5cb906 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -12,26 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -import tensorrt as trt from paddle.tensorrt.converter_utils import ( - add_elementwise_layer, - trt_cast, + elementwise_map, ) from paddle.tensorrt.register import converter_registry @converter_registry.register("pd_op.greater_than", trt_version="8.x") -def greater_than_converter(network, paddle_op, inputs): - layer_output = add_elementwise_layer( - network, paddle_op, inputs, trt.ElementWiseOperation.GREATER - ) - return trt_cast(network, layer_output, inputs[0].dtype) - - @converter_registry.register("pd_op.less_than", trt_version="8.x") -def less_than_converter(network, paddle_op, inputs): - layer_output = add_elementwise_layer( - network, paddle_op, inputs, trt.ElementWiseOperation.LESS - ) - return trt_cast(network, layer_output, inputs[0].dtype) +def logic_converter(network, paddle_op, inputs): + return elementwise_map[paddle_op.name()](network, paddle_op, inputs) From 84baf4b346f1fb5f7f59ccdd335caa8521a985c6 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Wed, 16 Oct 2024 08:12:59 +0000 Subject: [PATCH 07/10] rm utils --- python/paddle/tensorrt/converter_utils.py | 20 -------------------- python/paddle/tensorrt/impls/logic.py | 14 ++++++++++++-- 2 files changed, 12 insertions(+), 22 deletions(-) diff --git a/python/paddle/tensorrt/converter_utils.py b/python/paddle/tensorrt/converter_utils.py index b35fc74845ec9e..337d876b3df2a1 100644 --- a/python/paddle/tensorrt/converter_utils.py +++ b/python/paddle/tensorrt/converter_utils.py @@ -322,23 +322,3 @@ def build_size_tensor( ).get_output(0) return size_tensor - - -def trt_greater_than(network, paddle_op, inputs): - layer_output = add_elementwise_layer( - network, paddle_op, inputs, trt.ElementWiseOperation.GREATER - ) - return trt_cast(network, layer_output, inputs[0].dtype) - - -def trt_less_than(network, paddle_op, inputs): - layer_output = add_elementwise_layer( - network, paddle_op, inputs, trt.ElementWiseOperation.LESS - ) - return trt_cast(network, layer_output, inputs[0].dtype) - - -elementwise_map = { - "pd_op.greater_than": trt_greater_than, - "pd_op.less_than": trt_less_than, -} diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index 399e26fc5cb906..5bc5ffee277ff3 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +import tensorrt as trt from paddle.tensorrt.converter_utils import ( - elementwise_map, + add_elementwise_layer, + trt_cast, ) from paddle.tensorrt.register import converter_registry @@ -22,4 +24,12 @@ @converter_registry.register("pd_op.greater_than", trt_version="8.x") @converter_registry.register("pd_op.less_than", trt_version="8.x") def logic_converter(network, paddle_op, inputs): - return elementwise_map[paddle_op.name()](network, paddle_op, inputs) + if paddle_op.name() == "pd_op.greater_than": + layer_output = add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.GREATER + ) + else: + layer_output = add_elementwise_layer( + network, paddle_op, inputs, trt.ElementWiseOperation.LESS + ) + return trt_cast(network, layer_output, inputs[0].dtype) From b8a9bc9cb0402e18c6024a32721b54fee7fe8e0d Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Thu, 17 Oct 2024 07:36:05 +0000 Subject: [PATCH 08/10] add cmakelists --- test/tensorrt/CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test/tensorrt/CMakeLists.txt b/test/tensorrt/CMakeLists.txt index 15fa3c0e4af1bd..d2ed5dc316e8ea 100644 --- a/test/tensorrt/CMakeLists.txt +++ b/test/tensorrt/CMakeLists.txt @@ -23,4 +23,5 @@ if(NOT WIN32 AND TENSORRT_FOUND) set_tests_properties(test_converter_creation PROPERTIES TIMEOUT "100") set_tests_properties(test_converter_attribute PROPERTIES TIMEOUT "100") set_tests_properties(test_converter_common PROPERTIES TIMEOUT "300") + set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "100") endif() From b8f00a3e75f6645ff750ba2f433f815c4501d753 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Tue, 22 Oct 2024 04:34:41 +0000 Subject: [PATCH 09/10] add test --- test/tensorrt/test_converter_logic.py | 38 ++++++++++++++++++++++++--- 1 file changed, 34 insertions(+), 4 deletions(-) diff --git a/test/tensorrt/test_converter_logic.py b/test/tensorrt/test_converter_logic.py index e607a4ec059424..00ace97ab27af4 100644 --- a/test/tensorrt/test_converter_logic.py +++ b/test/tensorrt/test_converter_logic.py @@ -20,14 +20,29 @@ import paddle -class TestGreaterThanTRTPattern(TensorRTBaseTest): +class TestGreaterThanFloat32TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.greater_than self.api_args = { - "x": np.random.randn(3).astype(np.float32), + "x": np.random.randn(2, 3).astype(np.float32), "y": np.random.randn(3).astype(np.float32), } self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1, 3], "y": [3]} + self.max_shape = {"x": [5, 3], "y": [3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestGreaterThanInt32TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.greater_than + self.api_args = { + "x": np.random.randn(3).astype(np.int32), + "y": np.random.randn(3).astype(np.int32), + } + self.program_config = {"feed_list": ["x", "y"]} self.min_shape = {"x": [1], "y": [1]} self.max_shape = {"x": [5], "y": [5]} @@ -35,14 +50,29 @@ def test_trt_result(self): self.check_trt_result() -class TestLessThanTRTPattern(TensorRTBaseTest): +class TestLessThanFloat32TRTPattern(TensorRTBaseTest): def setUp(self): self.python_api = paddle.less_than self.api_args = { - "x": np.random.randn(3).astype(np.float32), + "x": np.random.randn(2, 3).astype(np.float32), "y": np.random.randn(3).astype(np.float32), } self.program_config = {"feed_list": ["x", "y"]} + self.min_shape = {"x": [1, 3], "y": [3]} + self.max_shape = {"x": [5, 3], "y": [3]} + + def test_trt_result(self): + self.check_trt_result() + + +class TestLessThanInt32TRTPattern(TensorRTBaseTest): + def setUp(self): + self.python_api = paddle.less_than + self.api_args = { + "x": np.random.randn(3).astype(np.int32), + "y": np.random.randn(3).astype(np.int32), + } + self.program_config = {"feed_list": ["x", "y"]} self.min_shape = {"x": [1], "y": [1]} self.max_shape = {"x": [5], "y": [5]} From 62f57eba94bf08c4ebe32c6d8f1f6c8acdaac4f1 Mon Sep 17 00:00:00 2001 From: Hanyonggong <1229369094@qq.com> Date: Wed, 23 Oct 2024 07:50:55 +0000 Subject: [PATCH 10/10] mod file --- .../fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc | 8 ++++---- python/paddle/tensorrt/impls/logic.py | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 61e60f45c71408..09a9724f613e71 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -948,7 +948,7 @@ class GreaterThanOpPattern return false; } #if IS_TRT_VERSION_LT(8400) - VLOG(3) << "GreaterThanOp is not supported when TensorRT < 8.4"; + VLOG(3) << "pd_op.greater_than op is not supported when TensorRT < 8.4"; return false; #else pir::Value x = op.operand_source(0); @@ -956,7 +956,7 @@ class GreaterThanOpPattern auto x_dtype = pir::GetDataTypeFromValue(x); auto y_dtype = pir::GetDataTypeFromValue(y); if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << "Greater_than op do not support bool datatype"; + VLOG(3) << "pd_op.greater_than op do not support bool datatype"; return false; } #endif @@ -976,7 +976,7 @@ class LessThanOpPattern return false; } #if IS_TRT_VERSION_LT(8400) - VLOG(3) << "LessThanOp is not supported when TensorRT < 8.4"; + VLOG(3) << "pd_op.less_than op is not supported when TensorRT < 8.4"; return false; #else pir::Value x = op.operand_source(0); @@ -984,7 +984,7 @@ class LessThanOpPattern auto x_dtype = pir::GetDataTypeFromValue(x); auto y_dtype = pir::GetDataTypeFromValue(y); if (x_dtype.isa() || y_dtype.isa()) { - VLOG(3) << "Less_than op do not support bool datatype"; + VLOG(3) << "pd_op.less_than op do not support bool datatype"; return false; } #endif diff --git a/python/paddle/tensorrt/impls/logic.py b/python/paddle/tensorrt/impls/logic.py index 5bc5ffee277ff3..7b5fa1a6e8c92d 100644 --- a/python/paddle/tensorrt/impls/logic.py +++ b/python/paddle/tensorrt/impls/logic.py @@ -28,8 +28,10 @@ def logic_converter(network, paddle_op, inputs): layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.GREATER ) - else: + elif paddle_op.name() == "pd_op.less_than": layer_output = add_elementwise_layer( network, paddle_op, inputs, trt.ElementWiseOperation.LESS ) + else: + raise ValueError(f"Unexpected paddle_op: {paddle_op.name()}") return trt_cast(network, layer_output, inputs[0].dtype)