-
Notifications
You must be signed in to change notification settings - Fork 27.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enable AMP for xla:gpu device in trainer class #15022
Conversation
60ddcda
to
b49ef4a
Compare
Oh, interesting! Thanks for your contribution, pinging @sgugger on the issue. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure about this PR in the sense that PyTorch XLA support is mainly for TPU, and I don't know if traditional mixed precision training with the gradient scaler will work on TPUs.
So we should probably split the test to detect if we have GPUs available or TPUs. Some of the logic will stay common between the two, but the mixed precision part might only work for XLA GPUs?
if is_torch_tpu_available(): | ||
xm.mark_step() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is done by the dataloader (which is wrapped in a ParallelLoader
), so it shouldn't be here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is intentional since loss
will be materialized in self._nested_gather(loss.repeat(batch_size))
and adding a mark_step here can significantly improve the speed. For example, the evaluation time of bert-base-uncased
using run_mlm.py
will be reduced from 32.53s to 18.73s by adding this mark_step.
src/transformers/training_args.py
Outdated
@@ -811,7 +811,7 @@ def __post_init__(self): | |||
raise ValueError("sharded_ddp is not supported with bf16") | |||
if ( | |||
is_torch_available() | |||
and self.device.type != "cuda" | |||
and (self.device.type != "cuda" and self.device.type != "xla") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will this be false on TPU? The test is there for that purpose since mixed precision training does not work on TPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right that this check won't filter out TPU. We can change it to not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
.
Right, XLA:TPU does not support AMP and only XLA:GPU support it. |
356f161
to
0e37806
Compare
So as I said in my previous comment, could you add a new test |
@sgugger Maybe I'm missing something. Could you elaborate why the changes will make the Trainer stop working on TPU? The code inside python run_mlm.py \
--model_name_or_path bert-base-uncased \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--overwrite_output_dir true \
--output_dir /tmp/test-mlm \
--per_device_train_batch_size 10 \
--do_eval \
--do_train With master branch: WARNING:root:TPU has started up successfully with version pytorch-1.9
WARNING:__main__:Process rank: -1, device: xla:1, n_gpu: 0distributed training: False, 16-bits training: False
...
***** train metrics *****
epoch = 3.0
train_loss = 1.7568
train_runtime = 0:12:23.47
train_samples = 4627
train_samples_per_second = 18.67
train_steps_per_second = 1.868 With this PR: WARNING:root:TPU has started up successfully with version pytorch-1.9
WARNING:__main__:Process rank: -1, device: xla:1, n_gpu: 0distributed training: False, 16-bits training: False
...
***** train metrics *****
epoch = 3.0
train_loss = 1.7577
train_runtime = 0:10:19.70
train_samples = 4627
train_samples_per_second = 22.399
train_steps_per_second = 2.241 |
Ah yes, you're right. Thanks for testing! |
What does this PR do?
This PR enables AMP in trainer class for xla:gpu device.
Discussion
It looks like the torch_xla support in trainer class is primarily for xla:tpu device.
I found the following features may be useful but not essential and I can include them in this PR if necessary:
tpu
toxla
in the codebase.GPU_NUM_DEVICES
manually when using xla:gpu device. It may be useful to set a default value for it when torch_xla and cuda devices are available.