Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#14 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
增加了split_with_num Marker,修改了unsqueeze converter,增加了unsqueeze,squeeze marker
  • Loading branch information
lizexu123 authored Aug 1, 2024
2 parents 6fa5ef5 + cf169f4 commit 3ea745a
Show file tree
Hide file tree
Showing 7 changed files with 237 additions and 40 deletions.
113 changes: 106 additions & 7 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp)
DEFINE_GENERAL_PATTERN(Add, paddle::dialect::AddOp)
DEFINE_GENERAL_PATTERN(Full, paddle::dialect::FullOp)
DEFINE_GENERAL_PATTERN(Silu, paddle::dialect::SiluOp)

DEFINE_GENERAL_PATTERN(Conv2d, paddle::dialect::Conv2dOp)
DEFINE_GENERAL_PATTERN(FusedConv2dAddAct, paddle::dialect::FusedConv2dAddActOp)
DEFINE_GENERAL_PATTERN(DepthwiseConv2d, paddle::dialect::DepthwiseConv2dOp)
DEFINE_GENERAL_PATTERN(Sigmoid, paddle::dialect::SigmoidOp)

#undef DEFINE_GENERAL_PATTERN

Expand Down Expand Up @@ -492,11 +492,31 @@ class UnsqueezeOpPattern
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value axis = op.operand_source(1);
paddle::dialect::FullIntArrayOp full_int_array_op =
pir::GetDefiningOpForInput(op, 1)
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto axis = full_int_array_op->attribute<pir::ArrayAttribute>("value");

if (!axis) {
VLOG(3) << "The necessary attributes of the unsuqeeze axis is missing";
return false;
}
pir::Value x = op.operand_source(0);
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto x_shape = x_type.dims();

std::vector<int32_t> dynamic_dims;
for (int i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] == -1) {
dynamic_dims.push_back(i);
}
}
if (dynamic_dims.size() > 1) {
VLOG(3) << "Currently we don't support unsqueeze with more than one "
"dynamic dims";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
Expand All @@ -512,11 +532,31 @@ class Unsqueeze_OpPattern
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
pir::Value axis = op.operand_source(1);
paddle::dialect::FullIntArrayOp full_int_array_op =
pir::GetDefiningOpForInput(op, 1)
->dyn_cast<paddle::dialect::FullIntArrayOp>();
auto axis = full_int_array_op->attribute<pir::ArrayAttribute>("value");

if (!axis) {
VLOG(3) << "The necessary attributes of the unsuqeeze axis is missing";
return false;
}
pir::Value x = op.operand_source(0);
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
auto x_shape = x_type.dims();

std::vector<int32_t> dynamic_dims;
for (int i = 0; i < x_shape.size(); ++i) {
if (x_shape[i] == -1) {
dynamic_dims.push_back(i);
}
}
if (dynamic_dims.size() > 1) {
VLOG(3) << "Currently we don't support unsqueeze with more than one "
"dynamic dims";
return false;
}

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
Expand Down Expand Up @@ -704,6 +744,7 @@ class CastOpPattern : public pir::OpRewritePattern<paddle::dialect::CastOp> {
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;

}
};

Expand Down Expand Up @@ -732,10 +773,6 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();
if (!out_vector_type) {
VLOG(3) << "Output is not a VectorType";
return false;
}

