Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#13 from lizexu123/add_trt
Browse files Browse the repository at this point in the history
添加了gelu的marker以及converter,跑通了dummpy模型
  • Loading branch information
lizexu123 authored Jul 31, 2024
2 parents 33548c5 + 4c86fbd commit 6fa5ef5
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 3 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ DEFINE_GENERAL_PATTERN(Dropout, paddle::dialect::DropoutOp)
DEFINE_GENERAL_PATTERN(Bmm, paddle::dialect::BmmOp)
DEFINE_GENERAL_PATTERN(Concat, paddle::dialect::ConcatOp)
DEFINE_GENERAL_PATTERN(Nonzero, paddle::dialect::NonzeroOp)

DEFINE_GENERAL_PATTERN(Gelu, paddle::dialect::GeluOp)
DEFINE_GENERAL_PATTERN(Fused_gemm_epilogue,
paddle::dialect::FusedGemmEpilogueOp)
DEFINE_GENERAL_PATTERN(Layer_norm, paddle::dialect::LayerNormOp)
Expand Down Expand Up @@ -797,6 +797,7 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ADD_PATTERN(FusedConv2dAddAct)
ADD_PATTERN(DepthwiseConv2d)
ADD_PATTERN(Nonzero)
ADD_PATTERN(Gelu)

#undef ADD_PATTERN
ps.Add(std::make_unique<Pool2dOpPattern>(context));
Expand Down
16 changes: 16 additions & 0 deletions python/paddle/pp_tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys

import numpy as np
import tensorrt as trt

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
Expand Down Expand Up @@ -108,3 +109,18 @@ def get_dynamic_dims(shape):
if s == -1:
dynamic_dims.append(i)
return dynamic_dims


def get_trt_plugin(plugin_name, field_collection, version, plugin_namespace=""):
plugin_registry = trt.get_plugin_registry()
plugin_creator = plugin_registry.get_plugin_creator(
plugin_name, version, plugin_namespace
)
assert (
plugin_creator
), f"Unabled to find plugin creator with name{plugin_name}"
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
assert plugin is not None, f"Plugin:{plugin_name} could not be fetched"
return plugin
24 changes: 23 additions & 1 deletion python/paddle/pp_tensorrt/impls/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_axes_for_reduce_op,
get_dynamic_dims,
has_dynamic_shape,
get_trt_plugin,
)


Expand Down Expand Up @@ -477,7 +478,7 @@ def flatten_converter(network, paddle_op, inputs):


# 在converter中,pd_op.concat有三个输入,因为builtin.combine有两个输入
@converter_registry.register("pd_op.concat")
@converter_registry.register("pd_op.concat",trt_version="8.x")
def concat_converter(network, paddle_op, inputs):
input_tensors = inputs[:-1]
axis_tensor = inputs[-1]
Expand All @@ -494,3 +495,24 @@ def concat_converter(network, paddle_op, inputs):
concat_layer.axis = axis

return concat_layer

@converter_registry.register("pd_op.gelu", trt_version="8.x")
def gelu_converter(network,paddle_op,inputs):
input_val =inputs[0]
approximate =paddle_op.attrs()["approximate"]
if approximate !=False:
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")

plugin_name ="CustomGeluPluginDynamic"
type_id =trt.PluginField("type_id",np.array(0,dtype=np.int32),trt.PluginFieldType.INT32)

filed_collection =trt.PluginFieldCollection([type_id])
plugin_version="1"

plugin=get_trt_plugin(plugin_name,filed_collection,plugin_version)

layer=network.add_plugin_v2([input_val],plugin)
return layer



2 changes: 1 addition & 1 deletion python/paddle/pp_tensorrt/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_paddle_to_tensorrt_conversion_dummy():
with paddle.static.program_guard(program):
executor = paddle.static.Executor()
output_var = program.list_vars()[-1]
forbid_op_lower_trt(program, "pd_op.gelu")
# forbid_op_lower_trt(program, "pd_op.gelu")
# Run the program with input_data
for _ in range(1):
output_original = executor.run(
Expand Down
31 changes: 31 additions & 0 deletions test/ir/pir/fused_pass/test_pir_trt_op_marker_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,5 +545,36 @@ def test_check_output(self):
self.check_pass_correct()


class TestGeluTRTPattern(PassTest):
def is_program_valid(self, program=None):
return True

def sample_program(self):
with paddle.pir_utils.IrGuard():
main_prog = paddle.static.Program()
start_prog = paddle.static.Program()
with paddle.pir.core.program_guard(main_prog, start_prog):
x = paddle.static.data(name='x', shape=[2, 2], dtype='float32')
m = paddle.nn.GELU()
out = m(x)
out = paddle.assign(out)
self.pass_attr_list = [{'trt_op_marker_pass': {}}]
self.feeds = {
"x": np.array([[-1, 0.5], [1, 1.5]]).astype("float32"),
}
self.fetch_list = [out]
self.valid_op_map = {
"pd_op.fusion_transpose_flatten_concat": 0,
}
yield [main_prog, start_prog], False

def setUp(self):
if core.is_compiled_with_cuda():
self.places.append(paddle.CUDAPlace(0))

def test_check_output(self):
self.check_pass_correct()


if __name__ == "__main__":
unittest.main()

0 comments on commit 6fa5ef5

Please sign in to comment.