-
Notifications
You must be signed in to change notification settings - Fork 5.9k
修复paddle.incubate.nn.functional.fused_rotary_position_embedding的非法地址访问问题 #74347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
修复paddle.incubate.nn.functional.fused_rotary_position_embedding的非法地址访问问题 #74347
Conversation
|
你的PR提交成功,感谢你对开源项目的贡献! |
| common::errors::InvalidArgument("The batch_size of q (%d) must be less " | ||
| "than or equal to k's (%d) to " | ||
| "prevent out-of-bounds memory access.", | ||
| batch_size, | ||
| k_batch_size)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这样限制的目的,到底是因为我们算子实现不完备,所以不得不进行限制,还是从算法原理上就不允许超过?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们的文档是没有详细描述的,我自己了解以后的判断是k和v的batch_size小于q的batch_size是不合理的,标准的应该是严格相等,这里我选择了保留大于的情况是因为q_batch_size < k/v_batch_size的测例可以pass,当前改法应该是改动最小的修改方法。
对于q_batch_size > k/v_batch_size这种情况,应该需要类似于广播机制这种额外的处理。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
那这个报错信息就不够清晰,应该报错这是违背定义的情况,让用户知道自己写错了;而不是为了防止越界,这样变成是好像是我们错了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
了解了
|
/re-run all-failed |
PR Category
Execute Infrastructure
PR Types
Bug fixes
Description
本次 PR 主要解决了
fused_rotary_position_embedding函数中出现的 CUDA error 700 (illegal address) 问题。问题根源:
该错误是由于在计算时,代码默认使用了
query(q) 的batch_size作为key(k) 和value(v) 张量的batch_size_stride。当q的batch_size大于k或v的batch_size时,会导致显存的非法地址访问,从而引发 CUDA 错误。此问题并非仅限于大Tensor的场景,以下测例同样可以复现该错误:
解决方案:
限制q的batch_size不能超过k,v的batch_size大小。