Skip to content

Commit 402dbf0

Browse files
Add todos for input proc refactoring
1 parent 76499ce commit 402dbf0

File tree

2 files changed

+13
-3
lines changed

2 files changed

+13
-3
lines changed

src/transformers/models/granite_speech/feature_extraction_granite_speech.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,15 @@ def __call__(
7474
)
7575
audio_embed_sizes = self._get_num_audio_features(audio_lengths)
7676
speech_inputs["audio_embed_sizes"] = audio_embed_sizes
77-
# TODO: input_features_mask is not a great name, because
78-
# input_features and input_features_mask have different shapes
79-
# (before/after the projector)
77+
# TODO (@alex-jw-brooks): Currently input_features_mask is not
78+
# a great name, because input_features and input_features_mask
79+
# have different shapes (before/after the projector).
80+
#
81+
# We should align this with other multimodal models, e.g,. llava
82+
# and qwen2audio and refactor this to ensure input_feature_mask
83+
# has the same dimensionality as input_features, or compute it in
84+
# the model based on the audio embedding sizes (since we do not
85+
# have an attention mask for the audio features to infer padding from).
8086
speech_inputs["input_features_mask"] = torch.arange(max(audio_embed_sizes)).view(1, -1) < torch.tensor(
8187
audio_embed_sizes
8288
).view(-1, 1)

src/transformers/models/granite_speech/processing_granite_speech.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ def __call__(
6262
# trigger the conditions due to the way they call multimodal
6363
# processors, e.g., vLLM.
6464
audio_inputs = self.audio_processor(audio, device=device)
65+
66+
# TODO (@alex-jw-brooks); we should add a util to get_num_audio_tokens
67+
# from feature lengths and call it here, rather than returning it
68+
# from the feature extractor.
6569
audio_embed_sizes = audio_inputs.pop("audio_embed_sizes")
6670

6771
# Expand the audio placeholders to match the feature dims; this

0 commit comments

Comments
 (0)