File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed 
src/transformers/models/biogpt Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -871,6 +871,7 @@ def forward(
871871        output_hidden_states : Optional [bool ] =  None ,
872872        return_dict : Optional [bool ] =  None ,
873873        cache_position : Optional [torch .Tensor ] =  None ,
874+         logits_to_keep : Union [int , torch .Tensor ] =  0 ,
874875    ) ->  Union [tuple , SequenceClassifierOutputWithPast ]:
875876        r""" 
876877        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 
@@ -894,7 +895,8 @@ def forward(
894895            cache_position = cache_position ,
895896        )
896897        hidden_states  =  transformer_outputs [0 ]
897-         logits  =  self .score (hidden_states )
898+         slice_indices  =  slice (- logits_to_keep , None ) if  isinstance (logits_to_keep , int ) else  logits_to_keep 
899+         logits  =  self .score (hidden_states [:, slice_indices , :])
898900
899901        if  input_ids  is  not None :
900902            batch_size , sequence_length  =  input_ids .shape [:2 ]
Original file line number Diff line number Diff line change @@ -693,6 +693,7 @@ def forward(
693693        output_hidden_states : Optional [bool ] =  None ,
694694        return_dict : Optional [bool ] =  None ,
695695        cache_position : Optional [torch .Tensor ] =  None ,
696+         logits_to_keep : Union [int , torch .Tensor ] =  0 ,
696697    ) ->  Union [tuple , SequenceClassifierOutputWithPast ]:
697698        r""" 
698699        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): 
@@ -716,7 +717,8 @@ def forward(
716717            cache_position = cache_position ,
717718        )
718719        hidden_states  =  transformer_outputs [0 ]
719-         logits  =  self .score (hidden_states )
720+         slice_indices  =  slice (- logits_to_keep , None ) if  isinstance (logits_to_keep , int ) else  logits_to_keep 
721+         logits  =  self .score (hidden_states [:, slice_indices , :])
720722
721723        if  input_ids  is  not None :
722724            batch_size , sequence_length  =  input_ids .shape [:2 ]
 
 
   
 
     
   
   
          
    
    
     
    
      
     
     
    You can’t perform that action at this time.
  
 
    
  
    
      
        
     
       
      
     
   
 
    
    
  
 
  
 
     
    
0 commit comments