From a6db70841a471b110e4c6dd13fcda18c45d2031c Mon Sep 17 00:00:00 2001 From: huismiling Date: Tue, 16 Jul 2024 20:33:22 +0800 Subject: [PATCH] Cambricon MLUs support SDPA and flash_attn (#31102) * add Cambricon MLUs support * fix mlu device rng state * up for quality check * up mlu to support fp16 * fix mlu device dependency error * fix mlu device dependency error * enable mlu device for bf16 * fix mlu device memory tracker * Cambricon support SDPA and flash_attn --- src/transformers/utils/import_utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index f47d4b7a5fb7df..bd14dd8cd7530c 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -329,6 +329,9 @@ def is_torch_sdpa_available(): # NOTE: We require torch>=2.1 (and not torch>=2.0) to use SDPA in Transformers for two reasons: # - Allow the global use of the `scale` argument introduced in https://github.com/pytorch/pytorch/pull/95259 # - Memory-efficient attention supports arbitrary attention_mask: https://github.com/pytorch/pytorch/pull/104310 + # NOTE: MLU is OK with non-contiguous inputs. + if is_torch_mlu_available(): + return version.parse(_torch_version) >= version.parse("2.1.0") # NOTE: We require torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous inputs: https://github.com/pytorch/pytorch/issues/112577 return version.parse(_torch_version) >= version.parse("2.1.1") @@ -795,7 +798,7 @@ def is_flash_attn_2_available(): # Let's add an extra check to see if cuda is available import torch - if not torch.cuda.is_available(): + if not (torch.cuda.is_available() or is_torch_mlu_available()): return False if torch.version.cuda: @@ -803,6 +806,8 @@ def is_flash_attn_2_available(): elif torch.version.hip: # TODO: Bump the requirement to 2.1.0 once released in https://github.com/ROCmSoftwarePlatform/flash-attention return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.0.4") + elif is_torch_mlu_available(): + return version.parse(importlib.metadata.version("flash_attn")) >= version.parse("2.3.3") else: return False