Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add position ids in forward pass to opt model #33121

Merged

Conversation

avishaiElmakies
Copy link
Contributor

@avishaiElmakies avishaiElmakies commented Aug 26, 2024

What does this PR do?

This pull request adds position_ids to the forward of OPT in a similar fashion to gemma and llama. #32937

Some models didn't have an argument for position_ids in their forward pass.

There are two main reasons we would like for all LM models to get positions ids.

  1. to have the API be consistent with all models.
  2. position_ids are very important if you want to use flash-attention without padding, during training. if i want to be able to pack two or more sentences in the same sequence. I would like to know that the model handles the sentences accordingly and treats each sentence as it's own different sentence. flash-attention code uses position_ids to check if some sequences are packed and runs an appropriate function to make sure there is no cross-example contamination. but without this, the model can't use this feature. the code always checks if position_ids is not None:

https://github.com/huggingface/transformers/blob/v4.44.1/src/transformers/modeling_flash_attention_utils.py#L270

This handles only OPT, so i can start small and get some feedback.

changes:

  • changed OPTLearnedPositionalEmbedding forward method to get position_ids instead of attention_mask.
  • changed all forward passes in the file to get position_ids and pass them forward.
  • update prepare_inputs_for_generation to get position_ids.
  • if position_ids are None create position_ids based on attention (similar to original version, so it should work the same if position_ids are not given)
  • update relevent docs

a few notes:

  • this model because of the use of offset in OPTLearnedPositionalEmbedding needs to represent it's padding token position as -1. I think this is fine to keep compatibility with the weights and everything.
  • BioGPT seems to copy OPT's positional embedding class. I needed to remove the comment to disable that (since it no longer copies). this is needed for make fixup.

feature-request #32937

Before submitting

  • [] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • [] Did you write any new necessary tests?

@ArthurZucker would love some feedback.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! in general the trick is that adding a new argument needs to be done at the end of the forward pass otherwise you are breaking the model for people who directly call the model

src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
@@ -46,7 +46,6 @@
_CONFIG_FOR_DOC = "BioGptConfig"


# Copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just add a TODO here or also update that model

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should I return the comment as well? it causes a fail when using make fixup.
I tried to look to update BioGPT, but it seems to be a lot of code from different models, so i didn't know if i should have touched it. I can work on it next.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes !

@@ -71,17 +79,10 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
def forward(self, position_ids: torch.LongTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is kind of a breaking change for this module 😓

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is this a problem? the weights are the same so loading should work. and this module should not be used by outside code so it is not supposed to break anything.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but it has caused issue in the past 😉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, how do you think I should do it? The module should get position_ids to work with packed sentences. Should I add position_ids as the last argument with None as default?

src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
@avishaiElmakies
Copy link
Contributor Author

avishaiElmakies commented Aug 27, 2024

I tried to keep the API as similar as possible to llama and gamma. they both put it third. isn't it kinda of a problem with AutoModel to have them at different placements? it means that you can't load using AutoModel and use the model without keyword arguments?

Thanks for the feedback!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @gante WDYT about this? In general IMO we should just run basic position ids init. Tho taking padding into account should be "alright", it's already done in generate, this would help for forward and training.

Just need to be careful as we also want to support packing

src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
@@ -71,17 +79,10 @@ def __init__(self, num_embeddings: int, embedding_dim: int):
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
def forward(self, position_ids: torch.LongTensor):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah but it has caused issue in the past 😉

src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
@avishaiElmakies
Copy link
Contributor Author

about the forward call for the embedding layer. I think it has to take position ids as an argument. otherwise it will not work with packed sentences.

@avishaiElmakies
Copy link
Contributor Author

@ArthurZucker I am thinking that maybe the best solution for the embedding layer is to add position_ids as an arg to the forward pass with default None. this is probably backward compatible, but will still help with packed sentences. the problem is probably that the code will not be very nice

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ArthurZucker I'm pro position_ids as it standardizes OPT wrt other models 🙌

@avishaiElmakies Thank you for adding the fix 🤗 Have a look at unresolved comments (you'd be surprised with how easy it is to break code for other external libraries, hyrum's law definitely applies to transformers)

@avishaiElmakies
Copy link
Contributor Author

@gante, thanks! Happy to contribute.

I would love some guidance on the last two comments left, what should I do with the position_ids in the embedding module. In my opion it should be able to get position_ids to work with packed sentences. Maybe a last argument with default None and a check?

And about the one liners, would love some guidance on that

@avishaiElmakies
Copy link
Contributor Author

@ArthurZucker would love some guidance here so i can finish and move on to other models.

@avishaiElmakies avishaiElmakies mentioned this pull request Sep 22, 2024
5 tasks
@avishaiElmakies
Copy link
Contributor Author

@ArthurZucker changed what you said and refactored the embedding class to be backward compatible. would love some feedback

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to merge @gante if it's alright with you! 🤗 and thanks for your contribution!

@@ -46,7 +46,6 @@
_CONFIG_FOR_DOC = "BioGptConfig"


# Copied from transformers.models.opt.modeling_opt.OPTLearnedPositionalEmbedding with OPT->BioGpt
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes !

src/transformers/models/opt/modeling_opt.py Show resolved Hide resolved
src/transformers/models/opt/modeling_opt.py Outdated Show resolved Hide resolved
@ArthurZucker ArthurZucker merged commit 4953ddf into huggingface:main Oct 7, 2024
16 of 18 checks passed
@avishaiElmakies avishaiElmakies deleted the add_position_ids_to_opt branch October 7, 2024 07:40
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

NielsRogge pushed a commit to NielsRogge/transformers that referenced this pull request Oct 21, 2024
* start working on adding position ids

* add docs

* Refactor modeling_biogpt.py and modeling_opt.py for code consistency

* fix 2 PR comments

* move position_ids to end of args

* remove trailing white space

* add comment with TODO

* bug fix gradient checkpointing

* fixup

* missed on position_ids

* remove _attention_to_position_ids and refactor embedding class

* remove redundent code

---------

Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* start working on adding position ids

* add docs

* Refactor modeling_biogpt.py and modeling_opt.py for code consistency

* fix 2 PR comments

* move position_ids to end of args

* remove trailing white space

* add comment with TODO

* bug fix gradient checkpointing

* fixup

* missed on position_ids

* remove _attention_to_position_ids and refactor embedding class

* remove redundent code

---------

Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* start working on adding position ids

* add docs

* Refactor modeling_biogpt.py and modeling_opt.py for code consistency

* fix 2 PR comments

* move position_ids to end of args

* remove trailing white space

* add comment with TODO

* bug fix gradient checkpointing

* fixup

* missed on position_ids

* remove _attention_to_position_ids and refactor embedding class

* remove redundent code

---------

Co-authored-by: Avishai Elmakies <avishai.elma@cs.huji.ac.il>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants