Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Starcoder2 : KVCache and flash attention (FusedSDPA) enablement #1149

Merged
merged 21 commits into from
Aug 6, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Update GaudiStarcoder2ForCausalLM
12010486 committed Jun 26, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit fd023f3612d655eff44249a215eb12cec30dad32
Original file line number Diff line number Diff line change
@@ -377,6 +377,15 @@ def gaudi_starcoder2_model_forward(


class GaudiStarcoder2ForCausalLM(Starcoder2ForCausalLM):
def allocate_kv_cache(self, batch_size, max_seq_len, inp_seq_len):
self.model.allocate_kv_cache(batch_size, max_seq_len, inp_seq_len)

def reorder_kv_cache(self, beam_idx: torch.LongTensor):
return self.model.reorder_kv_cache(beam_idx)

def update_sincos_cache(self, seq_len):
self.model.update_sincos_cache(seq_len)

def forward(
self,
input_ids: torch.LongTensor = None,
@@ -390,19 +399,28 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
token_idx: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
trim_logits: Optional[bool] = False,
attn_softmax_bf16: Optional[bool] = False,
reuse_cache: Optional[bool] = False,
use_flash_attention: Optional[bool] = False,
flash_attention_recompute: Optional[bool] = False,
flash_attention_causal_mask: Optional[bool] = False,
cache_idx: int = None,
lazy_mode: Optional[bool] = True,
) -> Union[Tuple, CausalLMOutputWithPast]:

"""
Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.40.2/src/transformers/models/starcoder2/modeling_starcoder2.py
The only differences are:
- add new args token_idx
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if not hasattr(self.config, "_attn_implementation"):
setattr(self.config, "_attn_implementation", "eager")
else:
self.config._attn_implementation = "eager"

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
@@ -415,12 +433,27 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
token_idx=token_idx,
cache_position=cache_position,
attn_softmax_bf16=attn_softmax_bf16,
reuse_cache=reuse_cache,
use_flash_attention=use_flash_attention,
flash_attention_recompute=flash_attention_recompute,
flash_attention_causal_mask=flash_attention_causal_mask,
cache_idx=cache_idx,
lazy_mode=lazy_mode,
)

hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

_, seq_len, _ = hidden_states.shape
if seq_len > 1 and trim_logits and not self.training:
if token_idx is not None:
hidden_states = hidden_states.index_select(1, token_idx - 1)
else:
hidden_states = hidden_states[:, -1, :]

logits = self.lm_head(hidden_states).float()

loss = None
if labels is not None:
# Shift so that tokens < n predict n
@@ -431,7 +464,7 @@ def forward(
shift_labels = shift_labels.view(-1)
# Ensure tensors are on the same device
shift_labels = shift_labels.to(shift_logits.device)
loss_fct = CrossEntropyLoss()
loss_fct = torch.nn.CrossEntropyLoss()
loss = loss_fct(shift_logits, shift_labels)

if not return_dict:
@@ -447,20 +480,15 @@ def forward(
)

def prepare_inputs_for_generation(
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, token_idx=None, **kwargs
):
"""
Inherits from Starcoder2ForCausalLM: https://github.com/huggingface/transformers/blob/v4.39.0/src/transformers/models/starcoder2/modeling_starcoder2.py
The only differences are:
- add new args token_idx
- add token_idx into model_inputs
- from step2 when enable KV cache, slice next_input_ids from input_ids base on the token_idx
- from step2 when enable KV cache, slice next_position_ids from position_ids base on the token_idx
"""
token_idx = kwargs.get("token_idx", None)
past_length = 0
# Omit tokens covered by past_key_values
reuse_cache = kwargs.get("reuse_cache")
if past_key_values is not None:
if token_idx is None:
if token_idx is not None:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
else:
if isinstance(past_key_values, Cache):
cache_length = past_key_values.get_seq_length()
past_length = past_key_values.seen_tokens
@@ -488,8 +516,10 @@ def prepare_inputs_for_generation(
and cache_length + input_ids.shape[1] > max_cache_length
):
attention_mask = attention_mask[:, -max_cache_length:]
else:
input_ids = torch.index_select(input_ids, 1, token_idx - 1)
elif reuse_cache and token_idx is not None:
# With reuse_cache, KV cache is pre allocated hence for the 1st token we can slice the inputs till token idx for the fwd pass
input_ids = input_ids[:, :token_idx]
attention_mask = attention_mask[:, :token_idx]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
@@ -501,7 +531,8 @@ def prepare_inputs_for_generation(
position_ids = torch.index_select(position_ids, 1, token_idx - 1)
else:
position_ids = position_ids[:, -input_ids.shape[1] :]


cache_position = None
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if inputs_embeds is not None and past_key_values is None:
model_inputs = {"inputs_embeds": inputs_embeds}
@@ -510,11 +541,20 @@ def prepare_inputs_for_generation(

model_inputs.update(
{
"position_ids": position_ids,
"position_ids": position_ids.contiguous(),
"cache_position": cache_position,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"attention_mask": attention_mask,
"token_idx": token_idx,
"trim_logits": kwargs.get("trim_logits"),
"attn_softmax_bf16": kwargs.get("attn_softmax_bf16"),
"reuse_cache": reuse_cache,
"use_flash_attention": kwargs.get("use_flash_attention"),
"flash_attention_recompute": kwargs.get("flash_attention_recompute"),
"flash_attention_causal_mask": kwargs.get("flash_attention_causal_mask"),
"cache_idx": kwargs.get("cache_idx"),
"lazy_mode": kwargs.get("lazy_mode"),
}
)
return model_inputs