Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama: fix batched generation #29109

Merged
merged 3 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,10 @@ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):

def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
freqs = (self.inv_freq[:, None].float().expand(-1, position_ids.shape[0]) @ (position_ids.float())).t()
freqs = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) @ (
position_ids[:, None, :].float()
)
freqs = freqs.transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
return emb.cos().to(dtype=x.dtype), emb.sin().to(dtype=x.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW for BC we could / should still cache the rope no?
With a property _sin_cache: logger.warning_once(will be removed in 4.39) WDYT?


Expand Down Expand Up @@ -181,6 +184,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1):
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
Comment on lines +207 to +208
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's unsqueeze in the rotary embedding no? or that changes the shape we previously had?

Copy link
Member Author

@gante gante Feb 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same shapes/no shape problems, but unsqueezing here is preferable by some users (see #27117)

q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -1033,6 +1038,7 @@ def _update_causal_mask(self, attention_mask, input_tensor):

batch_size, seq_length = input_tensor.shape[:2]
dtype = input_tensor.dtype
device = input_tensor.device

# support going beyond cached `max_position_embedding`
if seq_length > self.causal_mask.shape[-1]:
Expand All @@ -1048,8 +1054,9 @@ def _update_causal_mask(self, attention_mask, input_tensor):
(self.config.max_position_embeddings, self.config.max_position_embeddings),
fill_value=torch.finfo(dtype).min,
)
causal_mask = torch.triu(mask, diagonal=1).to(dtype)
causal_mask = torch.triu(mask, diagonal=1)

causal_mask = causal_mask.to(dtype=dtype, device=device)
Comment on lines -1051 to +1079
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch!

if attention_mask is not None and attention_mask.dim() == 2:
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
Expand Down
10 changes: 5 additions & 5 deletions tests/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def test_sink_cache_iterative_prompts(self):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is the one that complements the subject you are photograph",
"The best color is the one that complements the skin tone of the",
Copy link
Member Author

@gante gante Feb 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These changed test results were checked against 4b236aed7618d90546cd2e8797dab5b4a24c5dce (the commit before the static caches were introduced).

These tests do batched generation, hence the need to change.

👉 the fact that this PR matches the commit before the static caches in this test means that we can now do left-padded batched generation with the same results!

"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

Expand Down Expand Up @@ -333,18 +333,18 @@ def test_static_cache_greedy_sampling_pad_left(self, attn_implementation):
@parameterized.expand(["eager", "sdpa", "flash_attention_2"])
def test_static_cache_greedy_sampling_pad_right(self, attn_implementation):
EXPECTED_GENERATION = [
"The best color is\n\n\n\n\n\n\n\n\n\n",
"We should not undermind the issues at hand, but address them head on.\nI think",
"The best color isЋ the one that complements the skin tone of",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

-isЋ t
+is t

seems strange 😅 but alright

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe this weird one is a copy/paste

(it has right-padding, so we should expect weird things at generation time)

"We should not undermind the issues at hand.\nWe should not undermind the issues",
]

tokenizer = AutoTokenizer.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf", padding_side="left", pad_token="<s>"
"NousResearch/Llama-2-7b-chat-hf", padding_side="right", pad_token="<s>"
)
model = AutoModelForCausalLM.from_pretrained(
"NousResearch/Llama-2-7b-chat-hf",
torch_dtype=torch.bfloat16,
attn_implementation=attn_implementation,
).to("cuda:1")
).to(torch_device)
inputs = tokenizer(
["The best color is", "We should not undermind the issues at hand"], padding=True, return_tensors="pt"
).to(model.device)
Expand Down
Loading