-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【BUPT】【Paddle TensorRT No.18】Add pd_op.conv3d and pd_op.conv3d_transpose converter #69757
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
op->attribute<pir::BoolAttribute>(kCanRunTrtAttr).data()) { | ||
return false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
padding_algorithm没检查
@@ -374,6 +375,42 @@ class DepthwiseConv2dTransposeOpPattern | |||
} | |||
}; | |||
|
|||
class Conv3dTransposeOpPattern | |||
: public pir::OpRewritePattern<paddle::dialect::Conv3dTransposeOp> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
你写的到底是conv3d还是conv3d_transpose
if (op->HasAttribute("padding_algorithm")) { | ||
auto padding_algorithm = | ||
op->attribute<pir::StrAttribute>("padding_algorithm").AsString(); | ||
if (padding_algorithm == "SAME") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个检查不对,conv3d_transpose且是same且不是动态shape才不让进trt
VLOG(3) << "In conv3d, paddings size must be less than or equal to 3"; | ||
return false; | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
看的仔细点,dilations属性没检查
dilation = paddle_op.attrs().get("dilations", [1, 1, 1]) | ||
groups = paddle_op.attrs().get("groups", 1) | ||
|
||
if has_dynamic_shape(input_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个不用检查,一定是动态shape
|
||
output_padding = paddle_op.attrs().get("output_padding", [0, 0, 0]) | ||
padding_algorithm = paddle_op.attrs().get("padding_algorithm", "EXPLICIT") | ||
if padding_algorithm == "VALID": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一句哪来的,我怎么没在旧ir里找到?
nv_dilations = trt.Dims3(dilation[0], dilation[1], dilation[2]) | ||
nv_strides = trt.Dims3(stride[0], stride[1], stride[2]) | ||
|
||
pre_paddings = [0, 0, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
pre_paddings应该是trt.Dims(paddings...吧)
pre_paddings = [0, 0, 0] | ||
post_paddings = [0, 0, 0] | ||
|
||
if len(paddings) == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这段代码在旧ir下没找到,怎么写的?
post_paddings[0] = paddings[0] | ||
post_paddings[1] = paddings[1] | ||
post_paddings[2] = paddings[2] | ||
elif len(paddings) == 6: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
} | ||
} | ||
|
||
if (op->HasAttribute("output_padding")) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这一块可以去掉了
if (op->HasAttribute("padding_algorithm")) { | ||
auto padding_algorithm = | ||
op->attribute<pir::StrAttribute>("padding_algorithm").AsString(); | ||
if (padding_algorithm == "SAME") { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里不对,如果是conv3d_transpose且是same且不是动态shape,则return false
if paddle_op.name() == "pd_op.conv3d": | ||
input_tensor, filter = inputs | ||
elif paddle_op.name() == "pd_op.conv3d_transpose": | ||
if len(inputs) == 3: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里哪来的输入长度为3?
input_tensor, filter, output_size = inputs | ||
elif len(inputs) == 2: | ||
input_tensor, filter = inputs | ||
output_size = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
哪来的output_size?
return conv(x) | ||
|
||
|
||
class TestConv3dTRTPattern(TensorRTBaseTest): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测太少了,再测试下padding_algorithm为SAME,为valid和EXPLICIT,把conv3d_transpose补上吧,
Sorry to inform you that 4e377ec's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
PR Category
User Experience
PR Types
New features
Description
新增pd_op.conv3d and pd_op.conv3d_transpose