paddle::dialect::FullIntArrayOp full_sections_op =
pir::GetDefiningOpForInput(op, 1)
Expand Down Expand Up @@ -768,7 +805,67 @@ class SplitOpPattern : public pir::OpRewritePattern<paddle::dialect::SplitOp> {
return true;
}
};
class SplitWithNumOpPattern
: public pir::OpRewritePattern<paddle::dialect::SplitWithNumOp> {
public:
using pir::OpRewritePattern<paddle::dialect::SplitWithNumOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::SplitWithNumOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
paddle::dialect::FullOp full_op =
pir::GetDefiningOpForInput(op, 1)->dyn_cast<paddle::dialect::FullOp>();
if (!full_op) {
VLOG(3) << "Can not find full op";
return false;
} else {
auto axis = full_op->attribute<paddle::dialect::ScalarAttribute>("value")
.data()
.to<int>();
auto x_shape = op.operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
auto out_vector_type = op.result(0).type().dyn_cast<pir::VectorType>();

axis += (axis < 0) ? x_shape.size() : 0;
if (x_shape[axis] == -1) {
VLOG(3) << "The (" << axis << ") dim of input should not be -1";
return false;
}

if (!op->HasAttribute("num") ) {
VLOG(3)<< "split_with_num op must has num attributes";
return false;
}
int num = op->attribute<pir::Int32Attribute>("num").data();
std::vector<int64_t> output_lengths;
if (num > 0) {
int64_t in_axis_dim = x_shape[axis];
if (in_axis_dim % num != 0) {
VLOG(3) << "Invalid number to split. Tensor split does not result"
" in an equal division of dimensions. Axis dim = "
<< in_axis_dim << " num = " << num << "!= 0";
return false;
}
size_t out_axis_dim = in_axis_dim / num;
for (int i = 0; i < num; ++i) {
output_lengths.push_back(out_axis_dim);
}
}

if(out_vector_type.size() != output_lengths.size()){
VLOG(3) << "The output_length should be equal to the output size.";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}

}
};
class TrtOpMarkerPass : public pir::PatternRewritePass {
public:
TrtOpMarkerPass() : pir::PatternRewritePass("trt_op_marker_pass", 2) {}
Expand Down Expand Up @@ -798,6 +895,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(DepthwiseConv2d)
ADD_PATTERN(Nonzero)
ADD_PATTERN(Gelu)
ADD_PATTERN(Sigmoid)

#undef ADD_PATTERN
ps.Add(std::make_unique<Pool2dOpPattern>(context));
Expand All @@ -820,6 +918,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<FlattenOpPattern>(context));
ps.Add(std::make_unique<CastOpPattern>(context));
ps.Add(std::make_unique<SplitOpPattern>(context));
ps.Add(std::make_unique<SplitWithNumOpPattern>(context));
return ps;
}
};
Expand Down
7 changes: 1 addition & 6 deletions python/paddle/pp_tensorrt/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from register import converter_registry
from impls.core import *


def get_cache_path():
home_path = os.path.expanduser("~")
cache_path = os.path.join(home_path, ".pp_trt_cache")
Expand Down Expand Up @@ -204,11 +203,7 @@ def convert_subgraph_to_trt(self, program, group_op):
raise RuntimeError(
f'{source_id} not found in value_to_trt_tensor'
)
# operands.append(value_to_trt_tensor[operand.source().id])
# operands = [
# value_to_trt_tensor[operand.source().id]
# for operand in op.operands()
# ]

layer = self.convert(network, op, operands)

# _logger.info(f"start convert {op}")
Expand Down
6 changes: 6 additions & 0 deletions python/paddle/pp_tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,9 @@ def get_trt_plugin(plugin_name, field_collection, version, plugin_namespace=""):
)
assert plugin is not None, f"Plugin:{plugin_name} could not be fetched"
return plugin


def get_positive_dim(dim, dim_size):
if dim < 0:
return dim % dim_size
return dim
95 changes: 75 additions & 20 deletions python/paddle/pp_tensorrt/impls/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@
broadcast,
get_axes_for_reduce_op,
get_dynamic_dims,
has_dynamic_shape,
get_positive_dim,
get_trt_plugin,
has_dynamic_shape,
)


Expand Down Expand Up @@ -118,11 +119,17 @@ def full_int_array_converter(network, paddle_op, inputs):
@converter_registry.register("pd_op.reshape", trt_version="8.x")
def reshape_converter(network, paddle_op, inputs):
input_tensor, shape_tensor = inputs
input_shape = paddle_op.operands()[0].source().shape

output_shape = paddle_op.results()[1].shape
if network.has_implicit_batch_dimension:
output_shape = output_shape[1:]

if type(input_tensor) == trt.Weights:
input_tensor = network.add_constant(
input_shape, input_tensor
).get_output(0)

shuffle_layer = network.add_shuffle(input_tensor)

try:
Expand Down Expand Up @@ -478,12 +485,13 @@ def flatten_converter(network, paddle_op, inputs):


# 在converter中,pd_op.concat有三个输入,因为builtin.combine有两个输入
@converter_registry.register("pd_op.concat",trt_version="8.x")
@converter_registry.register("pd_op.concat", trt_version="8.x")
def concat_converter(network, paddle_op, inputs):
input_tensors = inputs[:-1]
axis_tensor = inputs[-1]
concat_layer = network.add_concatenation(inputs=input_tensors)

