-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【SCU】【Paddle TensorRT No.36】Add pd_op.flip
converter
#69724
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -2093,6 +2093,29 @@ class AssignValueOpPattern | |||
} | |||
}; | |||
|
|||
class FlipOpPattern : public pir::OpRewritePattern<paddle::dialect::FlipOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
无条件进入trt,pir-trt只支持trt8.0以上版本
input_shape = input_shape_layer.get_output(0) | ||
rank = len(input_tensor.shape) | ||
|
||
axis = paddle_op.attrs()["axis"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis是一个std::vector
stride_tensors = [] | ||
size_tensors = [] | ||
|
||
for i in range(rank): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
直接遍历axis.size(),严格按照flip_op.cc来写
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
按照flip_op.cc重新写了下逻辑
self.python_api = paddle.flip | ||
self.api_args = { | ||
"x": np.random.randn(2, 3, 4).astype("float32"), | ||
"axis": [0, 2], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
axis也需要测有负数
one_tensor = add_1D_constant_layer(network, [1]) | ||
iRec_layer = loop_layer.add_recurrence(zero_tensor) | ||
iCur = iRec_layer.get_output(0) | ||
iNext_layer = network.add_elementwise( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
trt_sum
def setUp(self): | ||
self.python_api = paddle.flip | ||
self.api_args = { | ||
"x": np.random.randn(2, 3, 4).astype("float32"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加一个测试int64的单测
PR Category
User Experience
PR Types
New features
Description
新增了
pd_op.flip
Marker和Converter(好像题目写成flipe了....)