-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Jamba batched generation #32914
Conversation
def test_left_padding_compatibility(self): | ||
r""" | ||
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences | ||
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value. | ||
""" | ||
import inspect | ||
# NOTE: left-padding results in small numerical differences. This is expected. | ||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 | ||
|
||
# First, filter out models that don't support left padding - generative and decoder-only. | ||
# Jamba is a decoder-only architecture | ||
decoder_only_classes = self.all_generative_model_classes | ||
|
||
# Then, test left-padding | ||
def _prepare_model_kwargs(input_ids, attention_mask, signature): | ||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} | ||
if "position_ids" in signature: | ||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
model_kwargs["position_ids"] = position_ids | ||
if "cache_position" in signature: | ||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device) | ||
model_kwargs["cache_position"] = cache_position | ||
return model_kwargs | ||
|
||
for model_class in decoder_only_classes: | ||
config, input_ids, attention_mask = self._get_input_ids_and_config() | ||
model = model_class(config).to(torch_device).eval() | ||
signature = inspect.signature(model.forward).parameters.keys() | ||
|
||
# Without padding | ||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) | ||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
||
# With left-padding (length 32) | ||
pad_size = (input_ids.shape[0], 32) | ||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id | ||
padded_input_ids = torch.cat((padding, input_ids), dim=1) | ||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) | ||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) | ||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
||
# They should result in very similar logits | ||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Passed locally without the higher rtol/atol. Will see if the CI agrees.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like it does :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models).
CI failure seems unrelated to the PR, some import issues from another model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What I would find weird is if this does not improve / change the results. Especially for batched generation! The model is tiny random, would be nice if we can run this with the big one 👀
Yea, I think it should definitely improve the batched generation. Especially since the Too GPU poor to run the Jamba models, iirc they require at least an 80GB Vram GPU 😢 Maybe we could notify the guys behind Jamba? I doubt they are aware of this issue. |
… batch gen (with todo on logits comp)
cfc73d9
to
e2c2341
Compare
@@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self): | |||
with torch.no_grad(): | |||
logits = self.model(input_ids=inputs["input_ids"]).logits | |||
|
|||
# TODO fix logits |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more visibility so that I don't forget about it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you for fixing! 🙌
Added a nit to confirm. Pre-approving assuming the logits tests will be addressed (sorry about that :) )
# No need for zeroing states when | ||
# 1. Cached forward | ||
# 2. Attending to all inputs | ||
if cache_position[0] > 0 or (attention_mask is not None and torch.all(attention_mask == 1)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?
If it fails, we can add a compile guard, i.e. start the if
with not is_torchdynamo_compiling()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested via the following mini script:
import torch
from transformers import JambaForCausalLM, AutoTokenizer
model_id = "ai21labs/Jamba-tiny-random"
model = JambaForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True).to("cuda")
model = torch.compile(model)
tokenizer = AutoTokenizer.from_pretrained(model_id)
# tested on both, batched or non-batched input
#input = tokenizer(["Hey how are you doing on this lovely evening?", "What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
input = tokenizer(["What is the purpose of life?"], padding=True, return_tensors="pt").to("cuda")
# tested on both, forward call or generate
out = model(**input)
#out = model.generate(**input, do_sample=False, max_new_tokens=10)
Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's it!
Perfect, thank you for confirming :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM thanks again @vasqu for your great contributions!
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
* init fix * fix mask during cached forward, move mask related stuff to own function * adjust tests as left padding does not change logits as much anymore + batch gen (with todo on logits comp) * revert overwriting new integration tests * move some comments to docstring
What does this PR do?
Basically a continuation of #32677 which implements the fixes for Jamba this time. Batched generation tests might need to be changed, especially the logits, but not sure how to proceed there as the logits are HW dependent.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@molbap @ArthurZucker