Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Jamba batched generation #32914

Merged
merged 5 commits into from
Aug 28, 2024

Conversation

vasqu
Copy link
Contributor

@vasqu vasqu commented Aug 21, 2024

What does this PR do?

Basically a continuation of #32677 which implements the fixes for Jamba this time. Batched generation tests might need to be changed, especially the logits, but not sure how to proceed there as the logits are HW dependent.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@molbap @ArthurZucker

Comment on lines -461 to -505
def test_left_padding_compatibility(self):
r"""
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value.
"""
import inspect
# NOTE: left-padding results in small numerical differences. This is expected.
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535

# First, filter out models that don't support left padding - generative and decoder-only.
# Jamba is a decoder-only architecture
decoder_only_classes = self.all_generative_model_classes

# Then, test left-padding
def _prepare_model_kwargs(input_ids, attention_mask, signature):
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask}
if "position_ids" in signature:
position_ids = torch.cumsum(attention_mask, dim=-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
model_kwargs["position_ids"] = position_ids
if "cache_position" in signature:
cache_position = torch.arange(input_ids.shape[-1], device=torch_device)
model_kwargs["cache_position"] = cache_position
return model_kwargs

for model_class in decoder_only_classes:
config, input_ids, attention_mask = self._get_input_ids_and_config()
model = model_class(config).to(torch_device).eval()
signature = inspect.signature(model.forward).parameters.keys()

# Without padding
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature)
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :]

# With left-padding (length 32)
pad_size = (input_ids.shape[0], 32)
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id
padded_input_ids = torch.cat((padding, input_ids), dim=1)
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1)
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature)
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :]

# They should result in very similar logits
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passed locally without the higher rtol/atol. Will see if the CI agrees.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it does :D

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models).

@vasqu
Copy link
Contributor Author

vasqu commented Aug 21, 2024

CI failure seems unrelated to the PR, some import issues from another model.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What I would find weird is if this does not improve / change the results. Especially for batched generation! The model is tiny random, would be nice if we can run this with the big one 👀

src/transformers/models/jamba/modeling_jamba.py Outdated Show resolved Hide resolved
tests/models/jamba/test_modeling_jamba.py Outdated Show resolved Hide resolved
@vasqu
Copy link
Contributor Author

vasqu commented Aug 23, 2024

Yea, I think it should definitely improve the batched generation. Especially since the test_left_padding_compatibility doesn't need higher atols anymore, padding is not as big of a problem as before (I think).

Too GPU poor to run the Jamba models, iirc they require at least an 80GB Vram GPU 😢 Maybe we could notify the guys behind Jamba? I doubt they are aware of this issue.

@vasqu vasqu force-pushed the jamba-batched-gen-fix branch from cfc73d9 to e2c2341 Compare August 23, 2024 10:08
@vasqu
Copy link
Contributor Author

vasqu commented Aug 23, 2024

#32250 seems to have changed the integration tests cc @gante

Guess we have to redo them again 👀

@@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self):
with torch.no_grad():
logits = self.model(input_ids=inputs["input_ids"]).logits

# TODO fix logits
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more visibility so that I don't forget about it.

Copy link
Member

@gante gante left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for fixing! 🙌

Added a nit to confirm. Pre-approving assuming the logits tests will be addressed (sorry about that :) )

# No need for zeroing states when
# 1. Cached forward
# 2. Attending to all inputs
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?

If it fails, we can add a compile guard, i.e. start the if with not is_torchdynamo_compiling()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tested via the following mini script:

import torch
from transformers import JambaForCausalLM, AutoTokenizer

model_id = "ai21labs/Jamba-tiny-random"
model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to("cuda")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)

# tested on both, batched or non-batched input
#input = tokenizer(["Hey how are you doing on this lovely evening?", "What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
input = tokenizer(["What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")

# tested on both, forward call or generate
out = model(**input)
#out = model.generate(**input, do_sample=False, max_new_tokens=10)

Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's it!

Perfect, thank you for confirming :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM thanks again @vasqu for your great contributions!

@ArthurZucker ArthurZucker merged commit 3bfd3e4 into huggingface:main Aug 28, 2024
18 checks passed
@vasqu vasqu deleted the jamba-batched-gen-fix branch August 28, 2024 11:15
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
* init fix

* fix mask during cached forward, move mask related stuff to own function

* adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp)

* revert overwriting new integration tests

* move some comments to docstring
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants