Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error when device mismatches when calling gather() on CUDA #2180

Merged
merged 2 commits into from
Nov 29, 2023

Conversation

muellerzr
Copy link
Collaborator

What does this PR do?

This PR adds a new explicit error when a user tries to call .gather() in a GPU scenario and the device of the passed in tensor != the device in PartialState (aka CUDA). Avoids users getting err:

  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2897, in all_gather_into_tensor
    return honor_type(
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/accelerate/utils/operations.py", line 83, in honor_type
    return type(obj)(generator)
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/accelerate/utils/operations.py", line 112, in <genexpr>
    recursively_apply(
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/accelerate/utils/operations.py", line 128, in recursively_apply
    return func(data, *args, **kwargs)
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/accelerate/utils/operations.py", line 307, in _gpu_gather_one
    gather_op(output_tensors, tensor)
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/torch/distributed/c10d_logger.py", line 47, in wrapper
    return func(*args, **kwargs)
  File "/home/student/Experiemnts/MultiLabelClassification_LLMs/.mlc/lib/python3.10/site-packages/torch/distributed/distributed_c10d.py", line 2897, in all_gather_into_tensor
    work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: Tensors must be CUDA and dense
    work = group._allgather_base(output_tensor, input_tensor)
RuntimeError: Tensors must be CUDA and dense

And instead gives them a much clearer err:

"One or more of the tensors passed to `gather` was not on the GPU while the `Accelerator` is configured for CUDA. "
                "Please move it to the GPU before calling `gather`."

Fixes # (issue)

https://discuss.huggingface.co/t/problem-with-model-inference-using-accelerate/63078

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@BenjaminBossan @SunMarc

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 22, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 for better error messages.

I have two nits, not blockers for the PR.

src/accelerate/utils/operations.py Outdated Show resolved Hide resolved
@@ -298,6 +298,13 @@ def _gpu_gather_one(tensor):
if not tensor.is_contiguous():
tensor = tensor.contiguous()

# Check if `tensor` is not on CUDA
if state.device.type == "cuda" and tensor.device.type != "cuda":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there other device mismatches that could be checked here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@muellerzr @BenjaminBossan Can this logic be extended to other devices ? Seems like a generic exception handling case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it can, for now it’s just a thing on CUDA but if it’s useful for other devices that can be added. This is just a known base case

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM !

Co-authored-by: Benjamin Bossan <BenjaminBossan@users.noreply.github.com>
@muellerzr muellerzr merged commit 1516379 into main Nov 29, 2023
25 checks passed
@muellerzr muellerzr deleted the check-device branch November 29, 2023 17:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants