diff --git a/docs/source/en/main_classes/trainer.mdx b/docs/source/en/main_classes/trainer.mdx index 44c9d1d4b01973..ab942a2c1a7de6 100644 --- a/docs/source/en/main_classes/trainer.mdx +++ b/docs/source/en/main_classes/trainer.mdx @@ -591,6 +591,66 @@ More details in this [issues](https://github.com/pytorch/pytorch/issues/75676). More details mentioned in this [issue](https://github.com/pytorch/pytorch/issues/76501) (`The original model parameters' .grads are not set, meaning that they cannot be optimized separately (which is why we cannot support multiple parameter groups)`). +### Using Trainer for accelerated PyTorch Training on Mac + +With PyTorch v1.12 release, developers and researchers can take advantage of Apple silicon GPUs for significantly faster model training. +This unlocks the ability to perform machine learning workflows like prototyping and fine-tuning locally, right on Mac. +Apple's Metal Performance Shaders (MPS) as a backend for PyTorch enables this and can be used via the new `"mps"` device. +This will map computational graphs and primitives on the MPS Graph framework and tuned kernels provided by MPS. +For more information please refer official documents [Introducing Accelerated PyTorch Training on Mac](https://pytorch.org/blog/introducing-accelerated-pytorch-training-on-mac/) +and [MPS BACKEND](https://pytorch.org/docs/stable/notes/mps.html). + + + +We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing) on your MacOS machine. +It has major fixes related to model correctness and performance improvements for transformer based models. +Please refer to https://github.com/pytorch/pytorch/issues/82707 for more details. + + + +**Benefits of Training and Inference using Apple Silicon Chips** + +1. Enables users to train larger networks or batch sizes locally +2. Reduces data retrieval latency and provides the GPU with direct access to the full memory store due to unified memory architecture. +Therefore, improving end-to-end performance. +3. Reduces costs associated with cloud-based development or the need for additional local GPUs. + +**Pre-requisites**: To install torch with mps support, +please follow this nice medium article [GPU-Acceleration Comes to PyTorch on M1 Macs](https://medium.com/towards-data-science/gpu-acceleration-comes-to-pytorch-on-m1-macs-195c399efcc1). + +**Usage**: +User has to just pass `--use_mps_device` argument. +For example, you can run the offical Glue text classififcation task (from the root folder) using Apple Silicon GPU with below command: + +```bash +export TASK_NAME=mrpc + +python examples/pytorch/text-classification/run_glue.py \ + --model_name_or_path bert-base-cased \ + --task_name $TASK_NAME \ + --do_train \ + --do_eval \ + --max_seq_length 128 \ + --per_device_train_batch_size 32 \ + --learning_rate 2e-5 \ + --num_train_epochs 3 \ + --output_dir /tmp/$TASK_NAME/ \ + --use_mps_device \ + --overwrite_output_dir +``` + +**A few caveats to be aware of** + +1. Some PyTorch operations have not been implemented in mps and will throw an error. +One way to get around that is to set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1`, +which will fallback to CPU for these operations. It still throws a UserWarning however. +2. Distributed setups `gloo` and `nccl` are not working with `mps` device. +This means that currently only single GPU of `mps` device type can be used. + +Finally, please, remember that, 🤗 `Trainer` only integrates MPS backend, therefore if you +have any problems or questions with regards to MPS backend usage, please, +file an issue with [PyTorch GitHub](https://github.com/pytorch/pytorch/issues). + Sections that were moved: [ DeepSpeed diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e9a9f8f0043a79..7a23281d82ee21 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -22,6 +22,8 @@ from pathlib import Path from typing import Any, Dict, List, Optional, Union +from packaging import version + from .debug_utils import DebugOption from .trainer_utils import ( EvaluationStrategy, @@ -478,6 +480,8 @@ class TrainingArguments: are also available. See the [Ray documentation]( https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for more options. + use_mps_device (`bool`, *optional*, defaults to `False`): + Whether to use Apple Silicon chip based `mps` device. """ output_dir: str = field( @@ -630,6 +634,9 @@ class TrainingArguments: }, ) no_cuda: bool = field(default=False, metadata={"help": "Do not use CUDA even when it is available"}) + use_mps_device: bool = field( + default=False, metadata={"help": "Whether to use Apple Silicon chip based `mps` device."} + ) seed: int = field(default=42, metadata={"help": "Random seed that will be set at the beginning of training."}) data_seed: Optional[int] = field(default=None, metadata={"help": "Random seed to be used with data samplers."}) jit_mode_eval: bool = field( @@ -1368,16 +1375,42 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", self.local_rank) self._n_gpu = 1 elif self.local_rank == -1: - # if n_gpu is > 1 we'll use nn.DataParallel. - # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` - # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will - # trigger an error that a device index is missing. Index 0 takes into account the - # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` - # will use the first GPU in that env, i.e. GPU#1 - device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at - # the default value. - self._n_gpu = torch.cuda.device_count() + if self.use_mps_device: + if not torch.backends.mps.is_available(): + if not torch.backends.mps.is_built(): + raise AssertionError( + "MPS not available because the current PyTorch install was not " + "built with MPS enabled. Please install torch version >=1.12.0 on " + "your Apple silicon Mac running macOS 12.3 or later with a native " + "version (arm64) of Python" + ) + else: + raise AssertionError( + "MPS not available because the current MacOS version is not 12.3+ " + "and/or you do not have an MPS-enabled device on this machine." + ) + else: + if not version.parse(version.parse(torch.__version__).base_version) > version.parse("1.12.0"): + warnings.warn( + "We strongly recommend to install PyTorch >= 1.13 (nightly version at the time of writing)" + " on your MacOS machine. It has major fixes related to model correctness and performance" + " improvements for transformer based models. Please refer to" + " https://github.com/pytorch/pytorch/issues/82707 for more details." + ) + device = torch.device("mps") + self._n_gpu = 1 + + else: + # if n_gpu is > 1 we'll use nn.DataParallel. + # If you only want to use a specific subset of GPUs use `CUDA_VISIBLE_DEVICES=0` + # Explicitly set CUDA to the first (index 0) CUDA device, otherwise `set_device` will + # trigger an error that a device index is missing. Index 0 takes into account the + # GPUs available in the environment, so `CUDA_VISIBLE_DEVICES=1,2` with `cuda:0` + # will use the first GPU in that env, i.e. GPU#1 + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + # Sometimes the line in the postinit has not been run before we end up here, so just checking we're not at + # the default value. + self._n_gpu = torch.cuda.device_count() else: # Here, we'll use torch.distributed. # Initializes the distributed backend which will take care of synchronizing nodes/GPUs