Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#4 from Layssy/lw_trt
Browse files Browse the repository at this point in the history
添加split的开发和单测
  • Loading branch information
lizexu123 authored Aug 1, 2024
2 parents 4c86fbd + 76e2d57 commit e6ca18f
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 4 deletions.
66 changes: 62 additions & 4 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -704,6 +704,7 @@ class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;

}
};

Expand Down Expand Up @@ -732,10 +733,6 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
if (!out_vector_type) {
VLOG(3) << "Output is not a VectorType";
return false;
}

paddle::dialect::FullIntArrayOp full_sections_op =
pir::GetDefiningOpForInput(op, 1)
Expand Down Expand Up @@ -768,7 +765,67 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
return true;
}
};
class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SplitWithNumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
paddle::dialect::FullOp full_op =
pir::GetDefiningOpForInput(op, 1)->dyn_cast<paddle::dialect::FullOp>();
if (!full_op) {
VLOG(3) << "Can not find full op";
return false;
} else {
auto axis = full_op->attribute<paddle::dialect::ScalarAttribute>("value")
.data()
.to<int>();
auto x_shape = op.operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();

axis += (axis < 0) ? x_shape.size() : 0;
if (x_shape[axis] == -1) {
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
return false;
}

if (!op->HasAttribute("num") ) {
VLOG(3)<< "split_with_num op must has num attributes";
return false;
}
int num = op->attribute<pir::Int32Attribute>("num").data();
std::vector<int64_t> output_lengths;
if (num > 0) {
int64_t in_axis_dim = x_shape[axis];
if (in_axis_dim % num != 0) {
VLOG(3) << "Invalid number to split. Tensor split does not result"
" in an equal division of dimensions. Axis dim = "
<< in_axis_dim << " num = " << num << "!= 0";
return false;
}
size_t out_axis_dim = in_axis_dim / num;
for (int i = 0; i < num; ++i) {
output_lengths.push_back(out_axis_dim);
}
}

if(out_vector_type.size() != output_lengths.size()){
VLOG(3) << "The output_length should be equal to the output size.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}

}
};
class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -820,6 +877,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<FlattenOpPattern>(context));
ps.Add(std::make_unique<CastOpPattern>(context));
ps.Add(std::make_unique<SplitOpPattern>(context));
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
return ps;
}
};
Expand Down
35 changes: 35 additions & 0 deletions test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,6 +544,41 @@ def setUp(self):
def test_check_output(self):
self.check_pass_correct()

class TestSplitWithNumTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True
def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(
name='x', shape=[3,9,5], dtype='int64'
)
num_or_sections = 3
axis = 1
split_out = paddle.split(x, num_or_sections=num_or_sections, axis=axis)
print(f'split_out:{split_out}')
out = paddle.assign(split_out[0])
out1 = paddle.assign(split_out[1])
out2 = paddle.assign(split_out[2])
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.random.random([3,9,5]).astype("int64"),
}

self.fetch_list = [out,out1,out2]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()

class TestGeluTRTPattern(PassTest):
def is_program_valid(self, program=None):
Expand Down

0 comments on commit e6ca18f

Please sign in to comment.