Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix paddle.where #9652

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1593,13 +1593,16 @@
if get_env_device() in ["npu", "mlu", "intel_hpu"]:
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(paddle.finfo(dtype).min, dtype="float32")
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
elif get_env_device() in ["xpu", "gcu"]:
min_val = paddle.finfo(dtype).min if get_env_device() == "gcu" else -1e37 # mask value for xpu
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)
elif get_env_device() == "xpu":
x = paddle.to_tensor(0.0, dtype="float32")
y = paddle.to_tensor(-1.7005809656952787e38, dtype="float32")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里默认fp32吗?结束之后缺少.astype(dtype),是否结果一致?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,xpu的fa算子需要输入mask为fp32数据类型

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已确认,XPU使用fp32

expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y)
elif get_env_device() == "gcu":
min_val = paddle.finfo(dtype).min

Check warning on line 1602 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1596-L1602

Added lines #L1596 - L1602 were not covered by tests
x = paddle.to_tensor(0.0, dtype=dtype)
y = paddle.to_tensor(min_val, dtype=dtype)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask, x, y).astype(dtype)
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), x, y).astype(dtype)

Check warning on line 1605 in paddlenlp/transformers/llama/modeling.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/transformers/llama/modeling.py#L1605

Added line #L1605 was not covered by tests
else:
expanded_attn_mask = paddle.where(expanded_attn_mask.cast("bool"), 0.0, paddle.finfo(dtype).min)
expanded_attn_mask = expanded_attn_mask.astype(dtype)
Expand Down