Skip to content

Commit

Permalink
Merge pull request #4 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
Add trt
  • Loading branch information
lizexu123 authored Jul 12, 2024
2 parents f8e7d37 + 5fbba5c commit b2da6f3
Show file tree
Hide file tree
Showing 2 changed files with 209 additions and 1 deletion.
131 changes: 131 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.h"
#include <glog/logging.h>
#include <bitset>
#include <vector>
#include "paddle/fluid/inference/tensorrt/helper.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
Expand Down Expand Up @@ -596,6 +597,133 @@ class ScaleOpPattern : public pir::OpRewritePattern<paddle::dialect::ScaleOp> {
}
};

class UnsqueezeOpPattern
: public pir::OpRewritePattern<paddle::dialect::UnsqueezeOp> {
public:
using pir::OpRewritePattern<paddle::dialect::UnsqueezeOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::UnsqueezeOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value axis = op.operand_source(1);
if (!axis) {
VLOG(3) << "The necessary attributes of the unsuqeeze axis is missing";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class Unsqueeze_OpPattern
: public pir::OpRewritePattern<paddle::dialect::Unsqueeze_Op> {
public:
using pir::OpRewritePattern<paddle::dialect::Unsqueeze_Op>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::Unsqueeze_Op op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value axis = op.operand_source(1);
if (!axis) {
VLOG(3) << "The necessary attributes of the unsuqeeze axis is missing";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class SqueezeOpPattern
: public pir::OpRewritePattern<paddle::dialect::SqueezeOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SqueezeOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SqueezeOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}

pir::Value axis_ = op.operand_source(1);
std::vector<int64_t> axes;

if (axis_) {
bool is_from_tensor = false;
phi::IntArray axis = phi::IntArray(
paddle::dialect::ParseValueShape(axis_, &is_from_tensor));
for (auto a : axis.GetData()) {
axes.push_back(a);
}
}

if (axes.empty()) {
auto input_var_name = op.operand_source(0);
auto input_var_name_type =
input_var_name.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto input_var_name_shape = input_var_name_type.dims();

for (int i = 0; i < input_var_name_shape.size(); ++i) {
int64_t s = input_var_name_shape[i];
if (s == -1) {
VLOG(3) << "The necessary attributes of the squeeze operator axis is "
"missing. ss =====-1";
return false;
} else if (s == 1) {
axes.push_back(s);
}
}

if (axes.empty()) {
VLOG(3) << "The necessary attributes of the squeeze2 operator axes is "
"missing.";
return false;
}
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class SliceOpPattern : public pir::OpRewritePattern<paddle::dialect::SliceOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SliceOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SliceOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op->HasAttribute("axes")) {
VLOG(3)
<< "The necessary attribute of the slice operator axes are missing.";
return false;
}

auto axes_attr = op->attribute<pir::ArrayAttribute>("axes");

std::vector<int64_t> axes;
for (const auto &attr : axes_attr.AsVector()) {
axes.push_back(attr.dyn_cast<pir::Int64Attribute>().data());
}
pir::Value input = op.operand_source(0);

auto inputs = input.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto inputs_shape = inputs.dims();
if (axes.size() != inputs_shape.size()) {
VLOG(3) << "The shape of attributes of the slice operator axes "
"and starts are not equal.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};

class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -638,6 +766,9 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<GatherOpPattern>(context));
ps.Add(std::make_unique<GatherNdOpPattern>(context));
ps.Add(std::make_unique<ScaleOpPattern>(context));
ps.Add(std::make_unique<UnsqueezeOpPattern>(context));
ps.Add(std::make_unique<Unsqueeze_OpPattern>(context));
ps.Add(std::make_unique<SliceOpPattern>(context));
return ps;
}
};
Expand Down
79 changes: 78 additions & 1 deletion test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def test_check_output(self):
self.check_pass_correct()


class TestFlattenConCatPattern(PassTest):
class TestFlattenConCatTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

Expand Down Expand Up @@ -272,5 +272,82 @@ def test_check_output(self):
self.check_pass_correct()


class TestGatherNdTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[1, 3, 4], dtype='float32'
)
index = paddle.static.data(
name='index', shape=[1, 2, 2], dtype='int32'
)
gather_nd_out = paddle.gather_nd(x, index)
out = paddle.assign(gather_nd_out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random([1, 3, 4]).astype("float32"),
"index": np.random.random([1, 2, 2]).astype("int32"),
}

self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()


class TestSliceTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[4, 5, 6], dtype='float32'
)

# Convert starts and ends to tensors
axes = [0, 1, 2]
starts = [-3, 0, 2]
ends = [3, 2, 4]

sliced_1 = paddle.slice(x, axes=axes, starts=starts, ends=ends)
# print("Sliced output shape:", sliced_1.shape)

out = paddle.assign(sliced_1)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random([4, 5, 6]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.conv2d": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit b2da6f3

Please sign in to comment.