Skip to content

Commit 4f286fb

Browse files
authored
Biogptlogits (#41270)
added logits slicing to BioGpt for seq classifier Signed-off-by: Aviral <aviralkamaljain@gmail.com>
1 parent 1d91a8a commit 4f286fb

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -854,6 +854,7 @@ def forward(
854854
output_hidden_states: Optional[bool] = None,
855855
return_dict: Optional[bool] = None,
856856
cache_position: Optional[torch.Tensor] = None,
857+
logits_to_keep: Union[int, torch.Tensor] = 0,
857858
) -> Union[tuple, SequenceClassifierOutputWithPast]:
858859
r"""
859860
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -876,7 +877,8 @@ def forward(
876877
cache_position=cache_position,
877878
)
878879
hidden_states = transformer_outputs[0]
879-
logits = self.score(hidden_states)
880+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
881+
logits = self.score(hidden_states[:, slice_indices, :])
880882

881883
if input_ids is not None:
882884
batch_size, sequence_length = input_ids.shape[:2]

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,7 @@ def forward(
682682
output_hidden_states: Optional[bool] = None,
683683
return_dict: Optional[bool] = None,
684684
cache_position: Optional[torch.Tensor] = None,
685+
logits_to_keep: Union[int, torch.Tensor] = 0,
685686
) -> Union[tuple, SequenceClassifierOutputWithPast]:
686687
r"""
687688
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -704,7 +705,8 @@ def forward(
704705
cache_position=cache_position,
705706
)
706707
hidden_states = transformer_outputs[0]
707-
logits = self.score(hidden_states)
708+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
709+
logits = self.score(hidden_states[:, slice_indices, :])
708710

709711
if input_ids is not None:
710712
batch_size, sequence_length = input_ids.shape[:2]

0 commit comments

Comments
 (0)