-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix XLA fp16 and bf16 error checking #18913
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR, but this will result in the error not being raised on TPUs, which is the reason for this error message. There needs to be something that differentiates XLA on GPU from XLA on TPUs here.
XLA already identifies the device type and publishes it in the environment variable for distributed training:
Eg: XRT_MULTI_PROCESSING_DEVICE=GPU:0 Refer to relevant device specific setup in PT-XLA: https://github.com/pytorch/xla/blob/master/torch_xla/distributed/xla_multiprocessing.py#L219-L276 Looking into single worker training now. |
There might be a much easier solution: |
The most systematic logic should be like:
Inspired by this logic, it might be better to have an API to return the current torch_xla device so that we could use it here:
|
I'm fine with both solutions :-) |
I just realized torch_xla already has the API to distinguish different backends.
For GPU, it returns
For TPU, it returns
I'll try to implement it with this API. |
src/transformers/training_args.py
Outdated
@@ -1108,7 +1108,7 @@ def __post_init__(self): | |||
self.framework == "pt" | |||
and is_torch_available() | |||
and (self.device.type != "cuda") | |||
and not (self.device.type == "xla" and "GPU_NUM_DEVICES" in os.environ) | |||
and not (self.device.type == "xla" and xm.xla_real_devices([self.device])[0].split(":")[0] == "GPU") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xm
would be undefined if is_torch_tpu_available
is False, but it's fine if device type won't be xla
in this case. I'm not sure if it's guaranteed tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point. I think "xm" is guaranteed to be imported if the device type is "xla". I just removed torch_xla and tested with native pytorch DDP and it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still, let's wrap this in a function get_xla_device
which will return None, GPU or TPU to be sure?
baca7c4
to
5865dfe
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perfect, thanks a lot for iterating on this!
It looks like "torch.device" as a type hint can cause CI failure if pytorch is not installed. I've removed it. |
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
* Fix XLA fp16 and bf16 error checking * Update src/transformers/training_args.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
This PR fixed a bug introduced in #15022 that will wrongfully throw an error when training with XLA device + fp16.
GPU_NUM_DEVICES
is unset by torch_xla in distributed training here.Tested using the following scripts:
Thanks to @Lokiiiiii @comaniac for reporting this issue.
cc @sgugger