Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix device issue in OpenLlamaModelTest::test_model_parallelism #24195

Merged
merged 1 commit into from
Jun 12, 2023

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Jun 12, 2023

What does this PR do?

See the comments in the changes.

Currently, CI has a failure

src/transformers/models/open_llama/modeling_open_llama.py:740: in forward
    logits = torch.einsum("blh,vh->blv", hidden_states, self.model.embed_tokens.weight)
   ...
   ...
   RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument mat2 in method wrapper_CUDA_bmm)

@@ -736,12 +736,16 @@ def forward(

hidden_states = outputs[0]
if self.config.shared_input_output_embedding:
logits = torch.einsum("blh,vh->blv", hidden_states, self.model.embed_tokens.weight)
logits = torch.einsum(
"blh,vh->blv", hidden_states.to(self.model.embed_tokens.weight.device), self.model.embed_tokens.weight
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

send hidden_states (lighter) to embedding's (heavy) device.

else:
logits = self.lm_head(hidden_states)

loss = None
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just copied from other modeling files.

@ydshieh ydshieh requested a review from sgugger June 12, 2023 12:25
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jun 12, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fixes!

@ydshieh ydshieh merged commit a9cdb05 into main Jun 12, 2023
@ydshieh ydshieh deleted the fix_openllama branch June 12, 2023 13:21
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
…gingface#24195)

fix

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants