diff --git a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc index 02ddfe45331c6d..ca02506347f13c 100644 --- a/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc +++ b/paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc @@ -704,6 +704,7 @@ class CastOpPattern : public pir::OpRewritePattern { } op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); return true; + } }; @@ -732,10 +733,6 @@ class SplitOpPattern : public pir::OpRewritePattern { .dyn_cast() .dims(); auto out_vector_type = op.result(0).type().dyn_cast(); - if (!out_vector_type) { - VLOG(3) << "Output is not a VectorType"; - return false; - } paddle::dialect::FullIntArrayOp full_sections_op = pir::GetDefiningOpForInput(op, 1) @@ -768,7 +765,67 @@ class SplitOpPattern : public pir::OpRewritePattern { return true; } }; +class SplitWithNumOpPattern + : public pir::OpRewritePattern { + public: + using pir::OpRewritePattern::OpRewritePattern; + bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op, + pir::PatternRewriter &rewriter) const override { + if (op->HasAttribute(kCanRunTrtAttr) && + op->attribute(kCanRunTrtAttr).data()) { + return false; + } + paddle::dialect::FullOp full_op = + pir::GetDefiningOpForInput(op, 1)->dyn_cast(); + if (!full_op) { + VLOG(3) << "Can not find full op"; + return false; + } else { + auto axis = full_op->attribute("value") + .data() + .to(); + auto x_shape = op.operand_source(0) + .type() + .dyn_cast() + .dims(); + auto out_vector_type = op.result(0).type().dyn_cast(); + + axis += (axis < 0) ? x_shape.size() : 0; + if (x_shape[axis] == -1) { + VLOG(3) << "The (" << axis << ") dim of input should not be -1"; + return false; + } + + if (!op->HasAttribute("num") ) { + VLOG(3)<< "split_with_num op must has num attributes"; + return false; + } + int num = op->attribute("num").data(); + std::vector output_lengths; + if (num > 0) { + int64_t in_axis_dim = x_shape[axis]; + if (in_axis_dim % num != 0) { + VLOG(3) << "Invalid number to split. Tensor split does not result" + " in an equal division of dimensions. Axis dim = " + << in_axis_dim << " num = " << num << "!= 0"; + return false; + } + size_t out_axis_dim = in_axis_dim / num; + for (int i = 0; i < num; ++i) { + output_lengths.push_back(out_axis_dim); + } + } + if(out_vector_type.size() != output_lengths.size()){ + VLOG(3) << "The output_length should be equal to the output size."; + return false; + } + op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); + return true; + } + + } +}; class TrtOpMarkerPass : public pir::PatternRewritePass { public: TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {} @@ -820,6 +877,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); ps.Add(std::make_unique(context)); + ps.Add(std::make_unique(context)); return ps; } }; diff --git a/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py index 6f157b14afb158..a9c3d5fbfdfdb7 100644 --- a/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py +++ b/test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py @@ -544,6 +544,41 @@ def setUp(self): def test_check_output(self): self.check_pass_correct() +class TestSplitWithNumTRTPattern(PassTest): + def is_program_valid(self, program=None): + return True + def sample_program(self): + with paddle.pir_utils.IrGuard(): + main_prog = paddle.static.Program() + start_prog = paddle.static.Program() + with paddle.pir.core.program_guard(main_prog, start_prog): + x = paddle.static.data( + name='x', shape=[3,9,5], dtype='int64' + ) + num_or_sections = 3 + axis = 1 + split_out = paddle.split(x, num_or_sections=num_or_sections, axis=axis) + print(f'split_out:{split_out}') + out = paddle.assign(split_out[0]) + out1 = paddle.assign(split_out[1]) + out2 = paddle.assign(split_out[2]) + self.pass_attr_list = [{'trt_op_marker_pass': {}}] + self.feeds = { + "x": np.random.random([3,9,5]).astype("int64"), + } + + self.fetch_list = [out,out1,out2] + self.valid_op_map = { + "pd_op.fusion_transpose_flatten_concat": 0, + } + yield [main_prog, start_prog], False + + def setUp(self): + if core.is_compiled_with_cuda(): + self.places.append(paddle.CUDAPlace(0)) + + def test_check_output(self): + self.check_pass_correct() class TestGeluTRTPattern(PassTest): def is_program_valid(self, program=None):