Skip to content

Commit 652e186

Browse files
Merge pull request #4 from alex-jw-brooks/conformer_pr_updates
feature attention mask bugfix
2 parents 78ef059 + 588fb49 commit 652e186

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/transformers/models/granite_speech/feature_extraction_granite_speech.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __call__(
8383
audio_embed_sizes = self._get_num_audio_features(audio_lengths)
8484
speech_inputs["audio_embed_sizes"] = audio_embed_sizes
8585
# todo: input_features_mask is not a great name, because input_features and input_features mask have different shapes (before/after the projector)
86-
speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) <= torch.tensor(
86+
speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor(
8787
audio_embed_sizes
8888
).view(-1, 1)
8989
return BatchFeature(data=speech_inputs)

src/transformers/models/granite_speech/modeling_granite_speech.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1253,6 +1253,8 @@ def get_merged_audio_embeddings(self, input_ids, audio_features, input_features_
12531253
and potentially labels.
12541254
"""
12551255
is_audio_index = input_ids == self.config.audio_token_index
1256+
assert torch.all(is_audio_index.int().sum(dim=1) == input_features_mask.int().sum(dim=1)).item(), \
1257+
"number of features should align"
12561258
llm_input_ids = torch.where(is_audio_index, 0, input_ids)
12571259
inputs_embeds = self.language_model.get_input_embeddings()(llm_input_ids) # [bsz, # features, hidden size]
12581260

0 commit comments

Comments
 (0)