Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hackathon NO.73] 为 Paddle-TRT 添加 temporal_shift 算子 #51207

Merged
merged 19 commits into from
Mar 14, 2023

Conversation

AndSonder
Copy link
Contributor

PR types

Others

PR changes

Others

Describe

为 Paddle-TRT 添加 temporal_shift 算子,代码思路如下所示(参考Paddle文档):

image

单测代码中添加 seg_num, shift_ratio 不同情况的测试,单测代码通过,本地单测结果如下图所示:

image

image

@paddle-bot
Copy link

paddle-bot bot commented Mar 5, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

Comment on lines 2582 to 2587
if (op_type == "temporal_shift") {
#if !IS_TRT_VERSION_GE(8200)
VLOG(3) << "temporal_shift is not supported when TensorRT < 8.5.1";
return false;
#endif
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不支持静态shape可以在op teller里面说明

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 paddle/fluid/framework/ir/trt_support_nhwc_pass.cc 中设置了保持该算子输入的维度,目前已支持静态shape输入

Comment on lines 86 to 91
# # for static_shape
# clear_dynamic_shape()
# self.trt_param.precision = paddle_infer.PrecisionType.Float32
# yield self.create_inference_config(), (1, 3), 1e-5
# self.trt_param.precision = paddle_infer.PrecisionType.Half
# yield self.create_inference_config(), (1, 3), 1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这部分不应该删除,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

目前已保留

Comment on lines 50 to 59
const int NT = input_dims.d[0];
const int C = input_dims.d[1];
const int H = input_dims.d[2];
const int W = input_dims.d[3];
const int N = NT / T;

// Reshape input to [N,T,C,H,W]
auto reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
nvinfer1::Dims reshape_dims{5, {N, T, C, H, W}};
reshape_layer->setReshapeDimensions(reshape_dims);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

需考虑data_format为NHWC情况

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在 op 代码中添加了 Permute 的代码,单测代码已添加 nhwc 输入的情况

Comment on lines 113 to 114
slice_layer->setInput(2, *size);
slice_layer->setMode(nvinfer1::SliceMode::kFILL);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slice这种用法要求TRT 8.2+

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已添加 TRT 版本控制

def sample_program_configs(self):
def generate_input1(attrs):
T = attrs[0]["seg_num"]
return np.ones([3 * T, 10, 64, 64]).astype(np.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里改random值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

Comment on lines 91 to 101
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 3), 1e-3

# for dynamic_shape
generate_dynamic_shape(attrs)
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (0, 3), 1e-3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里写法有问题,0, 3分别表示trt支持op数量与非trt op数量 参考其他单测增加版本判断

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

其他值会报错,check 的时候提示 动态和静态 需要的值是一样的

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种写法表示没有跑trt, 参考Paddle/python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_cast.py或者test_trt_convert_gelu.py里面的写法

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修正,红框中的内容为 converter 中打印的信息,说明代码有运行到trt的部分

image

@AndSonder AndSonder requested a review from zhangjun March 9, 2023 08:14
Comment on lines 28 to 43
compile_version = paddle_infer.get_trt_compile_version()
runtime_version = paddle_infer.get_trt_runtime_version()
if (
compile_version[0] * 1000
+ compile_version[1] * 100
+ compile_version[2] * 10
< 8200
):
return False
if (
runtime_version[0] * 1000
+ runtime_version[1] * 100
+ runtime_version[2] * 10
< 8200
):
return False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generate_trt_nodes_num 有对应限制,这里就可以去掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已删除

@AndSonder AndSonder requested a review from zhangjun March 13, 2023 00:38
Comment on lines 17 to 25
namespace paddle {
namespace framework {
class Scope;

namespace proto {
class OpDesc;
} // namespace proto
} // namespace framework
} // namespace paddle
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也删掉

Comment on lines 1 to 2
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改成2023

bool test_mode) override {
#if IS_TRT_VERSION_GE(8200)

VLOG(3) << "convert a fluid temporal shift op to tensorrt temporal layer";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

去掉fluid

Copy link
Contributor

@zhangjun zhangjun left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Contributor

@luotao1 luotao1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM for const_cast

@luotao1 luotao1 merged commit e79699f into PaddlePaddle:develop Mar 14, 2023
@AndSonder AndSonder deleted the temporal_shift branch April 23, 2024 13:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor External developers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants