-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon NO.73] 为 Paddle-TRT 添加 temporal_shift 算子 #51207
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
if (op_type == "temporal_shift") { | ||
#if !IS_TRT_VERSION_GE(8200) | ||
VLOG(3) << "temporal_shift is not supported when TensorRT < 8.5.1"; | ||
return false; | ||
#endif | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不支持静态shape可以在op teller里面说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 paddle/fluid/framework/ir/trt_support_nhwc_pass.cc 中设置了保持该算子输入的维度,目前已支持静态shape输入
# # for static_shape | ||
# clear_dynamic_shape() | ||
# self.trt_param.precision = paddle_infer.PrecisionType.Float32 | ||
# yield self.create_inference_config(), (1, 3), 1e-5 | ||
# self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
# yield self.create_inference_config(), (1, 3), 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分不应该删除,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
目前已保留
const int NT = input_dims.d[0]; | ||
const int C = input_dims.d[1]; | ||
const int H = input_dims.d[2]; | ||
const int W = input_dims.d[3]; | ||
const int N = NT / T; | ||
|
||
// Reshape input to [N,T,C,H,W] | ||
auto reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); | ||
nvinfer1::Dims reshape_dims{5, {N, T, C, H, W}}; | ||
reshape_layer->setReshapeDimensions(reshape_dims); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需考虑data_format为NHWC情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
在 op 代码中添加了 Permute 的代码,单测代码已添加 nhwc 输入的情况
slice_layer->setInput(2, *size); | ||
slice_layer->setMode(nvinfer1::SliceMode::kFILL); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Slice这种用法要求TRT 8.2+
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已添加 TRT 版本控制
def sample_program_configs(self): | ||
def generate_input1(attrs): | ||
T = attrs[0]["seg_num"] | ||
return np.ones([3 * T, 10, 64, 64]).astype(np.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改random值
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改
self.trt_param.precision = paddle_infer.PrecisionType.Float32 | ||
yield self.create_inference_config(), (0, 3), 1e-5 | ||
self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
yield self.create_inference_config(), (0, 3), 1e-3 | ||
|
||
# for dynamic_shape | ||
generate_dynamic_shape(attrs) | ||
self.trt_param.precision = paddle_infer.PrecisionType.Float32 | ||
yield self.create_inference_config(), (0, 3), 1e-5 | ||
self.trt_param.precision = paddle_infer.PrecisionType.Half | ||
yield self.create_inference_config(), (0, 3), 1e-3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里写法有问题,0, 3分别表示trt支持op数量与非trt op数量 参考其他单测增加版本判断
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
其他值会报错,check 的时候提示 动态和静态 需要的值是一样的
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这种写法表示没有跑trt, 参考Paddle/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py或者test_trt_convert_gelu.py里面的写法
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compile_version = paddle_infer.get_trt_compile_version() | ||
runtime_version = paddle_infer.get_trt_runtime_version() | ||
if ( | ||
compile_version[0] * 1000 | ||
+ compile_version[1] * 100 | ||
+ compile_version[2] * 10 | ||
< 8200 | ||
): | ||
return False | ||
if ( | ||
runtime_version[0] * 1000 | ||
+ runtime_version[1] * 100 | ||
+ runtime_version[2] * 10 | ||
< 8200 | ||
): | ||
return False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
generate_trt_nodes_num 有对应限制,这里就可以去掉
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除
namespace paddle { | ||
namespace framework { | ||
class Scope; | ||
|
||
namespace proto { | ||
class OpDesc; | ||
} // namespace proto | ||
} // namespace framework | ||
} // namespace paddle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也删掉
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成2023
bool test_mode) override { | ||
#if IS_TRT_VERSION_GE(8200) | ||
|
||
VLOG(3) << "convert a fluid temporal shift op to tensorrt temporal layer"; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
去掉fluid
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for const_cast
PR types
Others
PR changes
Others
Describe
为 Paddle-TRT 添加 temporal_shift 算子,代码思路如下所示(参考Paddle文档):
单测代码中添加 seg_num, shift_ratio 不同情况的测试,单测代码通过,本地单测结果如下图所示: