-
Notifications
You must be signed in to change notification settings - Fork 27k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Llama: fix batched generation #29109
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,11 +101,34 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) | ||
self.register_buffer("inv_freq", inv_freq, persistent=False) | ||
|
||
@property | ||
def sin_cached(self): | ||
logger.warning_once( | ||
"The sin_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " | ||
"the forward method of RoPE from now on instead." | ||
) | ||
return self._sin_cached | ||
|
||
@property | ||
def cos_cached(self): | ||
logger.warning_once( | ||
"The cos_cached attribute will be removed in 4.40. Bear in mind that its contents changed in v4.38. Use " | ||
"the forward method of RoPE from now on instead." | ||
) | ||
return self._cos_cached | ||
|
||
def forward(self, x, position_ids, seq_len=None): | ||
# x: [bs, num_attention_heads, seq_len, head_size] | ||
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t() | ||
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | ||
position_ids_expanded = position_ids[:, None, :].float() | ||
freqs = (inv_freq_expanded @ position_ids_expanded).transpose(1, 2) | ||
emb = torch.cat((freqs, freqs), dim=-1) | ||
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype) | ||
cos = emb.cos().to(dtype=x.dtype) | ||
sin = emb.sin().to(dtype=x.dtype) | ||
# backwards compatibility | ||
self._cos_cached = cos | ||
self._sin_cached = sin | ||
return cos, sin | ||
|
||
|
||
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding): | ||
|
@@ -181,6 +204,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): | |
Returns: | ||
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | ||
""" | ||
cos = cos.unsqueeze(unsqueeze_dim) | ||
sin = sin.unsqueeze(unsqueeze_dim) | ||
Comment on lines
+207
to
+208
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's unsqueeze in the rotary embedding no? or that changes the shape we previously had? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same shapes/no shape problems, but unsqueezing here is preferable by some users (see #27117) |
||
q_embed = (q * cos) + (rotate_half(q) * sin) | ||
k_embed = (k * cos) + (rotate_half(k) * sin) | ||
return q_embed, k_embed | ||
|
@@ -1033,6 +1058,7 @@ def _update_causal_mask(self, attention_mask, input_tensor): | |
|
||
batch_size, seq_length = input_tensor.shape[:2] | ||
dtype = input_tensor.dtype | ||
device = input_tensor.device | ||
|
||
# support going beyond cached `max_position_embedding` | ||
if seq_length > self.causal_mask.shape[-1]: | ||
|
@@ -1048,8 +1074,9 @@ def _update_causal_mask(self, attention_mask, input_tensor): | |
(self.config.max_position_embeddings, self.config.max_position_embeddings), | ||
fill_value=torch.finfo(dtype).min, | ||
) | ||
causal_mask = torch.triu(mask, diagonal=1).to(dtype) | ||
causal_mask = torch.triu(mask, diagonal=1) | ||
|
||
causal_mask = causal_mask.to(dtype=dtype, device=device) | ||
Comment on lines
-1051
to
+1079
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! |
||
if attention_mask is not None and attention_mask.dim() == 2: | ||
mask_length = attention_mask.shape[-1] | ||
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self): | |
@parameterized.expand(["eager", "sdpa", "flash_attention_2"]) | ||
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): | ||
EXPECTED_GENERATION = [ | ||
"The best color is the one that complements the subject you are photograph", | ||
"The best color is the one that complements the skin tone of the", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changed test results were checked against These tests do batched generation, hence the need to change. 👉 the fact that this PR matches the commit before the static caches in this test means that we can now do left-padded batched generation with the same results! |
||
"We should not undermind the issues at hand.\nWe should not undermind the issues", | ||
] | ||
|
||
|
@@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation): | |
@parameterized.expand(["eager", "sdpa", "flash_attention_2"]) | ||
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation): | ||
EXPECTED_GENERATION = [ | ||
"The best color is\n\n\n\n\n\n\n\n\n\n", | ||
"We should not undermind the issues at hand, but address them head on.\nI think", | ||
"The best color isЋ the one that complements the skin tone of", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. -isЋ t
+is t seems strange 😅 but alright There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hehe this weird one is a copy/paste (it has right-padding, so we should expect weird things at generation time) |
||
"We should not undermind the issues at hand.\nWe should not undermind the issues", | ||
] | ||
|
||
tokenizer = AutoTokenizer.from_pretrained( | ||
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>" | ||
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>" | ||
) | ||
model = AutoModelForCausalLM.from_pretrained( | ||
"NousResearch/Llama-2-7b-chat-hf", | ||
torch_dtype=torch.bfloat16, | ||
attn_implementation=attn_implementation, | ||
).to("cuda:1") | ||
).to(torch_device) | ||
inputs = tokenizer( | ||
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt" | ||
).to(model.device) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should. not always overwrite them. We need them accessible but not to be overwritten at the forward