Skip to content

Commit cfd9737

Browse files
committed
fix xpu
1 parent 7095b4a commit cfd9737

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

python/paddle/nn/functional/flash_attention.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2285,7 +2285,9 @@ def flashmask_attention(
22852285
f"Invalid shape of startend_row_indices, when causal is False, the last dimension should be either 2 or 4 but got {startend_row_indices.shape[-1]}"
22862286
)
22872287

2288-
if paddle.get_flags(["FLAGS_cudnn_deterministic"])[
2288+
if "xpu" in paddle.get_device():
2289+
fa_version = 2
2290+
elif paddle.get_flags(["FLAGS_cudnn_deterministic"])[
22892291
"FLAGS_cudnn_deterministic"
22902292
]:
22912293
fa_version = 2

0 commit comments

Comments
 (0)