Skip to content

Commit

Permalink
Cambricon MLUs support SDPA and flash_attn (huggingface#31102)
Browse files Browse the repository at this point in the history
* add Cambricon MLUs support

* fix mlu device rng state

* up for quality check

* up mlu to support fp16

* fix mlu device dependency error

* fix mlu device dependency error

* enable mlu device for bf16

* fix mlu device memory tracker

* Cambricon support SDPA and flash_attn
  • Loading branch information
huismiling authored and zucchini-nlp committed Jul 24, 2024
1 parent 40c4026 commit a6db708
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,9 @@ def is_torch_sdpa_available():
# NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons:
# - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259
# - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310
# NOTE: MLU is OK with non-contiguous inputs.
if is_torch_mlu_available():
return version.parse(_torch_version) >= version.parse("2.1.0")
# NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577
return version.parse(_torch_version) >= version.parse("2.1.1")

Expand Down Expand Up @@ -795,14 +798,16 @@ def is_flash_attn_2_available():
# Let's add an extra check to see if cuda is available
import torch

if not torch.cuda.is_available():
if not (torch.cuda.is_available() or is_torch_mlu_available()):
return False

if torch.version.cuda:
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.1.0")
elif torch.version.hip:
# TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4")
elif is_torch_mlu_available():
return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3")
else:
return False

Expand Down

0 comments on commit a6db708

Please sign in to comment.