diff --git a/src/peft/utils/other.py b/src/peft/utils/other.py index 05daf3d64e..18d46bfc35 100644 --- a/src/peft/utils/other.py +++ b/src/peft/utils/other.py @@ -20,7 +20,7 @@ import accelerate import torch from accelerate.hooks import add_hook_to_module, remove_hook_from_module -from accelerate.utils import is_xpu_available +from accelerate.utils import is_npu_available, is_xpu_available # Get current device name based on available devices @@ -29,6 +29,8 @@ def infer_device(): torch_device = "cuda" elif is_xpu_available(): torch_device = "xpu" + elif is_npu_available(): + torch_device = "npu" else: torch_device = "cpu" return torch_device