Skip to content

Commit

Permalink
添加了gelu converter
Browse files Browse the repository at this point in the history
  • Loading branch information
lizexu123 committed Jul 31, 2024
1 parent 99d0484 commit 4c86fbd
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 2 deletions.
16 changes: 16 additions & 0 deletions python/paddle/pp_tensorrt/converter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import sys

import numpy as np
import tensorrt as trt

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
Expand Down Expand Up @@ -108,3 +109,18 @@ def get_dynamic_dims(shape):
if s == -1:
dynamic_dims.append(i)
return dynamic_dims


def get_trt_plugin(plugin_name, field_collection, version, plugin_namespace=""):
plugin_registry = trt.get_plugin_registry()
plugin_creator = plugin_registry.get_plugin_creator(
plugin_name, version, plugin_namespace
)
assert (
plugin_creator
), f"Unabled to find plugin creator with name{plugin_name}"
plugin = plugin_creator.create_plugin(
name=plugin_name, field_collection=field_collection
)
assert plugin is not None, f"Plugin:{plugin_name} could not be fetched"
return plugin
24 changes: 23 additions & 1 deletion python/paddle/pp_tensorrt/impls/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
get_axes_for_reduce_op,
get_dynamic_dims,
has_dynamic_shape,
get_trt_plugin,
)


Expand Down Expand Up @@ -477,7 +478,7 @@ def flatten_converter(network, paddle_op, inputs):


# 在converter中,pd_op.concat有三个输入,因为builtin.combine有两个输入
@converter_registry.register("pd_op.concat")
@converter_registry.register("pd_op.concat",trt_version="8.x")
def concat_converter(network, paddle_op, inputs):
input_tensors = inputs[:-1]
axis_tensor = inputs[-1]
Expand All @@ -494,3 +495,24 @@ def concat_converter(network, paddle_op, inputs):
concat_layer.axis = axis

return concat_layer

@converter_registry.register("pd_op.gelu", trt_version="8.x")
def gelu_converter(network,paddle_op,inputs):
input_val =inputs[0]
approximate =paddle_op.attrs()["approximate"]
if approximate !=False:
raise RuntimeError("GeLU converter currently doesn't support fast gelu compute")

plugin_name ="CustomGeluPluginDynamic"
type_id =trt.PluginField("type_id",np.array(0,dtype=np.int32),trt.PluginFieldType.INT32)

filed_collection =trt.PluginFieldCollection([type_id])
plugin_version="1"

plugin=get_trt_plugin(plugin_name,filed_collection,plugin_version)

layer=network.add_plugin_v2([input_val],plugin)
return layer



2 changes: 1 addition & 1 deletion python/paddle/pp_tensorrt/test_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_paddle_to_tensorrt_conversion_dummy():
with paddle.static.program_guard(program):
executor = paddle.static.Executor()
output_var = program.list_vars()[-1]
forbid_op_lower_trt(program, "pd_op.gelu")
# forbid_op_lower_trt(program, "pd_op.gelu")
# Run the program with input_data
for _ in range(1):
output_original = executor.run(
Expand Down

0 comments on commit 4c86fbd

Please sign in to comment.