@@ -1163,6 +1163,7 @@ def forward(
11631163 output_attentions : Optional [bool ] = None ,
11641164 output_hidden_states : Optional [bool ] = None ,
11651165 return_dict : Optional [bool ] = None ,
1166+ logits_to_keep : Union [int , torch .Tensor ] = 0 ,
11661167 ** kwargs ,
11671168 ) -> Union [tuple , CausalLMOutputWithCrossAttentions ]:
11681169 r"""
@@ -1208,25 +1209,26 @@ def forward(
12081209 torch .cuda .set_device (self .transformer .first_device )
12091210 hidden_states = hidden_states .to (self .lm_head .weight .device )
12101211
1211- lm_logits = self .lm_head (hidden_states )
1212+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
1213+ logits = self .lm_head (hidden_states [:, slice_indices , :])
12121214
12131215 loss = None
12141216 if labels is not None :
12151217 # Flatten the tokens
12161218 loss = self .loss_function (
1217- lm_logits ,
1219+ logits ,
12181220 labels ,
12191221 vocab_size = self .config .vocab_size ,
12201222 ** kwargs ,
12211223 )
12221224
12231225 if not return_dict :
1224- output = (lm_logits ,) + transformer_outputs [1 :]
1226+ output = (logits ,) + transformer_outputs [1 :]
12251227 return ((loss ,) + output ) if loss is not None else output
12261228
12271229 return CausalLMOutputWithCrossAttentions (
12281230 loss = loss ,
1229- logits = lm_logits ,
1231+ logits = logits ,
12301232 past_key_values = transformer_outputs .past_key_values ,
12311233 hidden_states = transformer_outputs .hidden_states ,
12321234 attentions = transformer_outputs .attentions ,
0 commit comments