Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle Tensorrt] add pd_op.greater_than, pd_op.less_than marker and converter #68686

Merged
merged 13 commits into from
Oct 25, 2024
60 changes: 60 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,7 @@ class SplitWithNumOpPattern
return true;
}
};

class GreaterEqualOpPattern
: public pir::OpRewritePattern<paddle::dialect::GreaterEqualOp> {
public:
Expand Down Expand Up @@ -935,6 +936,63 @@ class GreaterEqualOpPattern
return true;
}
};

class GreaterThanOpPattern
: public pir::OpRewritePattern<paddle::dialect::GreaterThanOp> {
public:
using pir::OpRewritePattern<paddle::dialect::GreaterThanOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::GreaterThanOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "pd_op.greater_than op is not supported when TensorRT < 8.4";
return false;
#else
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "pd_op.greater_than op do not support bool datatype";
return false;
}
#endif
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class LessThanOpPattern
: public pir::OpRewritePattern<paddle::dialect::LessThanOp> {
public:
using pir::OpRewritePattern<paddle::dialect::LessThanOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::LessThanOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "pd_op.less_than op is not supported when TensorRT < 8.4";
return false;
#else
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "pd_op.less_than op do not support bool datatype";
return false;
}
#endif
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class MultiplyOpPattern
: public pir::OpRewritePattern<paddle::dialect::MultiplyOp> {
public:
Expand Down Expand Up @@ -1525,6 +1583,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<SplitOpPattern>(context));
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
ps.Add(std::make_unique<GreaterThanOpPattern>(context));
ps.Add(std::make_unique<LessThanOpPattern>(context));
ps.Add(std::make_unique<MultiplyOpPattern>(context));
ps.Add(std::make_unique<SubtractOpPattern>(context));
ps.Add(std::make_unique<DivideOpPattern>(context));
Expand Down
1 change: 1 addition & 0 deletions python/paddle/tensorrt/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from .impls.conv import * # noqa: F403
from .impls.creation import * # noqa: F403
from .impls.linalg import * # noqa: F403
from .impls.logic import * # noqa: F403
from .impls.manipulation import * # noqa: F403
from .impls.math import * # noqa: F403
from .impls.norm import * # noqa: F403
Expand Down
37 changes: 37 additions & 0 deletions python/paddle/tensorrt/impls/logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import tensorrt as trt

from paddle.tensorrt.converter_utils import (
add_elementwise_layer,
trt_cast,
)
from paddle.tensorrt.register import converter_registry


@converter_registry.register("pd_op.greater_than", trt_version="8.x")
@converter_registry.register("pd_op.less_than", trt_version="8.x")
def logic_converter(network, paddle_op, inputs):
if paddle_op.name() == "pd_op.greater_than":
layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER
)
elif paddle_op.name() == "pd_op.less_than":
layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.LESS
)
else:
raise ValueError(f"Unexpected paddle_op: {paddle_op.name()}")
return trt_cast(network, layer_output, inputs[0].dtype)
1 change: 1 addition & 0 deletions test/tensorrt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ if(NOT WIN32 AND TENSORRT_FOUND)
set_tests_properties(test_converter_creation PROPERTIES TIMEOUT "100")
set_tests_properties(test_converter_attribute PROPERTIES TIMEOUT "100")
set_tests_properties(test_converter_common PROPERTIES TIMEOUT "300")
set_tests_properties(test_converter_logic PROPERTIES TIMEOUT "100")
endif()
84 changes: 84 additions & 0 deletions test/tensorrt/test_converter_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
from tensorrt_test_base import TensorRTBaseTest

import paddle


class TestGreaterThanFloat32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_than
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
"y": np.random.randn(3).astype(np.float32),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [3]}
self.max_shape = {"x": [5, 3], "y": [3]}

def test_trt_result(self):
self.check_trt_result()


class TestGreaterThanInt32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_than
self.api_args = {
"x": np.random.randn(3).astype(np.int32),
"y": np.random.randn(3).astype(np.int32),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1], "y": [1]}
self.max_shape = {"x": [5], "y": [5]}

def test_trt_result(self):
self.check_trt_result()


class TestLessThanFloat32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.less_than
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
"y": np.random.randn(3).astype(np.float32),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [3]}
self.max_shape = {"x": [5, 3], "y": [3]}

def test_trt_result(self):
self.check_trt_result()


class TestLessThanInt32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.less_than
self.api_args = {
"x": np.random.randn(3).astype(np.int32),
"y": np.random.randn(3).astype(np.int32),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1], "y": [1]}
self.max_shape = {"x": [5], "y": [5]}

def test_trt_result(self):
self.check_trt_result()


if __name__ == '__main__':
unittest.main()