Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle TensorRT No.9-10] Add pd_op.(argmin,argsort) converter #69261

74 changes: 74 additions & 0 deletions paddle/fluid/pir/transforms/tensorrt/trt_op_marker_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,78 @@ class ArgmaxOpPattern
}
};

class ArgminOpPattern
: public pir::OpRewritePattern<paddle::dialect::ArgminOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ArgminOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::ArgminOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
if (!op.axis().defining_op()->isa<paddle::dialect::FullOp>()) {
VLOG(3) << "Skip to convert into TRT while found axis is not a constant "
"data in arg_max.";
return false;
}
auto x = op.x();
auto x_tensor_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

获取dtype使用pir::GetDataTypeFromValue(x),可参考ScaleOpPattern

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你有如流账号吗,或者加vx,方便沟通下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有如流的账号,但是不知道怎么加,手机vx是 18268023940

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没查到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

auto data_type = paddle::dialect::TransToPhiDataType(x_tensor_type.dtype());
if (!(data_type == phi::DataType::FLOAT32 ||
data_type == phi::DataType::FLOAT16 ||
data_type == phi::DataType::FLOAT64)) {
return false;
}
int axis = static_cast<int>(op.axis()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里首先判断下,pir::GetDefiningOpForInput(op,1)->isapaddle:::dialect::FullOp,然后再去做下面的限制

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面已经判断过了,如果不是,就返回 false

.defining_op()
->attribute<pir::DoubleAttribute>("value")
.data());

bool flatten = op.attribute<pir::BoolAttribute>("flatten").data();
phi::DataType dtype =
op.attribute<paddle::dialect::DataTypeAttribute>("dtype").data();
if (axis == 0 || flatten ||
(dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面加一个VLOG(3)的打印,pd_op.argmin因为什么条件不能进入trt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

return false;
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理把argmax改一下吧


class ArgsortOpPattern
: public pir::OpRewritePattern<paddle::dialect::ArgsortOp> {
public:
using pir::OpRewritePattern<paddle::dialect::ArgsortOp>::OpRewritePattern;
bool MatchAndRewrite(paddle::dialect::ArgsortOp op,
pir::PatternRewriter &rewriter) const override {
if (op->HasAttribute(kCanRunTrtAttr) &&
op.attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) {
return false;
}
const std::vector<std::string> required_attrs = {"axis", "descending"};
for (const auto &attr : required_attrs) {
if (!op->HasAttribute(attr)) {
VLOG(3) << "Argsort " << attr << " attribute does not exist";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pd_op.argsort

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

return false;
}
}
auto x = op.x();
auto x_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pir::GetDataTypeFromValue(x)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里不是获取 datatype,应该不需要用这个函数~

auto x_shape = x_type.dims();
int axis = op->attribute<pir::Int32Attribute>("axis").data();
if (axis < 0) {
axis += x_shape.size();
}
if (x_shape[axis] > 3840 || x_shape[axis] < 0) {
VLOG(3) << "The axis dim of input should be less than 3840 and greater "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个pd_op.argsort吧vlog里面

"than 0 in Tensorrt argsort";
return false;
}
op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
class BilinearInterpV2Pattern
: public pir::OpRewritePattern<paddle::dialect::BilinearInterpOp> {
public:
Expand Down Expand Up @@ -1682,6 +1754,8 @@ class TrtOpMarkerPass : public pir::PatternRewritePass {
ps.Add(std::make_unique<RemainderOpPattern>(context));
ps.Add(std::make_unique<MulticlassNms3OpPattern>(context));
ps.Add(std::make_unique<ArgmaxOpPattern>(context));
ps.Add(std::make_unique<ArgminOpPattern>(context));
ps.Add(std::make_unique<ArgsortOpPattern>(context));
ps.Add(std::make_unique<MaxOpPattern>(context));
ps.Add(std::make_unique<MinOpPattern>(context));
ps.Add(std::make_unique<AllOpPattern>(context));
Expand Down
50 changes: 50 additions & 0 deletions python/paddle/tensorrt/impls/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,56 @@ def argmax_converter(network, paddle_op, inputs):
return squeeze_layer.get_output(0)


@converter_registry.register("pd_op.argmin", trt_version="8.x")
def argmin_converter(network, paddle_op, inputs):
x = inputs[0]
input_dims = x.shape
rank = len(input_dims)
axis = int(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还需要支持axis为pir::value的输入,也需要进入trt,同理可以把pd_op.argmax补充一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marker Pass 除了 full op 产生的 Value 情况会不进入 Tensorrt

paddle_op.operands()[1]
.source()
.get_defining_op()
.attrs()
.get("value", -1)
)
keepdims = paddle_op.attrs()["keepdims"]

if axis < 0:
axis += rank

topk_layer = network.add_topk(
input=x, op=trt.TopKOperation.Min, k=1, axes=(1 << axis)
)

if keepdims:
return topk_layer.get_output(1)
else:
squeeze_layer = network.add_shuffle(topk_layer.get_output(1))
output_dims = []
for i in range(len(input_dims)):
if i == axis:
continue
output_dims.append(input_dims[i])
squeeze_layer.reshape_dims = tuple(output_dims)
return squeeze_layer.get_output(0)


@converter_registry.register("pd_op.argsort", trt_version="8.x")
def argsort_converter(network, paddle_op, inputs):
input_tensor = inputs[0]
input_shape = input_tensor.shape
# The following two attributes is judged in Marker Pass.
# Default value maybe redundant.
axis = paddle_op.attrs().get("axis", -1)
descending = paddle_op.attrs().get("descending", False)
if axis < 0:
axis += len(input_shape)
topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里converter可以参考argsort_op.cc,这里应该是少了很多情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已完善

k = input_shape[axis]
topk_layer = network.add_topk(input_tensor, topk_op, k, 1 << axis)
return topk_layer.get_output(1)


@converter_registry.register("pd_op.where", trt_version="8.x")
def where_converter(network, paddle_op, inputs):
condition = inputs[0]
Expand Down
70 changes: 70 additions & 0 deletions test/tensorrt/test_converter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ def test_trt_result(self):
self.check_trt_result()


class TestArgminCase1TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.argmin
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
"axis": -1,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}


class TestWhereTRTPatternCase1(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.where
Expand All @@ -51,6 +63,64 @@ def test_trt_result(self):
self.check_trt_result()


class TestArgminCase2TRTPattern(TensorRTBaseTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里再补充一下axis为pir::value的场景,使用np.array([1]),feed_list中加入axis,但是min_shape,和max_shape不需要写

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marker Pass 除了 full op 产生的 Value 情况会不进入 Tensorrt

def setUp(self):
self.python_api = paddle.argmin
self.api_args = {
"x": np.random.randn(2, 3).astype(np.int64),
"axis": -1,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


class TestArgsortCase1TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.argsort
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测里面改成"float32",其余的同理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"axis": -1,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


class TestArgsortCase2TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.argsort
self.api_args = {
"x": np.random.randn(2, 3).astype(np.int64),
"axis": -1,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}

def test_trt_result(self):
self.check_trt_result()


class TestArgsortCase3TRTPattern(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.argsort
self.api_args = {
"x": np.random.randn(2, 3).astype(np.int64),
"axis": -1,
"descending": True,
}
self.program_config = {"feed_list": ["x"]}
self.min_shape = {"x": [1, 3]}
self.max_shape = {"x": [5, 3]}


class TestWhereTRTPatternCase2(TensorRTBaseTest):
def setUp(self):
self.python_api = paddle.where
Expand Down