Skip to content

Commit 623d395

Browse files
alex-jw-brooksavihu111eustlb
authored
Add Granite Speech Support (#36801)
* First pass at speech granite Add encoder / projector, rename things * Combine into one model file with causal lm outputs for forward * Add loss calc * Fix config loading Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> * Split new / old loading logic * Use transformers integration for loading peft adapters * Add generation wrapper for selective lora enablement * Add note for qformer encoder automodel * Guard torch/audio imports in feature extractor * Handle granite speech autoclasses * Handle optional deps in package structure for granite speech * Add granite pretrained model def for init * Add dummy objects for torch/torchaudio * Add tests for granite speech processor * Minor formatting fixes and refactoring * Add options for falling back to config in forward * Tentative model docstrings for granite speech * Fix config type * Remove legacy load * Allow non-lora variants for granite speech * Override weight tying for llm * Use text config instead of llm config * Add output embeddings getter to fix weight tying * Fix relative imports * computing the number of audio features, based on the raw audio sequence. * collating audio inputs, and keeping the original lengths. * asserted we have text. otherwise we can't specify the audio special token. * assering the number of audio-symbols/audios match correctly. running get validated_audios only when audio is present * indentation bugfix + supporting different feature lengths when expanding audio. * redundant, done in _get_validated_text * adapting the tests: - we must have text (not either audio or text) - _get_num_audio_features takes a list of raw lengths, provided it insetad. * Minor cleanup, remove unused import * Add more tests for batch feature processing * Allow setting offset in rel position embeddings * Add config option for warning if peft is not installed w/ lora * Port blip2 qformer code into granite speech * Add sad test for numpy arr processing * Allow numpy arrays / tuples in granite speech processor * Fix config type for projector * - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16) - cast input_features to the model dtype (support bfloat16) * merge Blip2QFormerConfig to GraniteSpeechProjectorConfig * prevent a crash when re-saving/loading the model (line 109) * consider additional edge cases during preprocessing. * consider additional edge cases during preprocessing. * add features mask for batched inference (bugfix) * Minor refactor, remove multiaudio processor tests * Add set input/output embeddings for granite speech * Fix feature dim check in processor test * Pop input features in embed test for granite speech * Small fixes for test edge cases Add granite speech to seq2seq causal lm mapping names * Add small tests for granite speech model * Fix data parallelism test * Standardize model class names * Fix check for copies * Fix misaligned init check * Skip granite speech in checkpoint check * Use default for tie_word_embeddings in granite speech * Fix non documentation granite speech repo issues * Fix comments and docstring checks * Add placeholder docs for granite speech * Fix test naming collision * Code formatting * Rerun torch dummy obj regen * Fix save pretrained for granite speech * Import sorting * Fix tests typo * Remove offset hack * Pass args through encoder config * Remove unused prune heads from blip2 * removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention. * remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention * remove GraniteSpeechConformerScale * rename to hidden_states * rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous. * move pre-norm to the attention/feedforward blocks (avoid complex module wrapping) * adding pre_norm into forward * feature extractor refactoring to resemble how it's done in phi4multimodal. * rename feature_extractor to audio_processor * bugfix: input_feature_mask fix to get the exact number tokens. * Fix pytest decorator in processor test * Add (disabled) integration tests for granite speech * Fix handling of optional feature masking * Loosen validation in processing for vLLM compatability * Formatting fixes * Update init structure to mirror llama * Make granite speech projector generic * Update test config to reflect generic projector * Formatting fixes * Fix typos, add license * Fix undefined var in input processing * Cleanup and expose ctc encoder * Add missing config docstrings * Better var names, type hints, etc * Set attn context size in init * Add max pos emb to encoder config * Cleanup feature extractor * Add granite speech architecture details * Remove granite speech qformer ref * Add paper link, explicit calc for qkv * Calculate padding directly in depthwise conv1d init * Raise value error instead of asserting * Reorder class defs (classes used at top) * Precompute relpos distances * Run formatting * Pass attention distances through forward * Apply suggestions from code review Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> * Add todo for using common batch feature extraction * Rename audios/features * Ensure chat template may be provided to processor * Move granite speech docs to audio models * Add todos for input proc refactoring * Fix import order * Guard torch import * Use relative imports * Require torch backend for processor in granite speech * Add backend guards in feature extractor --------- Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
1 parent 435f88f commit 623d395

18 files changed

+1924
-0
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,8 @@
823823
title: EnCodec
824824
- local: model_doc/fastspeech2_conformer
825825
title: FastSpeech2Conformer
826+
- local: model_doc/granite_speech
827+
title: GraniteSpeech
826828
- local: model_doc/hubert
827829
title: Hubert
828830
- local: model_doc/mctct
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
<!--Copyright 2025 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
13+
rendered properly in your Markdown viewer.
14+
15+
-->
16+
17+
# Granite Speech
18+
19+
<div class="flex flex-wrap space-x-1">
20+
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
21+
</div>
22+
23+
## Overview
24+
The Granite Speech model is a multimodal language model, consisting of a speech encoder, speech projector, large language model, and LoRA adapter(s). More details regarding each component for the current (Granite 3.2 Speech) model architecture may be found below.
25+
26+
1. Speech Encoder: A [Conformer](https://arxiv.org/abs/2005.08100) encoder trained with Connectionist Temporal Classification (CTC) on character-level targets on ASR corpora. The encoder uses block-attention and self-conditioned CTC from the middle layer.
27+
28+
2. Speech Projector: A query transformer (q-former) operating on the outputs of the last encoder block. The encoder and projector temporally downsample the audio features to be merged into the multimodal embeddings to be processed by the llm.
29+
30+
3. Large Language Model: The Granite Speech model leverages Granite LLMs, which were originally proposed in [this paper](https://arxiv.org/abs/2408.13359).
31+
32+
4. LoRA adapter(s): The Granite Speech model contains a modality specific LoRA, which will be enabled when audio features are provided, and disabled otherwise.
33+
34+
35+
Note that most of the aforementioned components are implemented generically to enable compatability and potential integration with other model architectures in transformers.
36+
37+
38+
This model was contributed by [Alexander Brooks](https://huggingface.co/abrooks9944), [Avihu Dekel](https://huggingface.co/Avihu), and [George Saon](https://huggingface.co/gsaon).
39+
40+
## Usage tips
41+
- This model bundles its own LoRA adapter, which will be automatically loaded and enabled/disabled as needed during inference calls. Be sure to install [PEFT](https://github.com/huggingface/peft) to ensure the LoRA is correctly applied!
42+
43+
<!-- TODO (@alex-jw-brooks) Add an example here once the model compatible with the transformers implementation is released -->
44+
45+
## GraniteSpeechConfig
46+
47+
[[autodoc]] GraniteSpeechConfig
48+
49+
50+
## GraniteSpeechEncoderConfig
51+
52+
[[autodoc]] GraniteSpeechEncoderConfig
53+
54+
55+
## GraniteSpeechProcessor
56+
57+
[[autodoc]] GraniteSpeechProcessor
58+
59+
60+
## GraniteSpeechFeatureExtractor
61+
62+
[[autodoc]] GraniteSpeechFeatureExtractor
63+
64+
65+
## GraniteSpeechForConditionalGeneration
66+
67+
[[autodoc]] GraniteSpeechForConditionalGeneration
68+
- forward

src/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@
125125
from .gpt_sw3 import *
126126
from .gptj import *
127127
from .granite import *
128+
from .granite_speech import *
128129
from .granitemoe import *
129130
from .granitemoeshared import *
130131
from .grounding_dino import *

src/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@
142142
("gptj", "GPTJConfig"),
143143
("gptsan-japanese", "GPTSanJapaneseConfig"),
144144
("granite", "GraniteConfig"),
145+
("granite_speech", "GraniteSpeechConfig"),
145146
("granitemoe", "GraniteMoeConfig"),
146147
("granitemoeshared", "GraniteMoeSharedConfig"),
147148
("granitevision", "LlavaNextConfig"),
@@ -491,6 +492,7 @@
491492
("gptj", "GPT-J"),
492493
("gptsan-japanese", "GPTSAN-japanese"),
493494
("granite", "Granite"),
495+
("granite_speech", "GraniteSpeech"),
494496
("granitemoe", "GraniteMoeMoe"),
495497
("granitemoeshared", "GraniteMoeSharedMoe"),
496498
("granitevision", "LLaVA-NeXT"),

src/transformers/models/auto/feature_extraction_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
("encodec", "EncodecFeatureExtractor"),
6262
("flava", "FlavaFeatureExtractor"),
6363
("glpn", "GLPNFeatureExtractor"),
64+
("granite_speech", "GraniteSpeechFeatureExtractor"),
6465
("groupvit", "CLIPFeatureExtractor"),
6566
("hubert", "Wav2Vec2FeatureExtractor"),
6667
("imagegpt", "ImageGPTFeatureExtractor"),

src/transformers/models/auto/modeling_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,7 @@
973973
("encoder-decoder", "EncoderDecoderModel"),
974974
("fsmt", "FSMTForConditionalGeneration"),
975975
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
976+
("granite_speech", "GraniteSpeechForConditionalGeneration"),
976977
("led", "LEDForConditionalGeneration"),
977978
("longt5", "LongT5ForConditionalGeneration"),
978979
("m2m_100", "M2M100ForConditionalGeneration"),
@@ -997,6 +998,7 @@
997998

998999
MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
9991000
[
1001+
("granite_speech", "GraniteSpeechForConditionalGeneration"),
10001002
("moonshine", "MoonshineForConditionalGeneration"),
10011003
("pop2piano", "Pop2PianoForConditionalGeneration"),
10021004
("seamless_m4t", "SeamlessM4TForSpeechToText"),

src/transformers/models/auto/processing_auto.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
("gemma3", "Gemma3Processor"),
6767
("git", "GitProcessor"),
6868
("got_ocr2", "GotOcr2Processor"),
69+
("granite_speech", "GraniteSpeechProcessor"),
6970
("grounding-dino", "GroundingDinoProcessor"),
7071
("groupvit", "CLIPProcessor"),
7172
("hubert", "Wav2Vec2Processor"),
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import TYPE_CHECKING
15+
16+
from ...utils import _LazyModule
17+
from ...utils.import_utils import define_import_structure
18+
19+
20+
if TYPE_CHECKING:
21+
from .configuration_granite_speech import *
22+
from .feature_extraction_granite_speech import *
23+
from .modeling_granite_speech import *
24+
from .processing_granite_speech import *
25+
else:
26+
import sys
27+
28+
_file = globals()["__file__"]
29+
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
# coding=utf-8
2+
# Copyright 2025 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Config class for Granite Speech."""
16+
17+
from ...configuration_utils import PretrainedConfig
18+
from ..auto import CONFIG_MAPPING, AutoConfig
19+
20+
21+
class GraniteSpeechEncoderConfig(PretrainedConfig):
22+
r"""
23+
This is the configuration class to store the configuration of a [`GraniteSpeechCTCEncoder`]. It is used to instantiate
24+
a Granite Speech audio encoder according to the specified arguments, defining the model architecture. Instantiating a
25+
configuration with the dfefaults will yield a similar configuration to that of the audio encoder of the Granite Speech
26+
architecture.
27+
28+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
29+
documentation from [`PretrainedConfig`] for more information.
30+
31+
Args:
32+
input_dim (`int`, *optional*, defaults to 160):
33+
Dimension of the first hidden layer of the encoder.
34+
num_layers (`int`, *optional*, defaults to 10):
35+
Number of encoder blocks.
36+
hidden_dim (`int`, *optional*, defaults to 1024):
37+
The size of the intermediate layers in the conformer encoder.
38+
feedforward_mult (`int`, *optional*, defaults to 4):
39+
Multiplier for the up/down projections in the encoder's feedforward layers;
40+
The projections will have intermediate dim of size `hidden_dim * feedforward_mult`.
41+
num_heads (`int`, *optional*, defaults to 8):
42+
Number of attention heads for each attention layer in the Transformer encoder.
43+
dim_head (`int`, *optional*, defaults to 128):
44+
Dimension of attention heads for each attention layer in the Transformer encoder.
45+
output_dim (`int`, *optional*, defaults to 42):
46+
Intermediate dimension of the feedforward projections in the conformer
47+
to be added to every other encoder block's output.
48+
context_size (`int`, *optional*, defaults to 200):
49+
Context size to be used in conformer attention.
50+
max_pos_emb (`int`, *optional*, defaults to 512):
51+
Max pos embeds to be used in attention (shaw's relative positional encoding).
52+
dropout (`float`, *optional*, defaults to 0.1):
53+
The dropout probability for fully connected layers in the encoder.
54+
conv_kernel_size (`int`, *optional*, defaults to 15):
55+
Kernel size to be used for 1D convolution in each conformer block.
56+
conv_expansion_factor (`int`, *optional*, defaults to 2):
57+
Intermediate dimension to be used in conformer convolutions.
58+
59+
Example:
60+
61+
```python
62+
>>> from transformers import GraniteSpeechEncoderConfig, GraniteSpeechCTCEncoder
63+
64+
>>> # Initializing a GraniteSpeechEncoderConfig
65+
>>> configuration = GraniteSpeechEncoderConfig()
66+
67+
>>> # Initializing a GraniteSpeechCTCEncoder (with random weights)
68+
>>> model = GraniteSpeechCTCEncoder(configuration)
69+
70+
>>> # Accessing the model configuration
71+
>>> configuration = model.config
72+
```"""
73+
74+
model_type = "granite_speech_encoder"
75+
76+
def __init__(
77+
self,
78+
input_dim=160,
79+
num_layers=10,
80+
hidden_dim=1024,
81+
feedforward_mult=4,
82+
num_heads=8,
83+
dim_head=128,
84+
output_dim=42,
85+
context_size=200,
86+
max_pos_emb=512,
87+
dropout=0.1,
88+
conv_kernel_size=15,
89+
conv_expansion_factor=2,
90+
**kwargs,
91+
):
92+
super().__init__(**kwargs)
93+
self.input_dim = input_dim
94+
self.num_layers = num_layers
95+
self.hidden_dim = hidden_dim
96+
self.feedforward_mult = feedforward_mult
97+
self.num_heads = num_heads
98+
self.dim_head = dim_head
99+
self.output_dim = output_dim
100+
self.context_size = context_size
101+
self.dropout = dropout
102+
self.conv_kernel_size = conv_kernel_size
103+
self.conv_expansion_factor = conv_expansion_factor
104+
self.max_pos_emb = max_pos_emb
105+
106+
107+
class GraniteSpeechConfig(PretrainedConfig):
108+
r"""
109+
This is the configuration class to store the configuration of a [`GraniteSpeechForConditionalGeneration`]. It is used to instantiate an
110+
Granite Speech model according to the specified arguments, defining the model architecture.
111+
112+
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
113+
documentation from [`PretrainedConfig`] for more information.
114+
115+
Args:
116+
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `GraniteConfig`):
117+
The config object or dictionary of the text backbone.
118+
encoder_config (`GraniteSpeechEncoderConfig`, *optional*):
119+
The config object or dictionary of the Granite Speech CTC Encoder.
120+
projector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Blip2QFormerConfig`):
121+
The config object or dictionary of the audio projector.
122+
audio_token_index (`int`, *optional*, defaults to 49155):
123+
The audio token index to encode the audio prompt.
124+
initializer_range (`float`, *optional*, defaults to 0.02):
125+
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
126+
has_lora_adapter (`bool`, *optional*, defaults to `True`):
127+
Indicates whether or not the model has a lora adapter that should only
128+
be activate when processing audio inputs.
129+
downsample_rate (`int`, *optional*, defaults to 5):
130+
Downsample rate for the audio feature extractor.
131+
window_size (`int`, *optional*, defaults to 15):
132+
Window size for the audio feature projector.
133+
134+
Example:
135+
136+
```python
137+
>>> from transformers import GraniteSpeechConfig, GraniteSpeechForConditionalGeneration
138+
139+
>>> # Initializing a GraniteSpeechConfig
140+
>>> configuration = GraniteSpeechConfig()
141+
142+
>>> # Initializing a GraniteSpeechForConditionalGeneration (with random weights)
143+
>>> model = GraniteSpeechForConditionalGeneration(configuration)
144+
145+
>>> # Accessing the model configuration
146+
>>> configuration = model.config
147+
```"""
148+
149+
model_type = "granite_speech"
150+
sub_configs = {
151+
"text_config": AutoConfig,
152+
"encoder_config": GraniteSpeechEncoderConfig,
153+
"projector_config": AutoConfig,
154+
}
155+
156+
def __init__(
157+
self,
158+
text_config=None,
159+
encoder_config=None,
160+
projector_config=None,
161+
audio_token_index=49155,
162+
initializer_range=0.02,
163+
has_lora_adapter=True,
164+
downsample_rate=5,
165+
window_size=15,
166+
**kwargs,
167+
):
168+
if isinstance(text_config, dict):
169+
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "granite"
170+
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
171+
elif text_config is None:
172+
text_config = CONFIG_MAPPING["granite"]()
173+
174+
if isinstance(projector_config, dict):
175+
projector_config["model_type"] = (
176+
projector_config["model_type"] if "model_type" in projector_config else "blip_2_qformer"
177+
)
178+
projector_config = CONFIG_MAPPING[projector_config["model_type"]](**projector_config)
179+
elif projector_config is None:
180+
projector_config = CONFIG_MAPPING["blip_2_qformer"]()
181+
182+
if not isinstance(encoder_config, GraniteSpeechEncoderConfig):
183+
encoder_config = {} if encoder_config is None else encoder_config
184+
encoder_config = GraniteSpeechEncoderConfig(**encoder_config)
185+
186+
self.text_config = text_config
187+
self.encoder_config = encoder_config
188+
self.projector_config = projector_config
189+
self.audio_token_index = audio_token_index
190+
self.initializer_range = initializer_range
191+
self.has_lora_adapter = has_lora_adapter
192+
self.downsample_rate = downsample_rate
193+
self.window_size = window_size
194+
super().__init__(**kwargs)
195+
196+
197+
__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechConfig"]

0 commit comments

Comments
 (0)