-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix: Jamba batched generation #32914
Changes from all commits
c02bf38
fcd6d20
e2c2341
e81eee2
43e08dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -458,51 +458,6 @@ def test_attention_outputs(self): | |
[self.model_tester.num_attention_heads, encoder_seq_length, encoder_key_length], | ||
) | ||
|
||
def test_left_padding_compatibility(self): | ||
r""" | ||
Overriding the test_left_padding_compatibility test as the mamba layers accentuate the numerical differences | ||
effect of the left padding discussed in the issue in the note. Using a more permissive tolerance value. | ||
""" | ||
import inspect | ||
# NOTE: left-padding results in small numerical differences. This is expected. | ||
# See https://github.com/huggingface/transformers/issues/25420#issuecomment-1775317535 | ||
|
||
# First, filter out models that don't support left padding - generative and decoder-only. | ||
# Jamba is a decoder-only architecture | ||
decoder_only_classes = self.all_generative_model_classes | ||
|
||
# Then, test left-padding | ||
def _prepare_model_kwargs(input_ids, attention_mask, signature): | ||
model_kwargs = {"input_ids": input_ids, "attention_mask": attention_mask} | ||
if "position_ids" in signature: | ||
position_ids = torch.cumsum(attention_mask, dim=-1) - 1 | ||
position_ids.masked_fill_(attention_mask == 0, 1) | ||
model_kwargs["position_ids"] = position_ids | ||
if "cache_position" in signature: | ||
cache_position = torch.arange(input_ids.shape[-1], device=torch_device) | ||
model_kwargs["cache_position"] = cache_position | ||
return model_kwargs | ||
|
||
for model_class in decoder_only_classes: | ||
config, input_ids, attention_mask = self._get_input_ids_and_config() | ||
model = model_class(config).to(torch_device).eval() | ||
signature = inspect.signature(model.forward).parameters.keys() | ||
|
||
# Without padding | ||
model_kwargs = _prepare_model_kwargs(input_ids, attention_mask, signature) | ||
next_logits_wo_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
||
# With left-padding (length 32) | ||
pad_size = (input_ids.shape[0], 32) | ||
padding = torch.ones(pad_size, dtype=input_ids.dtype, device=torch_device) * config.pad_token_id | ||
padded_input_ids = torch.cat((padding, input_ids), dim=1) | ||
padded_attention_mask = torch.cat((torch.zeros_like(padding), attention_mask), dim=1) | ||
model_kwargs = _prepare_model_kwargs(padded_input_ids, padded_attention_mask, signature) | ||
next_logits_with_padding = model(**model_kwargs).logits[:, -1, :] | ||
|
||
# They should result in very similar logits | ||
self.assertTrue(torch.allclose(next_logits_wo_padding, next_logits_with_padding, atol=3e-3)) | ||
|
||
Comment on lines
-461
to
-505
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Passed locally without the higher rtol/atol. Will see if the CI agrees. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like it does :D There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Keeping it open for visibility: Left padding works fine now, it was an issue of how padding has been handled in general (for mamba-related models). |
||
@require_flash_attn | ||
@require_torch_gpu | ||
@require_bitsandbytes | ||
|
@@ -692,7 +647,7 @@ def test_simple_generate(self): | |
EXPECTED_LOGITS_NO_GRAD = torch.tensor( | ||
[ | ||
0.0134, -0.2197, 0.0396, -0.1011, 0.0459, 0.2793, -0.1465, 0.1660, | ||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, | ||
-0.2930, -0.0278, 0.0269, -0.5586, -0.2109, -0.1426, -0.1553, 0.1279, | ||
0.0713, 0.2246, 0.1660, -0.2314, -0.1187, -0.1162, -0.1377, 0.0292, | ||
0.1245, 0.2275, 0.0374, 0.1089, -0.1348, -0.2305, 0.1484, -0.3906, | ||
0.1709, -0.4590, -0.0447, 0.2422, 0.1592, -0.1855, 0.2441, -0.0562 | ||
|
@@ -737,10 +692,11 @@ def test_simple_batched_generate_with_padding(self): | |
with torch.no_grad(): | ||
logits = self.model(input_ids=inputs["input_ids"]).logits | ||
|
||
# TODO fix logits | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For more visibility so that I don't forget about it. |
||
EXPECTED_LOGITS_NO_GRAD_0 = torch.tensor( | ||
[ | ||
0.0166, -0.2227, 0.0396, -0.1035, 0.0459, 0.2754, -0.1445, 0.1641, | ||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, | ||
-0.2910, -0.0273, 0.0227, -0.5547, -0.2139, -0.1396, -0.1582, 0.1289, | ||
0.0713, 0.2256, 0.1699, -0.2295, -0.1182, -0.1167, -0.1387, 0.0261, | ||
0.1270, 0.2285, 0.0403, 0.1108, -0.1318, -0.2334, 0.1455, -0.3945, | ||
0.1729, -0.4609, -0.0410, 0.2412, 0.1572, -0.1895, 0.2402, -0.0583 | ||
|
@@ -749,7 +705,7 @@ def test_simple_batched_generate_with_padding(self): | |
|
||
EXPECTED_LOGITS_NO_GRAD_1 = torch.tensor( | ||
[ | ||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, | ||
-0.1318, 0.2354, -0.4160, -0.0325, -0.0461, 0.0342, 0.2578, 0.0874, | ||
0.1484, 0.2266, -0.1182, -0.1396, -0.1494, -0.1089, -0.0019, -0.2852, | ||
0.1973, -0.2676, 0.0586, -0.1992, -0.2520, -0.1147, -0.1973, 0.2129, | ||
0.0520, 0.1699, 0.1816, 0.1289, 0.1699, -0.1216, -0.2656, -0.2891, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect this line will fail at compilation time (data-dependent conditional branch). Can you confirm, i.e. try running a compiled forward pass?
If it fails, we can add a compile guard, i.e. start the
if
withnot is_torchdynamo_compiling()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested via the following mini script:
Haven't encountered any compilation errors locally, so seems to be fine. Is this what you had in mind to test compilation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's it!
Perfect, thank you for confirming :)