Skip to content

Commit

Permalink
Generate: support for left-padding on GPTNeoX and Llama (huggingface#…
Browse files Browse the repository at this point in the history
  • Loading branch information
gante authored and novice03 committed Jun 23, 2023
1 parent 3f746bd commit ca51a94
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 124 deletions.
80 changes: 56 additions & 24 deletions src/transformers/models/gpt_neox/modeling_gpt_neox.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,13 @@ def __init__(self, config):

def forward(
self,
hidden_states,
attention_mask,
head_mask=None,
layer_past=None,
use_cache=False,
output_attentions=False,
hidden_states: torch.FloatTensor,
attention_mask: torch.FloatTensor,
position_ids: torch.LongTensor,
head_mask: Optional[torch.FloatTensor] = None,
layer_past: Optional[Tuple[torch.Tensor]] = None,
use_cache: Optional[bool] = False,
output_attentions: Optional[bool] = False,
):
has_layer_past = layer_past is not None

Expand All @@ -132,12 +133,10 @@ def forward(

# Compute token offset for rotary embeddings (when decoding)
seq_len = key.shape[-2]
offset = 0
if has_layer_past:
offset = layer_past[0].shape[-2]
seq_len += offset
seq_len += layer_past[0].shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)
query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, position_ids)
query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)

Expand Down Expand Up @@ -275,9 +274,11 @@ def rotate_half(x):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
cos = cos[..., offset : q.shape[-2] + offset, :]
sin = sin[..., offset : q.shape[-2] + offset, :]
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
gather_indices = position_ids[:, None, :, None] # [bs, 1, seq_len, 1]
gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
Expand Down Expand Up @@ -308,16 +309,18 @@ def __init__(self, config):

def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
use_cache=False,
layer_past=None,
output_attentions=False,
hidden_states: Optional[torch.FloatTensor],
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = False,
layer_past: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
):
attention_layer_outputs = self.attention(
self.input_layernorm(hidden_states),
attention_mask=attention_mask,
position_ids=position_ids,
layer_past=layer_past,
head_mask=head_mask,
use_cache=use_cache,
Expand Down Expand Up @@ -374,6 +377,11 @@ def forward(
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
position_ids (`torch.LongTensor` of shape `({0})`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
Mask to nullify selected heads of the self-attention modules. Mask values selected in `[0, 1]`:
Expand Down Expand Up @@ -430,6 +438,7 @@ def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -467,7 +476,17 @@ def forward(
batch_size, seq_length = input_shape

if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * self.config.num_hidden_layers)
else:
past_length = past_key_values[0][0].size(-2)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(past_length, seq_length + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
else:
position_ids = position_ids.view(-1, seq_length).long()

# Attention mask.
if attention_mask is not None:
Expand Down Expand Up @@ -527,12 +546,14 @@ def custom_forward(*inputs):
create_custom_forward(layer),
hidden_states,
attention_mask,
position_ids,
head_mask[i],
)
else:
outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask[i],
layer_past=layer_past,
use_cache=use_cache,
Expand Down Expand Up @@ -587,6 +608,7 @@ def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
Expand Down Expand Up @@ -640,6 +662,7 @@ def forward(
outputs = self.gpt_neox(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
past_key_values=past_key_values,
Expand Down Expand Up @@ -672,20 +695,29 @@ def forward(
attentions=outputs.attentions,
)

def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs):
def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, **kwargs):
input_shape = input_ids.shape

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

# cut decoder_input_ids if past is used
if past_key_values and past_key_values[0] is not None:
input_ids = input_ids[:, -1:]

position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
position_ids = position_ids[:, -1].unsqueeze(-1)

# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
if attention_mask is None:
attention_mask = input_ids.new_ones(input_shape)

return {
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
"past_key_values": past_key_values,
}

Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/gptj/modeling_gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _get_embed_positions(self, position_ids):

def forward(
self,
hidden_states: Optional[torch.FloatTensor],
hidden_states: torch.FloatTensor,
layer_past: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.FloatTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand Down
Loading

0 comments on commit ca51a94

Please sign in to comment.