Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: allow custom 4d masks #29618

Merged
merged 1 commit into from
Mar 13, 2024
Merged

Llama: allow custom 4d masks #29618

merged 1 commit into from
Mar 13, 2024

Conversation

gante
Copy link
Member

@gante gante commented Mar 12, 2024

What does this PR do?

Fixes #29525

Reintroduces the ability to pass custom 4D attention masks, which was removed in the static cache transition. The following tests are now passing

RUN_SLOW=1 python -m pytest -v ./tests/test_modeling_utils.py::Mask4DTestFP32
RUN_SLOW=1 python -m pytest -v ./tests/test_modeling_utils.py::Mask4DTestFP16

cc @ArthurZucker after you come back from holidays, have a look at this PR :)

@gante gante requested a review from amyeroberts March 12, 2024 17:42

hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0, position_ids=position_ids_0)[0]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

position_ids is now a "mandatory" input to the attention layer forward

# outs_0.shape == torch.Size([3, 4, 768])

hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
hid_1, attention_mask=causal_mask_1, position_ids=position_ids_1
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the attention layer forward now expects numerical 4D causal masks (as opposed to 2D boolean masks)

)[0]
# outs_1.shape == torch.Size([1, 6, 768])

outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)

def test_inner_model(self):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test was a copy of the test below 🤔

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@amyeroberts amyeroberts left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for reenabling this!

Only question before merge is how come this is only needed for the gemma and llama models?

@gante
Copy link
Member Author

gante commented Mar 13, 2024

Only question before merge is how come this is only needed for the gemma and llama models?

@amyeroberts They are the only models that have received the static cache treatment. The static cache transition did not foresee this case in the original diff :)

We are finalizing support on the generate side before we propagate this pattern across the library! (#29374)

@gante gante merged commit 1e21c4f into huggingface:main Mar 13, 2024
19 checks passed
@gante gante deleted the fix_29525 branch March 13, 2024 15:07
itazap pushed a commit that referenced this pull request May 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

custom 4d attention masks broken by #28937
3 participants