File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed
src/transformers/models/biogpt Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change @@ -854,6 +854,7 @@ def forward(
854854 output_hidden_states : Optional [bool ] = None ,
855855 return_dict : Optional [bool ] = None ,
856856 cache_position : Optional [torch .Tensor ] = None ,
857+ logits_to_keep : Union [int , torch .Tensor ] = 0 ,
857858 ) -> Union [tuple , SequenceClassifierOutputWithPast ]:
858859 r"""
859860 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -876,7 +877,8 @@ def forward(
876877 cache_position = cache_position ,
877878 )
878879 hidden_states = transformer_outputs [0 ]
879- logits = self .score (hidden_states )
880+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
881+ logits = self .score (hidden_states [:, slice_indices , :])
880882
881883 if input_ids is not None :
882884 batch_size , sequence_length = input_ids .shape [:2 ]
Original file line number Diff line number Diff line change @@ -682,6 +682,7 @@ def forward(
682682 output_hidden_states : Optional [bool ] = None ,
683683 return_dict : Optional [bool ] = None ,
684684 cache_position : Optional [torch .Tensor ] = None ,
685+ logits_to_keep : Union [int , torch .Tensor ] = 0 ,
685686 ) -> Union [tuple , SequenceClassifierOutputWithPast ]:
686687 r"""
687688 labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -704,7 +705,8 @@ def forward(
704705 cache_position = cache_position ,
705706 )
706707 hidden_states = transformer_outputs [0 ]
707- logits = self .score (hidden_states )
708+ slice_indices = slice (- logits_to_keep , None ) if isinstance (logits_to_keep , int ) else logits_to_keep
709+ logits = self .score (hidden_states [:, slice_indices , :])
708710
709711 if input_ids is not None :
710712 batch_size , sequence_length = input_ids .shape [:2 ]
You can’t perform that action at this time.
0 commit comments