# 这是获取op
full_op = paddle_op.operands()[1]
# 这是获取value
full_value = full_op.source()
Expand All @@ -496,23 +504,70 @@ def concat_converter(network, paddle_op, inputs):

return concat_layer


@converter_registry.register("pd_op.gelu", trt_version="8.x")
def gelu_converter(network,paddle_op,inputs):
input_val =inputs[0]
approximate =paddle_op.attrs()["approximate"]
if approximate !=False:
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")

plugin_name ="CustomGeluPluginDynamic"
type_id =trt.PluginField("type_id",np.array(0,dtype=np.int32),trt.PluginFieldType.INT32)

filed_collection =trt.PluginFieldCollection([type_id])
plugin_version="1"

plugin=get_trt_plugin(plugin_name,filed_collection,plugin_version)

layer=network.add_plugin_v2([input_val],plugin)
def gelu_converter(network, paddle_op, inputs):
input_val = inputs[0]
approximate = paddle_op.attrs()["approximate"]
if approximate:
raise RuntimeError(
"GeLU converter currently doesn't support fast gelu compute"
)

plugin_name = "CustomGeluPluginDynamic"
type_id = trt.PluginField(
"type_id", np.array(0, dtype=np.int32), trt.PluginFieldType.INT32
)

filed_collection = trt.PluginFieldCollection([type_id])
plugin_version = "1"

plugin = get_trt_plugin(plugin_name, filed_collection, plugin_version)

layer = network.add_plugin_v2([input_val], plugin)
return layer


@converter_registry.register("pd_op.unsqueeze", trt_version="8.x")
@converter_registry.register("pd_op.unsqueeze_", trt_version="8.x")
def unsqueeze_converter(network, paddle_op, inputs):
input_val = inputs[0]
input_shape = paddle_op.operands()[0].source().shape
input_shape_size = len(input_shape)

if type(input_val) == trt.Weights:
input_val = network.add_constant(input_shape, input_val).get_output(0)
axis = paddle_op.operands()[1].source().get_defining_op().attrs()["value"]
axis = axis[0]

axis = get_positive_dim(axis, input_shape_size + 1)
layer = network.add_shuffle(input_val)
layer.reshape_dims = (
tuple(input_val.shape)[:axis] + (1,) + tuple(input_val.shape)[axis:]
)
return layer


@converter_registry.register("pd_op.squeeze", trt_version="8.x")
@converter_registry.register("pd_op.squeeze_", trt_version="8.x")
def squeeze_converter(network, paddle_op, inputs):
input_val = inputs[0]
input_shape = paddle_op.operands()[0].source().shape
input_shape_size = len(input_shape)

if type(input_val) == trt.Weights:
input_val = network.add_constant(input_shape, input_val).get_output(0)

axis = paddle_op.operands()[1].source().get_defining_op().attrs()["value"]
axis = axis[0]

axis = get_positive_dim(axis, input_shape_size + 1)
output_shape = []
for i, s in enumerate(input_shape):
if i == axis and s == 1:
continue
output_shape.append(s)

layer = network.add_shuffle(input_val)
layer.reshape_dims = tuple(output_shape)
return layer



7 changes: 3 additions & 4 deletions python/paddle/pp_tensorrt/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,7 @@ def test_paddle_to_tensorrt_conversion_dummy():
with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(program):
executor = paddle.static.Executor()
output_var = program.list_vars()[-1]
# forbid_op_lower_trt(program, "pd_op.gelu")
output_var = program.list_vars()[-2]
# Run the program with input_data
for _ in range(1):
output_original = executor.run(
Expand All @@ -51,14 +50,14 @@ def test_paddle_to_tensorrt_conversion_dummy():
feed={"input": input_data_max_shape},
fetch_list=[output_var],
)

# forbid_op_lower_trt(program,"pd_op.squeeze")
# Apply PIR pass to the program
program_with_pir = run_pir_pass(program, partition_mode=True)

# Convert the program to TensorRT
converter = PaddleToTensorRTConverter(program_with_pir, scope)
converter.convert_program_to_trt()
output_var = program_with_pir.list_vars()[-1]
output_var = program_with_pir.list_vars()[-2]

with paddle.pir_utils.IrGuard():
with paddle.static.program_guard(program_with_pir):
Expand Down
Loading

0 comments on commit 3ea745a

Please sign in to comment.