Skip to content

Commit

Permalink
Remove redundant backend checks in training_args.py (huggingface#30999)
Browse files Browse the repository at this point in the history
* Remove backend checks in training_args.py

* Expilicit initialize the device

---------

Co-authored-by: tonghengwen <tonghengwen@cambricon.com>
  • Loading branch information
2 people authored and vasqu committed Jun 1, 2024
1 parent 46b606e commit 7c472e6
Showing 1 changed file with 4 additions and 33 deletions.
37 changes: 4 additions & 33 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
import torch
import torch.distributed as dist

from .pytorch_utils import is_torch_greater_or_equal_than_2_0, is_torch_greater_or_equal_than_2_3
from .pytorch_utils import is_torch_greater_or_equal_than_2_0

if is_accelerate_available():
from accelerate.state import AcceleratorState, PartialState
Expand Down Expand Up @@ -1677,38 +1677,9 @@ def __post_init__(self):
)
self.accelerator_config.split_batches = self.split_batches

if (
self.framework == "pt"
and is_torch_available()
and (self.device.type == "cpu" and not is_torch_greater_or_equal_than_2_3)
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (self.fp16 or self.fp16_full_eval)
):
raise ValueError(
"FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation"
" (`--fp16_full_eval`) can only be used on CUDA or MLU devices or NPU devices or certain XPU devices (with IPEX)."
)

if (
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and (self.device.type != "mlu")
and (self.device.type != "npu")
and (self.device.type != "xpu")
and (get_xla_device_type(self.device) not in ["GPU", "CUDA"])
and (get_xla_device_type(self.device) != "TPU")
and (self.device.type != "cpu")
and (self.bf16 or self.bf16_full_eval)
):
raise ValueError(
"BF16 Mixed precision training with AMP (`--bf16`) and BF16 half precision evaluation"
" (`--bf16_full_eval`) can only be used on CUDA, XPU (with IPEX), NPU, MLU or CPU/TPU/NeuronCore devices."
)
# Initialize device before we proceed
if self.framework == "pt" and is_torch_available():
self.device

if self.torchdynamo is not None:
warnings.warn(
Expand Down

0 comments on commit 7c472e6

Please sign in to comment.