Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix XLA fp16 and bf16 error checking #18913

Merged
merged 2 commits into from
Sep 7, 2022

Conversation

ymwangg
Copy link
Contributor

@ymwangg ymwangg commented Sep 7, 2022

This PR fixed a bug introduced in #15022 that will wrongfully throw an error when training with XLA device + fp16. GPU_NUM_DEVICES is unset by torch_xla in distributed training here.

Tested using the following scripts:

GPU_NUM_DEVICES=8 python -m torch_xla.distributed.xla_spawn --num_gpus 8 language-modeling/run_mlm.py \
    --model_name_or_path bert-base-uncased \
    --dataset_name wikitext \
    --dataset_config_name wikitext-2-raw-v1 \
    --overwrite_output_dir true \
    --output_dir /tmp/test-mlm \
    --per_device_train_batch_size 10 \
    --do_eval \
    --fp16 true \
    --do_train \
    --num_train_epochs 3 \
    --optim adamw_torch_xla

Thanks to @Lokiiiiii @comaniac for reporting this issue.

cc @sgugger

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Sep 7, 2022

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR, but this will result in the error not being raised on TPUs, which is the reason for this error message. There needs to be something that differentiates XLA on GPU from XLA on TPUs here.

@Lokiiiiii
Copy link

Lokiiiiii commented Sep 7, 2022

XLA already identifies the device type and publishes it in the environment variable for distributed training:

XRT_MULTI_PROCESSING_DEVICE="device:ordinal"

Eg: XRT_MULTI_PROCESSING_DEVICE=GPU:0
Eg: XRT_MULTI_PROCESSING_DEVICE=TPU:0

Refer to relevant device specific setup in PT-XLA: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_multiprocessing.py#L219-L276

Looking into single worker training now.

@Lokiiiiii
Copy link

There might be a much easier solution:
The presence of environment variables of TPU_NUM_DEVICES or XRT_TPU_CONFIG indicates a TPU environment.
The presence of environment variable GPU_NUM_DEVICES indicates a GPU environment.

@comaniac
Copy link
Contributor

comaniac commented Sep 7, 2022

The most systematic logic should be like:

and not (self.device.type == "xla" and is_torch_tpu_available() and xm.xla_device() == gpu)

Inspired by this logic, it might be better to have an API to return the current torch_xla device so that we could use it here:

and not(self.device_type "xla" and get_torch_xla_device() != gpu)

@sgugger
Copy link
Collaborator

sgugger commented Sep 7, 2022

I'm fine with both solutions :-)

@ymwangg
Copy link
Contributor Author

ymwangg commented Sep 7, 2022

I just realized torch_xla already has the API to distinguish different backends.

torch_xla._XLAC._xla_real_devices([str(device)])

For GPU, it returns

['GPU:0']

For TPU, it returns

['TPU:0']

I'll try to implement it with this API.

@@ -1108,7 +1108,7 @@ def __post_init__(self):
self.framework == "pt"
and is_torch_available()
and (self.device.type != "cuda")
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ)
and not (self.device.type == "xla" and xm.xla_real_devices([self.device])[0].split(":")[0] == "GPU")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xm would be undefined if is_torch_tpu_available is False, but it's fine if device type won't be xla in this case. I'm not sure if it's guaranteed tho.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I think "xm" is guaranteed to be imported if the device type is "xla". I just removed torch_xla and tested with native pytorch DDP and it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still, let's wrap this in a function get_xla_device which will return None, GPU or TPU to be sure?

@ymwangg ymwangg force-pushed the fix_fp16_error_checking branch from baca7c4 to 5865dfe Compare September 7, 2022 19:05
Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perfect, thanks a lot for iterating on this!

@ymwangg
Copy link
Contributor Author

ymwangg commented Sep 7, 2022

It looks like "torch.device" as a type hint can cause CI failure if pytorch is not installed. I've removed it.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
@sgugger sgugger merged commit 6394221 into huggingface:main Sep 7, 2022
oneraghavan pushed a commit to oneraghavan/transformers that referenced this pull request Sep 26, 2022
* Fix XLA fp16 and bf16 error checking

* Update src/transformers/training_args.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants