Skip to content

Commit

Permalink
[tests] enable benchmark unit tests on XPU (#29284)
Browse files Browse the repository at this point in the history
* add xpu for benchmark

* no auto_map

* use require_torch_gpu

* use gpu

* revert

* revert

* fix style
  • Loading branch information
faaany authored and Ita Zaporozhets committed May 14, 2024
1 parent 12b25a0 commit 9151818
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion src/transformers/benchmark/benchmark_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,14 @@
from dataclasses import dataclass, field
from typing import Tuple

from ..utils import cached_property, is_torch_available, is_torch_tpu_available, logging, requires_backends
from ..utils import (
cached_property,
is_torch_available,
is_torch_tpu_available,
is_torch_xpu_available,
logging,
requires_backends,
)
from .benchmark_args_utils import BenchmarkArguments


Expand Down Expand Up @@ -84,6 +91,9 @@ def _setup_devices(self) -> Tuple["torch.device", int]:
elif is_torch_tpu_available():
device = xm.xla_device()
n_gpu = 0
elif is_torch_xpu_available():
device = torch.device("xpu")
n_gpu = torch.xpu.device_count()
else:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
Expand Down

0 comments on commit 9151818

Please sign in to comment.