From 915181802bf4b7a15d9f3483c7026f1c9e175512 Mon Sep 17 00:00:00 2001 From: Fanli Lin Date: Tue, 27 Feb 2024 17:44:48 +0800 Subject: [PATCH] [tests] enable benchmark unit tests on XPU (#29284) * add xpu for benchmark * no auto_map * use require_torch_gpu * use gpu * revert * revert * fix style --- src/transformers/benchmark/benchmark_args.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/transformers/benchmark/benchmark_args.py b/src/transformers/benchmark/benchmark_args.py index b5887e4a9bcb4b..c20683e416843b 100644 --- a/src/transformers/benchmark/benchmark_args.py +++ b/src/transformers/benchmark/benchmark_args.py @@ -17,7 +17,14 @@ from dataclasses import dataclass, field from typing import Tuple -from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends +from ..utils import ( + cached_property, + is_torch_available, + is_torch_tpu_available, + is_torch_xpu_available, + logging, + requires_backends, +) from .benchmark_args_utils import BenchmarkArguments @@ -84,6 +91,9 @@ def _setup_devices(self) -> Tuple["torch.device", int]: elif is_torch_tpu_available(): device = xm.xla_device() n_gpu = 0 + elif is_torch_xpu_available(): + device = torch.device("xpu") + n_gpu = torch.xpu.device_count() else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") n_gpu = torch.cuda.device_count()