Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SCU][Paddle TensorRT No.31]Add greater_equal and greater_equal_ converter #69770

Open
wants to merge 6 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 22 additions & 16 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1146,27 +1146,26 @@ class SplitWithNumOpPattern
}
};

class GreaterEqualOpPattern
: public pir::OpRewritePattern<paddle::dialect::GreaterEqualOp> {
class GreaterThanOpPattern
: public pir::OpRewritePattern<paddle::dialect::GreaterThanOp> {
public:
using pir::OpRewritePattern<
paddle::dialect::GreaterEqualOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::GreaterEqualOp op,
using pir::OpRewritePattern<paddle::dialect::GreaterThanOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::GreaterThanOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "GreaterEqualOp is not supported when TensorRT < 8.4";
VLOG(3) << "pd_op.greater_than op is not supported when TensorRT < 8.4";
return false;
#else
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "Greate_equal op do not support bool datatype";
VLOG(3) << "pd_op.greater_than op do not support bool datatype";
return false;
}
#endif
Expand All @@ -1175,33 +1174,39 @@ class GreaterEqualOpPattern
}
};

class GreaterThanOpPattern
: public pir::OpRewritePattern<paddle::dialect::GreaterThanOp> {
template <typename OpType>
class GreaterEqualOpPattern : public pir::OpRewritePattern<OpType> {
public:
using pir::OpRewritePattern<paddle::dialect::GreaterThanOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::GreaterThanOp op,
using pir::OpRewritePattern<OpType>::OpRewritePattern;

bool MatchAndRewrite(OpType op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
op->template attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
#if IS_TRT_VERSION_LT(8400)
VLOG(3) << "pd_op.greater_than op is not supported when TensorRT < 8.4";
VLOG(3) << op->name() << " is not supported when TensorRT < 8.4";
return false;
#else
pir::Value x = op.operand_source(0);
pir::Value y = op.operand_source(1);
auto x_dtype = pir::GetDataTypeFromValue(x);
auto y_dtype = pir::GetDataTypeFromValue(y);
if (x_dtype.isa<pir::BoolType>() || y_dtype.isa<pir::BoolType>()) {
VLOG(3) << "pd_op.greater_than op do not support bool datatype";
if (x_dtype.template isa<pir::BoolType>() ||
y_dtype.template isa<pir::BoolType>()) {
VLOG(3) << op->name() << " does not support bool datatype";
return false;
}
#endif
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
using GreaterEqual1OpPattern =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pd_op.greater_than不符合这里的检查规则吗

GreaterEqualOpPattern<paddle::dialect::GreaterEqualOp>;
using GreaterEqual2OpPattern =
GreaterEqualOpPattern<paddle::dialect::GreaterEqual_Op>;

class LessThanOpPattern
: public pir::OpRewritePattern<paddle::dialect::LessThanOp> {
Expand Down Expand Up @@ -2222,7 +2227,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<CastOpPattern>(context));
ps.Add(std::make_unique<SplitOpPattern>(context));
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
ps.Add(std::make_unique<GreaterEqualOpPattern>(context));
ps.Add(std::make_unique<GreaterEqual1OpPattern>(context));
ps.Add(std::make_unique<GreaterEqual2OpPattern>(context));
ps.Add(std::make_unique<GreaterThanOpPattern>(context));
ps.Add(std::make_unique<LessThanOpPattern>(context));
ps.Add(std::make_unique<MultiplyOpPattern>(context));
Expand Down
18 changes: 18 additions & 0 deletions python/paddle/tensorrt/impls/logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,21 @@ def not_equal_converter(network, paddle_op, inputs):
not_layer = network.add_unary(layer_output, trt.UnaryOperation.NOT)
layer_output = not_layer.get_output(0)
return layer_output


@converter_registry.register("pd_op.greater_equal", trt_version="8.x")
@converter_registry.register("pd_op.greater_equal_", trt_version="8.x")
def greater_equal_converter(network, paddle_op, inputs):
greater_layer_output = add_elementwise_layer(
network, paddle_op, inputs, trt.ElementWiseOperation.GREATER
)
equal_layer_output = add_elementwise_layer(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里写的不对,参考pd_op.remainder,应该就第一个需要add_elementwise_layer

network, paddle_op, inputs, trt.ElementWiseOperation.EQUAL
)
or_layer = add_elementwise_layer(
network,
paddle_op,
[greater_layer_output, equal_layer_output],
trt.ElementWiseOperation.OR,
)
return or_layer
30 changes: 30 additions & 0 deletions test/tensorrt/test_converter_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,36 @@ def test_trt_result(self):
self.check_trt_result()


class TestGreaterEqualRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_equal
self.api_args = {
"x": np.random.randn(2, 3).astype("float32"),
"y": np.random.randn(2, 3).astype("float32"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [1, 3]}
self.max_shape = {"x": [5, 3], "y": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


class TestGreaterEqual_RTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.greater_equal_
self.api_args = {
"x": np.random.randn(2, 3).astype("float32"),
"y": np.random.randn(2, 3).astype("float32"),
}
self.program_config = {"feed_list": ["x", "y"]}
self.min_shape = {"x": [1, 3], "y": [1, 3]}
self.max_shape = {"x": [5, 3], "y": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个Marker单测,为了过ci

class TestLessThanFloat32TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.less_than
Expand Down