Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle TensorRT No.9-10] Add pd_op.(argmin,argsort) converter #69261

Conversation

ooooo-create
Copy link
Contributor

@ooooo-create ooooo-create commented Nov 9, 2024

PR Category

Inference

PR Types

New features

Description

Copy link

paddle-bot bot commented Nov 9, 2024

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@paddle-bot paddle-bot bot added the contributor External developers label Nov 9, 2024
@ooooo-create ooooo-create changed the title [Paddle TensorRT] No.8,9 Add pd_op.(argmin,argsort) converter [Paddle TensorRT No.8-9] Add pd_op.(argmin,argsort) converter Nov 9, 2024
@ooooo-create ooooo-create changed the title [Paddle TensorRT No.8-9] Add pd_op.(argmin,argsort) converter [Paddle TensorRT No.9-10] Add pd_op.(argmin,argsort) converter Nov 11, 2024
return false;
}
auto x = op.x();
auto x_tensor_type = x.type().dyn_cast<paddle::dialect::DenseTensorType>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

获取dtype使用pir::GetDataTypeFromValue(x),可参考ScaleOpPattern

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你有如流账号吗,或者加vx,方便沟通下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有如流的账号,但是不知道怎么加,手机vx是 18268023940

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

没查到

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

data_type == phi::DataType::FLOAT64)) {
return false;
}
int axis = static_cast<int>(op.axis()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里首先判断下,pir::GetDefiningOpForInput(op,1)->isapaddle:::dialect::FullOp,然后再去做下面的限制

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面已经判断过了,如果不是,就返回 false

phi::DataType dtype =
op.attribute<paddle::dialect::DataTypeAttribute>("dtype").data();
if (axis == 0 || flatten ||
(dtype != phi::DataType::INT32 && dtype != phi::DataType::INT64))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里面加一个VLOG(3)的打印,pd_op.argmin因为什么条件不能进入trt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加

op->set_attribute(kCanRunTrtAttr, rewriter.bool_attr(true));
return true;
}
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同理把argmax改一下吧

const std::vector<std::string> required_attrs = {"axis", "descending"};
for (const auto &attr : required_attrs) {
if (!op->HasAttribute(attr)) {
VLOG(3) << "Argsort " << attr << " attribute does not exist";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pd_op.argsort

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

axis += x_shape.size();
}
if (x_shape[axis] > 3840 || x_shape[axis] < 0) {
VLOG(3) << "The axis dim of input should be less than 3840 and greater "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加一个pd_op.argsort吧vlog里面

x = inputs[0]
input_dims = x.shape
rank = len(input_dims)
axis = int(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里还需要支持axis为pir::value的输入,也需要进入trt,同理可以把pd_op.argmax补充一下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marker Pass 除了 full op 产生的 Value 情况会不进入 Tensorrt

descending = paddle_op.attrs().get("descending", False)
if axis < 0:
axis += len(input_shape)
topk_op = trt.TopKOperation.MAX if descending else trt.TopKOperation.MIN
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里converter可以参考argsort_op.cc,这里应该是少了很多情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已完善

@@ -51,6 +63,64 @@ def test_trt_result(self):
self.check_trt_result()


class TestArgminCase2TRTPattern(TensorRTBaseTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里再补充一下axis为pir::value的场景,使用np.array([1]),feed_list中加入axis,但是min_shape,和max_shape不需要写

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marker Pass 除了 full op 产生的 Value 情况会不进入 Tensorrt

def setUp(self):
self.python_api = paddle.argsort
self.api_args = {
"x": np.random.randn(2, 3).astype(np.float32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

单测里面改成"float32",其余的同理

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@luotao1 luotao1 added the HappyOpenSource 快乐开源活动issue与PR label Nov 13, 2024
lizexu123
lizexu123 previously approved these changes Nov 13, 2024
Copy link
Contributor

@YuanRisheng YuanRisheng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需要追加单测,完善一下c++和python代码覆盖率,可以参考单测TestMulticlassNMS3Marker内容写法,完善一下marker代码的覆盖率,具体覆盖率可以打开看一下ci-coverage的情况,覆盖率要达到90%以上才行

self.check_trt_result()


class TestArgsortCase4TRTPattern(TensorRTBaseTest):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个其实不用再加一个吧,能不能把self.check_marker(expected_result=False)放在前面某个单测里面检查下

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

前面的 chekc_marker 应该都是 True 的吧,False 可以放在前面吗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

哦对,也是

@lizexu123 lizexu123 merged commit 5c4e4b4 into PaddlePaddle:develop Nov 15, 2024
27 of 28 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers HappyOpenSource 快乐开源活动issue与PR
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants