Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
117 commits
Select commit Hold shift + click to select a range
a837f66
First pass at speech granite
alex-jw-brooks Mar 13, 2025
3bdd91a
Combine into one model file with causal lm outputs for forward
alex-jw-brooks Mar 14, 2025
a841bc1
Add loss calc
alex-jw-brooks Mar 15, 2025
a943aed
Fix config loading
alex-jw-brooks Mar 16, 2025
d459591
Split new / old loading logic
alex-jw-brooks Mar 16, 2025
edcffd4
Use transformers integration for loading peft adapters
alex-jw-brooks Mar 16, 2025
1f1ec31
Add generation wrapper for selective lora enablement
alex-jw-brooks Mar 16, 2025
b86d169
Add note for qformer encoder automodel
alex-jw-brooks Mar 17, 2025
7ad8b24
Guard torch/audio imports in feature extractor
alex-jw-brooks Mar 17, 2025
07afe8b
Handle granite speech autoclasses
alex-jw-brooks Mar 17, 2025
1814333
Handle optional deps in package structure for granite speech
alex-jw-brooks Mar 17, 2025
e983252
Add granite pretrained model def for init
alex-jw-brooks Mar 17, 2025
7b202eb
Add dummy objects for torch/torchaudio
alex-jw-brooks Mar 17, 2025
6844148
Add tests for granite speech processor
alex-jw-brooks Mar 17, 2025
265dfb3
Minor formatting fixes and refactoring
alex-jw-brooks Mar 17, 2025
cff180e
Add options for falling back to config in forward
alex-jw-brooks Mar 17, 2025
61a9495
Tentative model docstrings for granite speech
alex-jw-brooks Mar 17, 2025
88633ff
Fix config type
alex-jw-brooks Mar 17, 2025
a04da45
Remove legacy load
alex-jw-brooks Mar 17, 2025
581bae2
Allow non-lora variants for granite speech
alex-jw-brooks Mar 17, 2025
deb77eb
Override weight tying for llm
alex-jw-brooks Mar 17, 2025
674e971
Use text config instead of llm config
alex-jw-brooks Mar 17, 2025
0cbe18a
Add output embeddings getter to fix weight tying
alex-jw-brooks Mar 17, 2025
cafb696
Fix relative imports
alex-jw-brooks Mar 18, 2025
d6f866d
computing the number of audio features, based on the raw audio sequence.
avihu111 Mar 19, 2025
ab4cdc2
collating audio inputs, and keeping the original lengths.
avihu111 Mar 19, 2025
2f8f576
asserted we have text. otherwise we can't specify the audio special t…
avihu111 Mar 19, 2025
5794333
assering the number of audio-symbols/audios match correctly.
avihu111 Mar 19, 2025
7dddeff
indentation bugfix + supporting different feature lengths when expand…
avihu111 Mar 19, 2025
bf66295
redundant, done in _get_validated_text
avihu111 Mar 19, 2025
42b331d
adapting the tests:
avihu111 Mar 19, 2025
1ada05c
Minor cleanup, remove unused import
alex-jw-brooks Mar 20, 2025
5a4ece2
Add more tests for batch feature processing
alex-jw-brooks Mar 20, 2025
cd53167
Allow setting offset in rel position embeddings
alex-jw-brooks Mar 20, 2025
6a0d62c
Add config option for warning if peft is not installed w/ lora
alex-jw-brooks Mar 20, 2025
ed41307
Port blip2 qformer code into granite speech
alex-jw-brooks Mar 21, 2025
eff7982
Add sad test for numpy arr processing
alex-jw-brooks Mar 21, 2025
1f2e4da
Allow numpy arrays / tuples in granite speech processor
alex-jw-brooks Mar 21, 2025
cb6bf4a
Fix config type for projector
alex-jw-brooks Mar 21, 2025
4ca7e44
- pad instead of creating a zeros tensor, to keep the original dtype/…
avihu111 Mar 23, 2025
bd82de0
merge Blip2QFormerConfig to GraniteSpeechProjectorConfig
avihu111 Mar 23, 2025
f7384f6
prevent a crash when re-saving/loading the model (line 109)
avihu111 Mar 23, 2025
e0f8b53
consider additional edge cases during preprocessing.
avihu111 Mar 23, 2025
085e0fa
consider additional edge cases during preprocessing.
avihu111 Mar 23, 2025
9c1eac3
add features mask for batched inference (bugfix)
avihu111 Mar 23, 2025
633e8e7
Minor refactor, remove multiaudio processor tests
alex-jw-brooks Mar 24, 2025
d40175f
Add set input/output embeddings for granite speech
alex-jw-brooks Mar 24, 2025
672c5cc
Fix feature dim check in processor test
alex-jw-brooks Mar 24, 2025
56f2e6e
Pop input features in embed test for granite speech
alex-jw-brooks Mar 24, 2025
22a9d24
Small fixes for test edge cases
alex-jw-brooks Mar 24, 2025
c13d3ed
Add small tests for granite speech model
alex-jw-brooks Mar 24, 2025
7321c12
Fix data parallelism test
alex-jw-brooks Mar 24, 2025
5738ad1
Standardize model class names
alex-jw-brooks Mar 24, 2025
bead784
Fix check for copies
alex-jw-brooks Mar 24, 2025
4961191
Fix misaligned init check
alex-jw-brooks Mar 24, 2025
6719704
Skip granite speech in checkpoint check
alex-jw-brooks Mar 24, 2025
938bece
Use default for tie_word_embeddings in granite speech
alex-jw-brooks Mar 24, 2025
d6145dd
Fix non documentation granite speech repo issues
alex-jw-brooks Mar 24, 2025
7e816b4
Fix comments and docstring checks
alex-jw-brooks Mar 24, 2025
4f4659f
Add placeholder docs for granite speech
alex-jw-brooks Mar 24, 2025
7b8cd96
Fix test naming collision
alex-jw-brooks Mar 24, 2025
92eb144
Code formatting
alex-jw-brooks Mar 24, 2025
b0fe344
Rerun torch dummy obj regen
alex-jw-brooks Mar 24, 2025
42b940c
Fix save pretrained for granite speech
alex-jw-brooks Mar 24, 2025
02fc57a
Import sorting
alex-jw-brooks Mar 24, 2025
f7e53ed
Fix tests typo
alex-jw-brooks Mar 25, 2025
7853cc7
Remove offset hack
alex-jw-brooks Mar 25, 2025
19b41bb
Pass args through encoder config
alex-jw-brooks Mar 26, 2025
7f22f15
Remove unused prune heads from blip2
alex-jw-brooks Mar 26, 2025
10d2ad9
removing einsum. replaced with explicit multiplication (relative posi…
avihu111 Mar 26, 2025
85eaab5
remove Sequential from ConformerFeedForward and ConformerConvModule. …
avihu111 Mar 26, 2025
10ca6ea
remove GraniteSpeechConformerScale
avihu111 Mar 27, 2025
ca0c8ea
rename to hidden_states
avihu111 Mar 27, 2025
51f81ae
rename conformer layers to self.layers, remove the first linear from …
avihu111 Mar 27, 2025
2544100
move pre-norm to the attention/feedforward blocks (avoid complex modu…
avihu111 Mar 27, 2025
40e2c68
adding pre_norm into forward
avihu111 Mar 27, 2025
d529e29
feature extractor refactoring to resemble how it's done in phi4multim…
avihu111 Mar 27, 2025
e628517
rename feature_extractor to audio_processor
avihu111 Mar 27, 2025
9569d76
bugfix: input_feature_mask fix to get the exact number tokens.
avihu111 Mar 30, 2025
f62be3b
Fix pytest decorator in processor test
alex-jw-brooks Mar 31, 2025
6eeaab4
Add (disabled) integration tests for granite speech
alex-jw-brooks Mar 31, 2025
5ad01a2
Fix handling of optional feature masking
alex-jw-brooks Apr 3, 2025
fae6307
Loosen validation in processing for vLLM compatability
alex-jw-brooks Apr 3, 2025
d18f459
Formatting fixes
alex-jw-brooks Apr 3, 2025
5adc0a9
Update init structure to mirror llama
alex-jw-brooks Apr 3, 2025
e7f7af6
Make granite speech projector generic
alex-jw-brooks Apr 3, 2025
3725b04
Update test config to reflect generic projector
alex-jw-brooks Apr 3, 2025
a5216fb
Formatting fixes
alex-jw-brooks Apr 4, 2025
ff5869f
Fix typos, add license
alex-jw-brooks Apr 4, 2025
edfdfbe
Fix undefined var in input processing
alex-jw-brooks Apr 4, 2025
de0bf8b
Cleanup and expose ctc encoder
alex-jw-brooks Apr 4, 2025
1db4fd8
Add missing config docstrings
alex-jw-brooks Apr 4, 2025
e078bb9
Better var names, type hints, etc
alex-jw-brooks Apr 4, 2025
ea9381f
Set attn context size in init
alex-jw-brooks Apr 4, 2025
dc95123
Add max pos emb to encoder config
alex-jw-brooks Apr 4, 2025
a125ac8
Cleanup feature extractor
alex-jw-brooks Apr 4, 2025
bc88797
Add granite speech architecture details
alex-jw-brooks Apr 4, 2025
c4a9f64
Remove granite speech qformer ref
alex-jw-brooks Apr 4, 2025
f207aac
Add paper link, explicit calc for qkv
alex-jw-brooks Apr 9, 2025
882fb63
Calculate padding directly in depthwise conv1d init
alex-jw-brooks Apr 9, 2025
e7db05a
Raise value error instead of asserting
alex-jw-brooks Apr 9, 2025
fe8242d
Reorder class defs (classes used at top)
alex-jw-brooks Apr 9, 2025
70d84ea
Precompute relpos distances
alex-jw-brooks Apr 9, 2025
4d7e794
Run formatting
alex-jw-brooks Apr 9, 2025
c419426
Pass attention distances through forward
alex-jw-brooks Apr 9, 2025
b281bc4
Apply suggestions from code review
alex-jw-brooks Apr 10, 2025
d66fb7b
Add todo for using common batch feature extraction
alex-jw-brooks Apr 10, 2025
cce6253
Rename audios/features
alex-jw-brooks Apr 10, 2025
15ecc07
Ensure chat template may be provided to processor
alex-jw-brooks Apr 10, 2025
6819ea6
Move granite speech docs to audio models
alex-jw-brooks Apr 10, 2025
ab60ab6
Add todos for input proc refactoring
alex-jw-brooks Apr 11, 2025
ebf694a
Fix import order
alex-jw-brooks Apr 11, 2025
8ce898d
Guard torch import
alex-jw-brooks Apr 11, 2025
c0ca6bc
Use relative imports
alex-jw-brooks Apr 11, 2025
93486bf
Require torch backend for processor in granite speech
alex-jw-brooks Apr 11, 2025
677b4e5
Add backend guards in feature extractor
alex-jw-brooks Apr 11, 2025
5e6dc91
Merge branch 'main' into granite_speech
eustlb Apr 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,8 @@
title: EnCodec
- local: model_doc/fastspeech2_conformer
title: FastSpeech2Conformer
- local: model_doc/granite_speech
title: GraniteSpeech
- local: model_doc/hubert
title: Hubert
- local: model_doc/mctct
Expand Down
68 changes: 68 additions & 0 deletions docs/source/en/model_doc/granite_speech.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.

-->

# Granite Speech

<div class="flex flex-wrap space-x-1">
<img alt="PyTorch" src="https://img.shields.io/badge/PyTorch-DE3412?style=flat&logo=pytorch&logoColor=white">
</div>

## Overview
The Granite Speech model is a multimodal language model, consisting of a speech encoder, speech projector, large language model, and LoRA adapter(s). More details regarding each component for the current (Granite 3.2 Speech) model architecture may be found below.

1. Speech Encoder: A [Conformer](https://arxiv.org/abs/2005.08100) encoder trained with Connectionist Temporal Classification (CTC) on character-level targets on ASR corpora. The encoder uses block-attention and self-conditioned CTC from the middle layer.

2. Speech Projector: A query transformer (q-former) operating on the outputs of the last encoder block. The encoder and projector temporally downsample the audio features to be merged into the multimodal embeddings to be processed by the llm.

3. Large Language Model: The Granite Speech model leverages Granite LLMs, which were originally proposed in [this paper](https://arxiv.org/abs/2408.13359).

4. LoRA adapter(s): The Granite Speech model contains a modality specific LoRA, which will be enabled when audio features are provided, and disabled otherwise.


Note that most of the aforementioned components are implemented generically to enable compatability and potential integration with other model architectures in transformers.


This model was contributed by [Alexander Brooks](https://huggingface.co/abrooks9944), [Avihu Dekel](https://huggingface.co/Avihu), and [George Saon](https://huggingface.co/gsaon).

## Usage tips
- This model bundles its own LoRA adapter, which will be automatically loaded and enabled/disabled as needed during inference calls. Be sure to install [PEFT](https://github.com/huggingface/peft) to ensure the LoRA is correctly applied!

<!-- TODO (@alex-jw-brooks) Add an example here once the model compatible with the transformers implementation is released -->

## GraniteSpeechConfig

[[autodoc]] GraniteSpeechConfig


## GraniteSpeechEncoderConfig

[[autodoc]] GraniteSpeechEncoderConfig


## GraniteSpeechProcessor

[[autodoc]] GraniteSpeechProcessor


## GraniteSpeechFeatureExtractor

[[autodoc]] GraniteSpeechFeatureExtractor


## GraniteSpeechForConditionalGeneration

[[autodoc]] GraniteSpeechForConditionalGeneration
- forward
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@
from .gpt_sw3 import *
from .gptj import *
from .granite import *
from .granite_speech import *
from .granitemoe import *
from .granitemoeshared import *
from .grounding_dino import *
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@
("gptj", "GPTJConfig"),
("gptsan-japanese", "GPTSanJapaneseConfig"),
("granite", "GraniteConfig"),
("granite_speech", "GraniteSpeechConfig"),
("granitemoe", "GraniteMoeConfig"),
("granitemoeshared", "GraniteMoeSharedConfig"),
("granitevision", "LlavaNextConfig"),
Expand Down Expand Up @@ -491,6 +492,7 @@
("gptj", "GPT-J"),
("gptsan-japanese", "GPTSAN-japanese"),
("granite", "Granite"),
("granite_speech", "GraniteSpeech"),
("granitemoe", "GraniteMoeMoe"),
("granitemoeshared", "GraniteMoeSharedMoe"),
("granitevision", "LLaVA-NeXT"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/feature_extraction_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
("encodec", "EncodecFeatureExtractor"),
("flava", "FlavaFeatureExtractor"),
("glpn", "GLPNFeatureExtractor"),
("granite_speech", "GraniteSpeechFeatureExtractor"),
("groupvit", "CLIPFeatureExtractor"),
("hubert", "Wav2Vec2FeatureExtractor"),
("imagegpt", "ImageGPTFeatureExtractor"),
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@
("encoder-decoder", "EncoderDecoderModel"),
("fsmt", "FSMTForConditionalGeneration"),
("gptsan-japanese", "GPTSanJapaneseForConditionalGeneration"),
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("led", "LEDForConditionalGeneration"),
("longt5", "LongT5ForConditionalGeneration"),
("m2m_100", "M2M100ForConditionalGeneration"),
Expand All @@ -997,6 +998,7 @@

MODEL_FOR_SPEECH_SEQ_2_SEQ_MAPPING_NAMES = OrderedDict(
[
("granite_speech", "GraniteSpeechForConditionalGeneration"),
("moonshine", "MoonshineForConditionalGeneration"),
("pop2piano", "Pop2PianoForConditionalGeneration"),
("seamless_m4t", "SeamlessM4TForSpeechToText"),
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/auto/processing_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
("gemma3", "Gemma3Processor"),
("git", "GitProcessor"),
("got_ocr2", "GotOcr2Processor"),
("granite_speech", "GraniteSpeechProcessor"),
("grounding-dino", "GroundingDinoProcessor"),
("groupvit", "CLIPProcessor"),
("hubert", "Wav2Vec2Processor"),
Expand Down
29 changes: 29 additions & 0 deletions src/transformers/models/granite_speech/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2025 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import TYPE_CHECKING

from ...utils import _LazyModule
from ...utils.import_utils import define_import_structure


if TYPE_CHECKING:
from .configuration_granite_speech import *
from .feature_extraction_granite_speech import *
from .modeling_granite_speech import *
from .processing_granite_speech import *
else:
import sys

_file = globals()["__file__"]
sys.modules[__name__] = _LazyModule(__name__, _file, define_import_structure(_file), module_spec=__spec__)
197 changes: 197 additions & 0 deletions src/transformers/models/granite_speech/configuration_granite_speech.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# coding=utf-8
# Copyright 2025 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Config class for Granite Speech."""

from ...configuration_utils import PretrainedConfig
from ..auto import CONFIG_MAPPING, AutoConfig


class GraniteSpeechEncoderConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GraniteSpeechCTCEncoder`]. It is used to instantiate
a Granite Speech audio encoder according to the specified arguments, defining the model architecture. Instantiating a
configuration with the dfefaults will yield a similar configuration to that of the audio encoder of the Granite Speech
architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
input_dim (`int`, *optional*, defaults to 160):
Dimension of the first hidden layer of the encoder.
num_layers (`int`, *optional*, defaults to 10):
Number of encoder blocks.
hidden_dim (`int`, *optional*, defaults to 1024):
The size of the intermediate layers in the conformer encoder.
feedforward_mult (`int`, *optional*, defaults to 4):
Multiplier for the up/down projections in the encoder's feedforward layers;
The projections will have intermediate dim of size `hidden_dim * feedforward_mult`.
num_heads (`int`, *optional*, defaults to 8):
Number of attention heads for each attention layer in the Transformer encoder.
dim_head (`int`, *optional*, defaults to 128):
Dimension of attention heads for each attention layer in the Transformer encoder.
output_dim (`int`, *optional*, defaults to 42):
Intermediate dimension of the feedforward projections in the conformer
to be added to every other encoder block's output.
context_size (`int`, *optional*, defaults to 200):
Context size to be used in conformer attention.
max_pos_emb (`int`, *optional*, defaults to 512):
Max pos embeds to be used in attention (shaw's relative positional encoding).
dropout (`float`, *optional*, defaults to 0.1):
The dropout probability for fully connected layers in the encoder.
conv_kernel_size (`int`, *optional*, defaults to 15):
Kernel size to be used for 1D convolution in each conformer block.
conv_expansion_factor (`int`, *optional*, defaults to 2):
Intermediate dimension to be used in conformer convolutions.

Example:

```python
>>> from transformers import GraniteSpeechEncoderConfig, GraniteSpeechCTCEncoder

>>> # Initializing a GraniteSpeechEncoderConfig
>>> configuration = GraniteSpeechEncoderConfig()

>>> # Initializing a GraniteSpeechCTCEncoder (with random weights)
>>> model = GraniteSpeechCTCEncoder(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "granite_speech_encoder"

def __init__(
self,
input_dim=160,
num_layers=10,
hidden_dim=1024,
feedforward_mult=4,
num_heads=8,
dim_head=128,
output_dim=42,
context_size=200,
max_pos_emb=512,
dropout=0.1,
conv_kernel_size=15,
conv_expansion_factor=2,
**kwargs,
):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.feedforward_mult = feedforward_mult
self.num_heads = num_heads
self.dim_head = dim_head
self.output_dim = output_dim
self.context_size = context_size
self.dropout = dropout
self.conv_kernel_size = conv_kernel_size
self.conv_expansion_factor = conv_expansion_factor
self.max_pos_emb = max_pos_emb


class GraniteSpeechConfig(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a [`GraniteSpeechForConditionalGeneration`]. It is used to instantiate an
Granite Speech model according to the specified arguments, defining the model architecture.

Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
documentation from [`PretrainedConfig`] for more information.

Args:
text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `GraniteConfig`):
The config object or dictionary of the text backbone.
encoder_config (`GraniteSpeechEncoderConfig`, *optional*):
The config object or dictionary of the Granite Speech CTC Encoder.
projector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Blip2QFormerConfig`):
The config object or dictionary of the audio projector.
audio_token_index (`int`, *optional*, defaults to 49155):
The audio token index to encode the audio prompt.
initializer_range (`float`, *optional*, defaults to 0.02):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
has_lora_adapter (`bool`, *optional*, defaults to `True`):
Indicates whether or not the model has a lora adapter that should only
be activate when processing audio inputs.
downsample_rate (`int`, *optional*, defaults to 5):
Downsample rate for the audio feature extractor.
window_size (`int`, *optional*, defaults to 15):
Window size for the audio feature projector.

Example:

```python
>>> from transformers import GraniteSpeechConfig, GraniteSpeechForConditionalGeneration

>>> # Initializing a GraniteSpeechConfig
>>> configuration = GraniteSpeechConfig()

>>> # Initializing a GraniteSpeechForConditionalGeneration (with random weights)
>>> model = GraniteSpeechForConditionalGeneration(configuration)

>>> # Accessing the model configuration
>>> configuration = model.config
```"""

model_type = "granite_speech"
sub_configs = {
"text_config": AutoConfig,
"encoder_config": GraniteSpeechEncoderConfig,
"projector_config": AutoConfig,
}

def __init__(
self,
text_config=None,
encoder_config=None,
projector_config=None,
audio_token_index=49155,
initializer_range=0.02,
has_lora_adapter=True,
downsample_rate=5,
window_size=15,
**kwargs,
):
if isinstance(text_config, dict):
text_config["model_type"] = text_config["model_type"] if "model_type" in text_config else "granite"
text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
elif text_config is None:
text_config = CONFIG_MAPPING["granite"]()

if isinstance(projector_config, dict):
projector_config["model_type"] = (
projector_config["model_type"] if "model_type" in projector_config else "blip_2_qformer"
)
projector_config = CONFIG_MAPPING[projector_config["model_type"]](**projector_config)
elif projector_config is None:
projector_config = CONFIG_MAPPING["blip_2_qformer"]()

if not isinstance(encoder_config, GraniteSpeechEncoderConfig):
encoder_config = {} if encoder_config is None else encoder_config
encoder_config = GraniteSpeechEncoderConfig(**encoder_config)

self.text_config = text_config
self.encoder_config = encoder_config
self.projector_config = projector_config
self.audio_token_index = audio_token_index
self.initializer_range = initializer_range
self.has_lora_adapter = has_lora_adapter
self.downsample_rate = downsample_rate
self.window_size = window_size
super().__init__(**kwargs)


__all__ = ["GraniteSpeechEncoderConfig", "GraniteSpeechConfig"]
Loading