From 40ca8423ceef23e90304005b041139683e34268c Mon Sep 17 00:00:00 2001 From: hirwa Date: Sat, 6 Apr 2024 10:57:06 +0530 Subject: [PATCH 1/3] adding env variable for mps and is_torch_mps_available for Pipeline --- src/transformers/pipelines/base.py | 3 +++ src/transformers/utils/import_utils.py | 9 ++++++++- 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index fa1f2fcf5dfa06..4d24f6e9094716 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -44,6 +44,7 @@ is_torch_mlu_available, is_torch_npu_available, is_torch_xpu_available, + is_torch_mps_available, logging, ) @@ -860,6 +861,8 @@ def __init__( self.device = torch.device(f"npu:{device}") elif is_torch_xpu_available(check_device=True): self.device = torch.device(f"xpu:{device}") + elif is_torch_mps_available(): + self.device = torch.device(f"mps:{device}") else: raise ValueError(f"{device} unrecognized or not available.") else: diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 486df111856ae9..9e29806f8f39f2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -33,10 +33,17 @@ from . import logging - logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Some functions in torch don't support MPS, +# So we need to enable fallback to CPU for those functions +# Check https://github.com/pytorch/pytorch/issues/77764 for more details +os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" +warnings.filterwarnings("ignore", message=".*is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.*") + + + # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: # Check if the package spec exists and grab its version to avoid importing a local directory From a7bcbd4a6c9db2c7ddc739ae04b4c663b1bc0401 Mon Sep 17 00:00:00 2001 From: hirwa Date: Sat, 6 Apr 2024 10:59:47 +0530 Subject: [PATCH 2/3] fix linting errors --- src/transformers/pipelines/base.py | 2 +- src/transformers/utils/import_utils.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/src/transformers/pipelines/base.py b/src/transformers/pipelines/base.py index 4d24f6e9094716..7225a6136e48ad 100644 --- a/src/transformers/pipelines/base.py +++ b/src/transformers/pipelines/base.py @@ -42,9 +42,9 @@ is_torch_available, is_torch_cuda_available, is_torch_mlu_available, + is_torch_mps_available, is_torch_npu_available, is_torch_xpu_available, - is_torch_mps_available, logging, ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 9e29806f8f39f2..0b61b8ff2b1731 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -33,6 +33,7 @@ from . import logging + logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -40,8 +41,10 @@ # So we need to enable fallback to CPU for those functions # Check https://github.com/pytorch/pytorch/issues/77764 for more details os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -warnings.filterwarnings("ignore", message=".*is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.*") - +warnings.filterwarnings( + "ignore", + message=".*is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.*", +) # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. From 6ae2f265f962369729bdfcc16a4791f7e8f06889 Mon Sep 17 00:00:00 2001 From: Felix Hirwa Nshuti Date: Mon, 8 Apr 2024 20:06:58 +0530 Subject: [PATCH 3/3] Remove environment overide Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/utils/import_utils.py | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 0b61b8ff2b1731..486df111856ae9 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -37,16 +37,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -# Some functions in torch don't support MPS, -# So we need to enable fallback to CPU for those functions -# Check https://github.com/pytorch/pytorch/issues/77764 for more details -os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -warnings.filterwarnings( - "ignore", - message=".*is not currently supported on the MPS backend and will fall back to run on the CPU. This may have performance implications.*", -) - - # TODO: This doesn't work for all packages (`bs4`, `faiss`, etc.) Talk to Sylvain to see how to do with it better. def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: # Check if the package spec exists and grab its version to avoid importing a local directory