-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support XPU for auto-paralllel LLaMa #9796
base: develop
Are you sure you want to change the base?
Support XPU for auto-paralllel LLaMa #9796
Conversation
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (6.45%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #9796 +/- ##
===========================================
+ Coverage 52.06% 52.20% +0.14%
===========================================
Files 734 730 -4
Lines 116591 115836 -755
===========================================
- Hits 60703 60475 -228
+ Misses 55888 55361 -527 ☔ View full report in Codecov by Sentry. |
if get_env_device() in ["npu", "mlu", "intel_hpu"]: | ||
x = paddle.to_tensor(0.0, dtype="float32") | ||
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) | ||
elif get_env_device() == "xpu": | ||
x = paddle.to_tensor(0.0, dtype="float32") | ||
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32") | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y) | ||
elif get_env_device() == "gcu": | ||
min_val = paddle.finfo(dtype).min | ||
x = paddle.to_tensor(0.0, dtype=dtype) | ||
y = paddle.to_tensor(min_val, dtype=dtype) | ||
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's mask generation differs between different devices.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The mask generation logic is same as here: https://github.com/PaddlePaddle/PaddleNLP/blob/develop/paddlenlp/transformers/llama/modeling.py#L1606.
For the following two reasons that XPU needs a different mask:
- The flash_attention kernel implemented in XPU is different than in GPU, which may lead to numeric overflow when the mask value is too small. Therefore, a specific mask number
-1.7005809656952787e38
is needed. @runzhech is fixing this issue and we can usepaddle.finfo(dtype).min
like GPU after fixed. - The flash_attention kernel in XPU requires the mask input to be
float32
,so theastype(dtype)
cannot be added in XPU mask generation.
PR types
New features
PR changes
Models
Description
Llama模型适配xpu自动并行训练,目前仅支持动态图+纯dp(只包含allreduce通信)。
依赖主框架PR:PaddlePaddle/Paddle#70997