Skip to content

Commit 4f1faa0

Browse files
0x-aviArthurZucker
authored andcommitted
Biogptlogits (#41270)
added logits slicing to BioGpt for seq classifier Signed-off-by: Aviral <aviralkamaljain@gmail.com>
1 parent 91e1bdd commit 4f1faa0

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,7 @@ def forward(
871871
output_hidden_states: Optional[bool] = None,
872872
return_dict: Optional[bool] = None,
873873
cache_position: Optional[torch.Tensor] = None,
874+
logits_to_keep: Union[int, torch.Tensor] = 0,
874875
) -> Union[tuple, SequenceClassifierOutputWithPast]:
875876
r"""
876877
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -894,7 +895,8 @@ def forward(
894895
cache_position=cache_position,
895896
)
896897
hidden_states = transformer_outputs[0]
897-
logits = self.score(hidden_states)
898+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
899+
logits = self.score(hidden_states[:, slice_indices, :])
898900

899901
if input_ids is not None:
900902
batch_size, sequence_length = input_ids.shape[:2]

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,7 @@ def forward(
693693
output_hidden_states: Optional[bool] = None,
694694
return_dict: Optional[bool] = None,
695695
cache_position: Optional[torch.Tensor] = None,
696+
logits_to_keep: Union[int, torch.Tensor] = 0,
696697
) -> Union[tuple, SequenceClassifierOutputWithPast]:
697698
r"""
698699
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
@@ -716,7 +717,8 @@ def forward(
716717
cache_position=cache_position,
717718
)
718719
hidden_states = transformer_outputs[0]
719-
logits = self.score(hidden_states)
720+
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
721+
logits = self.score(hidden_states[:, slice_indices, :])
720722

721723
if input_ids is not None:
722724
batch_size, sequence_length = input_ids.shape[:2]

0 commit comments

Comments
 (0)