diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 0c1da8334ea706..463f134217582c 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1844,11 +1844,6 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 torch.cuda.set_device(device) - elif is_torch_xpu_available() and "ACCELERATE_USE_XPU" not in os.environ: - os.environ["ACCELERATE_USE_XPU"] = "true" - self.distributed_state = PartialState(timeout=timedelta(seconds=self.ddp_timeout)) - device = torch.device("xpu:0") - self._n_gpu = 1 elif is_sagemaker_dp_enabled(): self.distributed_state = PartialState(_use_sagemaker_dp=True) self._n_gpu = 1 @@ -1877,12 +1872,6 @@ def _setup_devices(self) -> "torch.device": elif is_sagemaker_dp_enabled() or is_sagemaker_mp_enabled(): # Already set _n_gpu pass - elif self.distributed_state.distributed_type == DistributedType.MULTI_XPU: - if "ACCELERATE_USE_XPU" not in os.environ: - os.environ["ACCELERATE_USE_XPU"] = "true" - self._n_gpu = 1 - device = torch.device("xpu:0") - torch.xpu.set_device(device) elif self.distributed_state.distributed_type == DistributedType.NO: if self.use_mps_device: warnings.warn(