Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable AMP for xla:gpu device in trainer class #15022

Merged
merged 3 commits into from
Jan 13, 2022

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Jan 4, 2022

What does this PR do?

This PR enables AMP in trainer class for xla:gpu device.

Discussion

It looks like the torch_xla support in trainer class is primarily for xla:tpu device.
I found the following features may be useful but not essential and I can include them in this PR if necessary:

  1. Rename tpu to xla in the codebase.
  2. Currently xla device is always turned on when torch_xla is installed. It may be useful to allow users to optionally turn it off without uninstalling torch_xla.
  3. Currently users need to set GPU_NUM_DEVICES manually when using xla:gpu device. It may be useful to set a default value for it when torch_xla and cuda devices are available.

@LysandreJik
Copy link
Member

Oh, interesting! Thanks for your contribution, pinging @sgugger on the issue.

@LysandreJik LysandreJik requested a review from sgugger January 6, 2022 10:17
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure about this PR in the sense that PyTorch XLA support is mainly for TPU, and I don't know if traditional mixed precision training with the gradient scaler will work on TPUs.

So we should probably split the test to detect if we have GPUs available or TPUs. Some of the logic will stay common between the two, but the mixed precision part might only work for XLA GPUs?

Comment on lines +2345 to +2346
if is_torch_tpu_available():
xm.mark_step()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is done by the dataloader (which is wrapped in a ParallelLoader), so it shouldn't be here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional since loss will be materialized in self._nested_gather(loss.repeat(batch_size)) and adding a mark_step here can significantly improve the speed. For example, the evaluation time of bert-base-uncased using run_mlm.py will be reduced from 32.53s to 18.73s by adding this mark_step.

@@ -811,7 +811,7 @@ def __post_init__(self):
raise ValueError("sharded_ddp is not supported with bf16")
if (
is_torch_available()
and self.device.type != "cuda"
and (self.device.type != "cuda" and self.device.type != "xla")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this be false on TPU? The test is there for that purpose since mixed precision training does not work on TPU.

Copy link
Contributor Author

@ymwangg ymwangg Jan 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right that this check won't filter out TPU. We can change it to not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ).

@ymwangg
Copy link
Contributor Author

ymwangg commented Jan 11, 2022

I'm not entirely sure about this PR in the sense that PyTorch XLA support is mainly for TPU, and I don't know if traditional mixed precision training with the gradient scaler will work on TPUs.

So we should probably split the test to detect if we have GPUs available or TPUs. Some of the logic will stay common between the two, but the mixed precision part might only work for XLA GPUs?

Right, XLA:TPU does not support AMP and only XLA:GPU support it.

@sgugger
Copy link
Collaborator

sgugger commented Jan 13, 2022

Right, XLA:TPU does not support AMP and only XLA:GPU support it.

So as I said in my previous comment, could you add a new test is_gpu_xla_available and use this one for the part where you add grad scalers? Otherwise the changes will make the Trainer stop working on TPU.

@ymwangg
Copy link
Contributor Author

ymwangg commented Jan 13, 2022

@sgugger Maybe I'm missing something. Could you elaborate why the changes will make the Trainer stop working on TPU? The code inside self.do_grad_scaling is unreachable if running on TPU since either --fp16 or --bf16 option will raise error on TPU. I tested the following training script on TPU:

python run_mlm.py \
    --model_name_or_path bert-base-uncased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --overwrite_output_dir true \
    --output_dir /tmp/test-mlm \
    --per_device_train_batch_size 10 \
    --do_eval \
    --do_train

With master branch:

WARNING:root:TPU has started up successfully with version pytorch-1.9
WARNING:__main__:Process rank: -1, device: xla:1, n_gpu: 0distributed training: False, 16-bits training: False
...
***** train metrics *****
  epoch                    =        3.0
  train_loss               =     1.7568
  train_runtime            = 0:12:23.47
  train_samples            =       4627
  train_samples_per_second =      18.67
  train_steps_per_second   =      1.868

With this PR:

WARNING:root:TPU has started up successfully with version pytorch-1.9
WARNING:__main__:Process rank: -1, device: xla:1, n_gpu: 0distributed training: False, 16-bits training: False
...
***** train metrics *****
  epoch                    =        3.0
  train_loss               =     1.7577
  train_runtime            = 0:10:19.70
  train_samples            =       4627
  train_samples_per_second =     22.399
  train_steps_per_second   =      2.241

@sgugger
Copy link
Collaborator

sgugger commented Jan 13, 2022

Ah yes, you're right. Thanks for testing!

@sgugger sgugger merged commit 6e058e8 into huggingface:master Jan 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants