-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hackathon NO.74] 为 Paddle-TRT 添加 grid_sampler 算子 #50934
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
paddle/fluid/inference/tensorrt/dynamic_shape_infermeta_registry.h
Outdated
Show resolved
Hide resolved
python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_grid_sampler.py
Outdated
Show resolved
Hide resolved
class GridSamplerOpConverter : public OpConverter { | ||
public: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个类在trt 8.5以下会有编译问题,可以参考one_hot处理方式
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已增加宏,防止编译出错
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR-CI-Coverage,没有过,这个CI无法点击重新运行,应该怎么重新跑呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PR-CI-Coverage,没有过,这个CI无法点击重新运行,应该怎么重新跑呢
这个ci Coverage问题,由于单测环境没有trt 8.5,可以在本地用trt8.5验证下python python/paddle/fluid/tests/unittests/ir/inference/test_trt_convert_grid_sampler.py
。并附上正确结果截图
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4cbc13d
to
a166427
Compare
a166427
to
a017a1d
Compare
K2-F7 是在bd 科技园办公吗?加个vx: shaoshuai2100 一起组队? 我在K4 |
desc = { | ||
"mode": "bilinear", | ||
"padding_mode": "border", | ||
"align_corners": True, | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测可能需要补全,支持align_corners
、mode
、padding_mode
属性组合,另外输入会有[N, C, D, H, W]输入形式。定义见https://www.paddlepaddle.org.cn/documentation/docs/zh/api/paddle/nn/functional/grid_sample_cn.html#grid-sample,实现可参考https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/cpu/grid_sample_kernel.cc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Describe
grid_sampler之前采用通用plugin的方式convert到TRT,现在TRT8.5以上支持了的grid_sampler的layer,因此补充TRT的映射方式,如果用户安装TRT版本小于8.5,则依旧通过通用plugin的方式转换grid_sampler,python的单测已经通过。
![0b914f72b223e89c5169a51fc46306d5](https://user-images.githubusercontent.com/88373061/221940022-6ac08750-3420-4c2b-b318-fb1d70b95583.jpg)
下图为在本地TRT8.5下python单测的验证结果,"in in in"为grid_sampler.cc中的打印信息(该打印信息仅作为测试使用,提PR时已经删除),可见在TRT8.5中通过convert layer成功转换了grid_sampler算子(此时没有使用通用plugin)。
另外TRT8.5中,gird_sampler layer只支持四维,在op_teller中加入了判断,如果input和grid不为四维,就return false,通过通用plugin的方式,通用plugin的grid_sampler中还有一个小bug,一起修改了。
![image](https://user-images.githubusercontent.com/88373061/221940252-cf05bc58-26fb-4109-b2c3-659853b2821b.png)