Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generate: support for left-padding on GPTNeoX and Llama #22382

Merged
merged 6 commits into from
Mar 27, 2023

Conversation

gante
Copy link
Member

@gante gante commented Mar 26, 2023

What does this PR do?

As the title indicates, adds left-padding support for GPTNeoX and Llama.

It adds the position_ids input, propagates all the way to the position embedding, and gathers the position embeddings given the value in position_ids. All slow tests are now passing in both models, including the newly added left-padding support test and the GPTNeoX integration test.

Also makes a few changes on Llama to make it more similar to other models 🤗

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Mar 26, 2023

The documentation is not available anymore as the PR was closed or merged.

@gante
Copy link
Member Author

gante commented Mar 26, 2023

The failing CI is fixed by #22383 :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Works for me! I like the addition of the type hints 😉

@@ -649,18 +629,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.llama = LlamaModel(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is breaking with regards to the chekpoints on the hub + the conversion script (renames using model.xxx) so if this is accepted, you can also update the checkpoints!

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this a is a big no no no no.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There doesn't seems to be changes to this file other than that should this be part of the PR?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily -- it's a little typo (for which I probably wouldn't spend the time to open a PR 😅 )

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not entirely sure if this also applies this, but the cross_pt_flax test might end up failing as it happened when I tried to fix GPT-j 😉

@@ -237,7 +237,7 @@ def test_feed_forward_chunking(self):
@require_torch
class GPTNeoXLanguageGenerationTest(unittest.TestCase):
@slow
def test_lm_generate_codegen(self):
def test_lm_generate_gptneox(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error originates from me, though :(

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for working on this. The change model->llama needs to be reverted as it will break all existing repos of Llama models on the Hub.

Comment on lines 358 to 361
base_model_prefix = "model"
base_model_prefix = "llama"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Absolutely not. We are not breaking all repos on the Hub with a Llama model.

@@ -649,18 +629,18 @@ class LlamaForCausalLM(LlamaPreTrainedModel):

def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.llama = LlamaModel(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this a is a big no no no no.

@gante
Copy link
Member Author

gante commented Mar 27, 2023

@ArthurZucker @sgugger woopsie, I forgot that it affected the weight loading code -- I come from a place where weight names have to be specified 👼 Reverted (self.llama is self.model again)!

@gante gante merged commit 7dcd870 into huggingface:main Mar 27, 2023
@gante gante deleted the llama_gptneox_left_padding branch March 27, 2023 14:48
@jquesnelle
Copy link

It appears as if this may have broken FSDP. For example, as specified in the Alpaca repo, finetuning with --fsdp "full_sh ard auto_wrap" --fsdp_transformer_layer_cls_to_wrap LlamaDecoderLayer worked before this commit, but after it gives the error such as:

File "/home/fsuser/.local/lib/python3.8/site-packages/transformers/models/llama/modeling_llama.py", line 313, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/fsuser/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: forward() got an unexpected keyword argument 'position_ids'

Reverting the commit fixes it, although perhaps the problem is with accelerate not supporting position_ids? cc: @ArthurZucker

@gante
Copy link
Member Author

gante commented Mar 29, 2023

@jquesnelle can you paste the full stack trace? It would allow us to find the root cause :D (maybe, as you mention, the problem is in accelerate... or maybe it comes from the Alpaca repo!)

raghavanone pushed a commit to raghavanone/transformers that referenced this pull request Apr 5, 2023
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 9, 2023
xloem pushed a commit to xloem/transformers that referenced this pull request Apr 10, 2023
novice03 pushed a commit to novice03/transformers that referenced this pull request Jun 23, 2023
@neggert
Copy link

neggert commented Jun 30, 2023

I'm seeing a pretty significant performance hit on RedPajama-7b-chat that I think is due to this change. I ran the PyTorch profiler and all of the repeat operators in apply_rotary_pos_emb are expensive and run mostly on CPU. Reverting to transformers 4.27.x resolves the performance issue.

@ArthurZucker
Copy link
Collaborator

You should try the main branch, #22785 removed the repeat solving this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants