-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix bug when dtype=fp16 in deformable conv #46975
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
2ce33ef
to
6ab7413
Compare
DenseTensor mt_dx = phi::EmptyLike<MT, Context>(dev_ctx, *dx); | ||
MT* mt_dx_ptr = (x.dtype() == DataType::FLOAT16) | ||
? (dev_ctx.template Alloc<MT>(&mt_dx)) | ||
: (dev_ctx.template Alloc<MT>(dx)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zhangting2020 这里已修改 从L305 ~ L324,对于fp16的情况做了结果转换。本地测试结果符合预期
由于前向加速的先合入了,导致这个代码有一行conflict。已修复重新跑CI流程 |
目前这个实现方式你创建了一个新的tensor,会增加内存,同时又插入了一个cast op,也会引入额外的代价。这个方案需要再斟酌,你可以借鉴下lookup_table_v2算子的梯度实现,其中也使用了原子操作,但我们过去也为它做了fp16的优化。 |
好的 我参考下实现,试一下速度是否可以满足要求 |
@zhangting2020 paddle/fluid/operators/lookup_table_v2_op.cu |
VectorizedAtomicAddPerBlock看了下实现也是调用了CudaAtomicAdd,在fp16时速度和fp32差距还是会比较大。 |
@zhangting2020 现在的代码会在fp32和fp64的时候也多增加一个新的tensor内存,这部分确实是浪费的。如果只在fp16是增加新的tensor和Cast_op,这样用内存换时间是可以接受的吗? |
lookup_table_v2的优化方式我理解是在将AtomAdd的过程单独作为一个kernel,且总数据量是偶数时,将两个数据构造为__half2,利用一次计算实现。 |
在Paddle框架中混合精度训练机制会在算子计算前将输入cast到fp32,计算用fp32,计算结果cast为fp16,但这样会引入额外的cast算子的开销,并且fp16相比fp32没有显著加速。正是为了提升混合精度训练的性能,才会设计这项任务。目前这版的方案恐怕无法满足预期,因为这个任务的重点之一就是对fp16性能完成优化 |
这项任务总共设计四部分计算,前向output使用fp16速度有提升,后向grad_offset, grad_filter的速度和fp32持平,只有grad_input由于计算的特殊引入cast增加了大概1.6%的计算时间。四部分整体上是和fp32持平的,测试中各有胜负,当时题目中要求也是性能不差于fp32。 |
benchmark中例子速度基本一致,但我换用更大的测试用例后cast过程确实很耗时无法忽略 |
@zhangting2020 之前只依赖benchmark中唯一的例子我以为cast的开销是很小的,换更大用例后发现还是挺高的。dx这里的计算我后边再继续想想办法,参考其他算子只使用更快的atomAdd也无法满足优于fp32的需要,可能需要改下dx的计算过程,今天比赛截止前应该无法满足更优实现。 |
PR types
New features
PR changes
OPs
Describe
修复FP16中dx精度bug,基于#46111