-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle TensorRT No.9-10] Add pd_op.(argmin,argsort)
converter
#69261
Changes from 8 commits
ff2b618
32a2f87
f884a7b
2b3946a
40eaa67
893d0e3
53a5840
83b6e96
911c63e
5f09dc7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1213,12 +1213,57 @@ class ArgmaxOpPattern | |
"data in arg_max."; | ||
return false; | ||
} | ||
auto x = op.x(); | ||
auto x_tensor_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>(); | ||
auto data_type = paddle::dialect::TransToPhiDataType(x_tensor_type.dtype()); | ||
if (!(data_type == phi::DataType::FLOAT32 || | ||
data_type == phi::DataType::FLOAT16 || | ||
data_type == phi::DataType::FLOAT64)) { | ||
pir::Value x = op.x(); | ||
auto data_type = pir::GetDataTypeFromValue(x); | ||
if (!(data_type.isa<pir::Float32Type>() || | ||
data_type.isa<pir::Float16Type>() || | ||
data_type.isa<pir::Float64Type>())) { | ||
VLOG(3) << "At present, pd_op.argmax only support float32 or float16 or " | ||
"float64 into trt."; | ||
return false; | ||
} | ||
int axis = static_cast<int>(op.axis() | ||
.defining_op() | ||
->attribute<pir::DoubleAttribute>("value") | ||
.data()); | ||
|
||
bool flatten = op.attribute<pir::BoolAttribute>("flatten").data(); | ||
phi::DataType dtype = | ||
op.attribute<paddle::dialect::DataTypeAttribute>("dtype").data(); | ||
if (axis == 0 || flatten || | ||
(dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { | ||
VLOG(3) << "Skipping TRT conversion in pd_op.argmax: axis is zero, " | ||
"flatten is True, or " | ||
"dtype isn't int32/int64"; | ||
return false; | ||
} | ||
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); | ||
return true; | ||
} | ||
}; | ||
|
||
class ArgminOpPattern | ||
: public pir::OpRewritePattern<paddle::dialect::ArgminOp> { | ||
public: | ||
using pir::OpRewritePattern<paddle::dialect::ArgminOp>::OpRewritePattern; | ||
bool MatchAndRewrite(paddle::dialect::ArgminOp op, | ||
pir::PatternRewriter &rewriter) const override { | ||
if (op->HasAttribute(kCanRunTrtAttr) && | ||
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { | ||
return false; | ||
} | ||
if (!op.axis().defining_op()->isa<paddle::dialect::FullOp>()) { | ||
VLOG(3) << "Skip to convert into TRT while found axis is not a constant " | ||
"data in arg_mix."; | ||
return false; | ||
} | ||
pir::Value x = op.x(); | ||
auto data_type = pir::GetDataTypeFromValue(x); | ||
if (!(data_type.isa<pir::Float32Type>() || | ||
data_type.isa<pir::Float16Type>() || | ||
data_type.isa<pir::Float64Type>())) { | ||
VLOG(3) << "At present, pd_op.argmin only support float32 or float16 or " | ||
"float64 into trt."; | ||
return false; | ||
} | ||
int axis = static_cast<int>(op.axis() | ||
|
@@ -1230,13 +1275,52 @@ class ArgmaxOpPattern | |
phi::DataType dtype = | ||
op.attribute<paddle::dialect::DataTypeAttribute>("dtype").data(); | ||
if (axis == 0 || flatten || | ||
(dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) | ||
(dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64)) { | ||
VLOG(3) << "Skipping TRT conversion in pd_op.argmin: axis is zero, " | ||
"flatten is True, or " | ||
"dtype isn't int32/int64"; | ||
return false; | ||
} | ||
|
||
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); | ||
return true; | ||
} | ||
}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同理把argmax改一下吧 |
||
|
||
class ArgsortOpPattern | ||
: public pir::OpRewritePattern<paddle::dialect::ArgsortOp> { | ||
public: | ||
using pir::OpRewritePattern<paddle::dialect::ArgsortOp>::OpRewritePattern; | ||
bool MatchAndRewrite(paddle::dialect::ArgsortOp op, | ||
pir::PatternRewriter &rewriter) const override { | ||
if (op->HasAttribute(kCanRunTrtAttr) && | ||
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { | ||
return false; | ||
} | ||
const std::vector<std::string> required_attrs = {"axis", "descending"}; | ||
for (const auto &attr : required_attrs) { | ||
if (!op->HasAttribute(attr)) { | ||
VLOG(3) << "pd_op.argsort " << attr << " attribute does not exist"; | ||
return false; | ||
} | ||
} | ||
pir::Value x = op.x(); | ||
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pir::GetDataTypeFromValue(x) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里不是获取 datatype,应该不需要用这个函数~ |
||
auto x_shape = x_type.dims(); | ||
int axis = op->attribute<pir::Int32Attribute>("axis").data(); | ||
if (axis < 0) { | ||
axis += x_shape.size(); | ||
} | ||
if (x_shape[axis] > 3840 || x_shape[axis] < 0) { | ||
VLOG(3) << "In pd_op.argsort,the axis dim of input should be less than " | ||
"3840 and greater " | ||
"than 0 in Tensorrt"; | ||
return false; | ||
} | ||
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true)); | ||
return true; | ||
} | ||
}; | ||
class BilinearInterpV2Pattern | ||
: public pir::OpRewritePattern<paddle::dialect::BilinearInterpOp> { | ||
public: | ||
|
@@ -1682,6 +1766,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass { | |
ps.Add(std::make_unique<RemainderOpPattern>(context)); | ||
ps.Add(std::make_unique<MulticlassNms3OpPattern>(context)); | ||
ps.Add(std::make_unique<ArgmaxOpPattern>(context)); | ||
ps.Add(std::make_unique<ArgminOpPattern>(context)); | ||
ps.Add(std::make_unique<ArgsortOpPattern>(context)); | ||
ps.Add(std::make_unique<MaxOpPattern>(context)); | ||
ps.Add(std::make_unique<MinOpPattern>(context)); | ||
ps.Add(std::make_unique<AllOpPattern>(context)); | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,8 +16,11 @@ | |
import tensorrt as trt | ||
|
||
from paddle.tensorrt.converter_utils import ( | ||
get_shape_tensor_element, | ||
squeeze_trt, | ||
trt_cast, | ||
trt_reshape, | ||
trt_shape, | ||
unsqueeze_trt, | ||
) | ||
from paddle.tensorrt.register import converter_registry | ||
|
@@ -66,6 +69,77 @@ def argmax_converter(network, paddle_op, inputs): | |
return squeeze_layer.get_output(0) | ||
|
||
|
||
@converter_registry.register("pd_op.argmin", trt_version="8.x") | ||
def argmin_converter(network, paddle_op, inputs): | ||
x = inputs[0] | ||
input_dims = x.shape | ||
rank = len(input_dims) | ||
axis = int( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里还需要支持axis为pir::value的输入,也需要进入trt,同理可以把pd_op.argmax补充一下 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Marker Pass 除了 full op 产生的 Value 情况会不进入 Tensorrt |
||
paddle_op.operands()[1] | ||
.source() | ||
.get_defining_op() | ||
.attrs() | ||
.get("value", -1) | ||
) | ||
keepdims = paddle_op.attrs()["keepdims"] | ||
|
||
if axis < 0: | ||
axis += rank | ||
|
||
topk_layer = network.add_topk( | ||
input=x, op=trt.TopKOperation.MIN, k=1, axes=(1 << axis) | ||
) | ||
|
||
if keepdims: | ||
return topk_layer.get_output(1) | ||
else: | ||
squeeze_layer = network.add_shuffle(topk_layer.get_output(1)) | ||
output_dims = [] | ||
for i in range(len(input_dims)): | ||
if i == axis: | ||
continue | ||
output_dims.append(input_dims[i]) | ||
squeeze_layer.reshape_dims = tuple(output_dims) | ||
return squeeze_layer.get_output(0) | ||
|
||
|
||
@converter_registry.register("pd_op.argsort", trt_version="8.x") | ||
def argsort_converter(network, paddle_op, inputs): | ||
input_tensor = inputs[0] | ||
input_shape = input_tensor.shape | ||
in_type = input_tensor.dtype | ||
in_rank = len(input_shape) | ||
axis = paddle_op.attrs()["axis"] | ||
descending = paddle_op.attrs()["descending"] | ||
if axis < 0: | ||
axis += len(input_shape) | ||
topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里converter可以参考argsort_op.cc,这里应该是少了很多情况 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 已完善 |
||
need_cast = True if in_type != trt.DataType.FLOAT else False | ||
if in_rank == 1: | ||
unsqueeze_shape = trt.Dims([1, -1]) | ||
input_tensor = trt_reshape( | ||
network, input_tensor, unsqueeze_shape, is_shape_tensor=True | ||
) | ||
axis = 1 | ||
if need_cast: | ||
input_tensor = trt_cast(network, input_tensor, trt.DataType.FLOAT) | ||
topk_layer = network.add_topk(input_tensor, topk_op, 1, 1 << axis) | ||
shape = trt_shape(network, input_tensor) | ||
k_tensor = get_shape_tensor_element(network, shape, axis, True) | ||
topk_layer.set_input(1, k_tensor) | ||
out = topk_layer.get_output(0) | ||
indices = topk_layer.get_output(1) | ||
if in_rank == 1: | ||
squeeze_shape = trt.Dims([-1]) | ||
out = trt_reshape(network, out, squeeze_shape, is_shape_tensor=True) | ||
indices = trt_reshape( | ||
network, indices, squeeze_shape, is_shape_tensor=True | ||
) | ||
out_tensor = trt_cast(network, out, in_type) | ||
indices_tensor = trt_cast(network, indices, indices.dtype) | ||
return out_tensor, indices_tensor | ||
|
||
|
||
@converter_registry.register("pd_op.where", trt_version="8.x") | ||
def where_converter(network, paddle_op, inputs): | ||
condition = inputs[0] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里首先判断下,pir::GetDefiningOpForInput(op,1)->isapaddle:::dialect::FullOp,然后再去做下面的限制
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
前面已经判断过了,如果不是,就返回 false