Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: fix batched generation #29109

Merged
merged 3 commits into from
Feb 20, 2024
Merged

Llama: fix batched generation #29109

merged 3 commits into from
Feb 20, 2024

Conversation

gante
Copy link
Member

@gante gante commented Feb 19, 2024

What does this PR do?

Fixes batched inference on llama, after the static cache changes were added. For instance, RUN_SLOW=1 py.test tests/test_cache_utils.py::CacheIntegrationTest::test_dynamic_cache_beam_search now passes.

What was wrong?

position_ids has shape [bsz, seq_len]. The line computing freqs was correct for batch size = 1, but incorrect for larger batch sizes: it was summing the values for the different batch members. Therefore, we need to create another dimension to prevent this sum from happening, which is what this PR does.

Throughput impact of changes

None 🙌 [Measured on my end, RTX3090 + TinyLlama/TinyLlama-1.1B-Chat-v1.0]

Before this PR
Screenshot 2024-02-19 at 13 10 54

After this PR
Screenshot 2024-02-19 at 13 43 29

@gante gante marked this pull request as ready for review February 19, 2024 13:53
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"The best color is the one that complements the skin tone of the",
Copy link
Member Author

@gante gante Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changed test results were checked against 4b236aed7618d90546cd2e8797dab5b4a24c5dce (the commit before the static caches were introduced).

These tests do batched generation, hence the need to change.

👉 the fact that this PR matches the commit before the static caches in this test means that we can now do left-padded batched generation with the same results!

@gante gante changed the title batched llama Llama: fix batched generation Feb 19, 2024
@ArthurZucker
Copy link
Collaborator

I'll have to run the benchmark on the A100 to make sure everything is alright but otherwise should be good

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work, nice catch! I'll approve but let me run the benchmark on my side!

Comment on lines +187 to +188
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's unsqueeze in the rotary embedding no? or that changes the shape we previously had?

Copy link
Member Author

@gante gante Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same shapes/no shape problems, but unsqueezing here is preferable by some users (see #27117)

freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ (
position_ids[:, None, :].float()
)
freqs = freqs.transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW for BC we could / should still cache the rope no?
With a property _sin_cache: logger.warning_once(will be removed in 4.39) WDYT?

Comment on lines -1051 to +1059
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
causal_mask = torch.triu(mask, diagonal=1)

causal_mask = causal_mask.to(dtype=dtype, device=device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

@@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
"The best color isЋ the one that complements the skin tone of",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-isЋ t
+is t

seems strange 😅 but alright

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe this weird one is a copy/paste

(it has right-padding, so we should expect weird things at generation time)

@ArthurZucker
Copy link
Collaborator

Alright, no significant slow downs so 🟢 but I can't do naive Dynamic generation with the same script as before:
Probably because I gave position_ids = torch.arange(seq_length, device=device) and they are not unsqueezed

  File "/home/arthur/transformers/../static-kv-cache/clean_bench.py", line 147, in <module>
    outputs = model(input_ids, past_key_values=past_key_values,position_ids=position_ids,cache_position=cache_position, return_dict=False, use_cache = True)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 1155, in forward
    outputs = self.model(
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 995, in forward
    layer_outputs = decoder_layer(
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 721, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 628, in forward
    cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1536, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/arthur/miniconda3/envs/py310/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1545, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/arthur/transformers/src/transformers/models/llama/modeling_llama.py", line 107, in forward
    position_ids[:, None, :].float()
IndexError: too many indices for tensor of dimension 1

@gante
Copy link
Member Author

gante commented Feb 20, 2024

@ArthurZucker regarding the benchmark error: position ids should be a 2D tensor, just like the input ids :D I also had to adapt it on my end

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Feb 20, 2024

Alright if passing a 1d before was erroring out!

@gante gante merged commit 7d312ad into huggingface:main Feb 20, 2024
19 checks passed
@gante gante deleted the batched_llama branch February 20, 2024 10:23
@fxmarty
Copy link
Contributor

fxmarty commented Feb 20, 2024

@gante thanks a lot for this

Comment on lines +129 to +130
self._cos_cached = cos
self._sin_cached = sin
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should. not always overwrite them. We need them accessible but not to be overwritten at the forward

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants