Skip to content

Commit 7b38073

Browse files
Drop unnecessary tokens in GPT2Model generation (#39016)
Drop unnecessary tokens in GPT2Model generation. Co-authored-by: Yi Pan <conlesspan@outlook.com>
1 parent e212ff9 commit 7b38073

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1163,6 +1163,7 @@ def forward(
11631163
output_attentions: Optional[bool] = None,
11641164
output_hidden_states: Optional[bool] = None,
11651165
return_dict: Optional[bool] = None,
1166+
logits_to_keep: Union[int, torch.Tensor] = 0,
11661167
**kwargs,
11671168
) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
11681169
r"""
@@ -1208,25 +1209,26 @@ def forward(
12081209
torch.cuda.set_device(self.transformer.first_device)
12091210
hidden_states = hidden_states.to(self.lm_head.weight.device)
12101211

1211-
lm_logits = self.lm_head(hidden_states)
1212+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1213+
logits = self.lm_head(hidden_states[:, slice_indices, :])
12121214

12131215
loss = None
12141216
if labels is not None:
12151217
# Flatten the tokens
12161218
loss = self.loss_function(
1217-
lm_logits,
1219+
logits,
12181220
labels,
12191221
vocab_size=self.config.vocab_size,
12201222
**kwargs,
12211223
)
12221224

12231225
if not return_dict:
1224-
output = (lm_logits,) + transformer_outputs[1:]
1226+
output = (logits,) + transformer_outputs[1:]
12251227
return ((loss,) + output) if loss is not None else output
12261228

12271229
return CausalLMOutputWithCrossAttentions(
12281230
loss=loss,
1229-
logits=lm_logits,
1231+
logits=logits,
12301232
past_key_values=transformer_outputs.past_key_values,
12311233
hidden_states=transformer_outputs.hidden_states,
12321234
attentions=transformer_outputs.attentions,

0 commit comments

Comments
 (0)