-
Couldn't load subscription status.
- Fork 31k
Add Granite Speech Support #36801
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Granite Speech Support #36801
Conversation
|
cc @eustlb |
521bb93 to
b0fe238
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great!!!
| GraniteSpeechConformerBlock( | ||
| dim=config.hidden_dim, | ||
| dim_head=config.dim_head, | ||
| heads=config.num_heads, | ||
| ff_mult=config.feedforward_mult, | ||
| conv_expansion_factor=config.conv_expansion_factor, | ||
| conv_kernel_size=config.conv_kernel_size, | ||
| context_size=config.context_size, # attention context size | ||
| attn_dropout=config.dropout, | ||
| ff_dropout=config.dropout, | ||
| conv_dropout=config.dropout, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the args should be passed through the config directly!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Updated the Conformer Block (and the other nn modules it creates internally, i.e., FF/attention/conv module to just pass the config 😄
| dots = einsum("b m h i d, b m h j d -> b m h i j", q, k) * self.scale | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't use einsum in transformers we should just use matmuls
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done! Replaced this einsum with sdpa attention and the second (relative positions dot product) with a standard dot product.
| self.net = nn.Sequential( | ||
| nn.Linear(dim, dim * mult), nn.SiLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout) | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's just not use sequential here an explicitly create this layer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agreed and done, much cleaner this way, and tensor names are readable.
| self.net = nn.Sequential( | ||
| nn.LayerNorm(dim), | ||
| GraniteSpeechConformerPermute(dims=(0, 2, 1)), | ||
| nn.Conv1d(dim, inner_dim * 2, 1), | ||
| nn.GLU(dim=1), | ||
| GraniteSpeechConformerDepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding), | ||
| nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), | ||
| nn.SiLU(), | ||
| nn.Conv1d(inner_dim, dim, 1), | ||
| GraniteSpeechConformerPermute(dims=(0, 2, 1)), | ||
| nn.Dropout(dropout), | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would much rather have something explicit for this! 🤗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
I also removed the permute layer and dropped the not causal part of the logic which is never used.
| x = self.ff1(x) + x | ||
| x = self.attn(x, context_size) + x | ||
| x = self.conv(x) + x | ||
| x = self.ff2(x) + x | ||
| x = self.post_norm(x) | ||
| return x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no single letter variables, let's use transformers common annothations for this (see DecoderLayer for llama for example)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done :)
|
|
||
|
|
||
| # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->GraniteSpeech | ||
| class GraniteSpeechQFormerMultiHeadAttention(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's not use Blip2! we should not register hooks and save attentions ! Let's rather take llama or whisper as references!
| if self.melspec is None: | ||
| self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) | ||
|
|
||
| def __call__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure this supports the batch features? Also not super aligned with our API here we use return_tensors=... and etc. Have a look at PHI4 feature extraction code to make it more aligned!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We support batched feature extraction and inference, but the batching logic is done in our Processor.
I'll re-organize this to match existing feature-extractors like Phi4. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I attempted to make the feature_extraction/preprocessor resemble the logic/namings used in Phi4.
I moved all the audio validating+batching+attention_mask logic to the audio_processor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!!
| @@ -0,0 +1,88 @@ | |||
| # Copyright 2025 EleutherAI and The HuggingFace Inc. team. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should use the new init format! Available for example for Llama's init.py 🤗
| text_inputs = self.tokenizer(text, padding=True, **kwargs) | ||
| return BatchFeature(data={**text_inputs, **speech_inputs}) | ||
|
|
||
| def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's align with how we do this in vision language models! (mostly just naming) I have not checked the qwen audio processor, but we should do something similar cc @eustlb
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me! Do you have an example for what convention should be followed here?
This is super similar to what's in the processor for Llava Next here, but we had just added separated it out from _call__ to make it a bit easier to read 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we need an integration test here!!!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added! Currently it's disabled while we await the model release, but everything passes on the version I have.
Will submit a PR to turn the tests on and add an example to docs once the model is out!
|
Awesome, thanks a lot for the quick feedback @ArthurZucker! We'll work on addressing the changes 🙂 One question - some of the things you pointed out are copied from Do you have any thoughts on the best way to handle this in this PR? We can keep things as is and work on the changes in the copied code if that is what makes sense, but thought it would be best to check in case you had over thoughts! AFAIK we don't really have a need on our end to abstract the encoder at the moment, just wondered what would be easiest |
| if is_torchaudio_available(): | ||
| from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor | ||
|
|
||
| pytest.skip("Public models not yet available", allow_module_level=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These are disabled since there are no public models available yet, but we've verified that they all pass on our end - if it would be better to create the processor off of a config so that these can be enabled, we can definitely do that, otherwise happy to submit a follow-up PR to turn these back on with the released model later on!
| </div> | ||
|
|
||
| ## Overview | ||
| Currently being updated! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need to add an overview and example here - still working on it 🙂
| # Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel._prune_heads | ||
| def _prune_heads(self, heads_to_prune): | ||
| """ | ||
| Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base | ||
| class PreTrainedModel | ||
| """ | ||
| for layer, heads in heads_to_prune.items(): | ||
| self.encoder.layer[layer].attention.prune_heads(heads) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup! Deleted it 🙂
| GraniteSpeechConformerBlock( | ||
| dim=config.hidden_dim, | ||
| dim_head=config.dim_head, | ||
| heads=config.num_heads, | ||
| ff_mult=config.feedforward_mult, | ||
| conv_expansion_factor=config.conv_expansion_factor, | ||
| conv_kernel_size=config.conv_kernel_size, | ||
| context_size=config.context_size, # attention context size | ||
| attn_dropout=config.dropout, | ||
| ff_dropout=config.dropout, | ||
| conv_dropout=config.dropout, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Updated the Conformer Block (and the other nn modules it creates internally, i.e., FF/attention/conv module to just pass the config 😄
| text_inputs = self.tokenizer(text, padding=True, **kwargs) | ||
| return BatchFeature(data={**text_inputs, **speech_inputs}) | ||
|
|
||
| def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good to me! Do you have an example for what convention should be followed here?
This is super similar to what's in the processor for Llava Next here, but we had just added separated it out from _call__ to make it a bit easier to read 🙂
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have not reviewed all yet but:
The current blip2 qformer code does work without any issues though. I believe the reason we did this instead of abstracting it and loading it out of the config like other models, e.g., the visual encoder in llava next (here) is that Blip2QFormerModel can't currently be created through AutoModel.from_config, since it's not an attribute on transformers.models, and it seemed like an anti pattern to add it there without also pulling it out into its own model dir 😓
A way to fix this is just to make the qformer model part of the public API!
Something along the lines of #36493 should be possible!
Also let's make sure you create a modular_granite_speech.py model, this will help keep the model up to date and isolate differences / simplify the actual diff! 🤗 https://huggingface.co/docs/transformers/en/modular_transformers for details and don't hesitate to ping me!
| ) | ||
| self.ff2 = GraniteSpeechConformerFeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout) | ||
| self.ff1 = GraniteSpeechConformerFeedForward(config) | ||
| self.attn = GraniteSpeechConformerAttention(config) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is overwritten by the next call no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| self.attn = GraniteSpeechConformerPreNormAttn(config.hidden_dim, self.attn) | ||
| self.ff1 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(config.hidden_dim, self.ff1)) | ||
| self.ff2 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(config.hidden_dim, self.ff2)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same :)
| if self.melspec is None: | ||
| self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) | ||
|
|
||
| def __call__( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!!!
652e186 to
44307c3
Compare
|
Hey @ArthurZucker! I think we have addressed all of the changes you have requested, except for the Regarding adding a modular file - I've opened a PR to expose Given this, the model code is pretty much independent from existing |
1ffb9b8 to
e640da5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! 🤗 Thanks for bearing with me!
| return hidden_states | ||
|
|
||
|
|
||
| class GraniteSpeechConformerFeedForward(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this should be above! 🤗
| return hidden_states | ||
|
|
||
|
|
||
| class GraniteSpeechConformerBlock(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
definition of clases being used should be placed above!
|
|
||
|
|
||
| class GraniteSpeechConformerAttention(nn.Module): | ||
| """Attention for conformer blocks with shaw's relpos embeddings.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have a link to paper! 🤗
|
|
||
| query_states = self.to_q(hidden_states) | ||
| key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1) | ||
| query_states, key_states, value_states = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's have this written explicitly please
| ] | ||
|
|
||
| # shaw's relative positional embedding | ||
| seq = torch.arange(self.context_size, device=hidden_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should comput this in the Model this way each layer receives it already preopaerd no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! I moved the calculation to precompute the clamped distances to __init__ in the ctc model and pass through forward
| mask_value = -torch.finfo(pos_attn.dtype).max | ||
| pos_attn[:, -1, :].masked_fill_(mask, mask_value) | ||
|
|
||
| with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we usually try to have the attention_interface but should be alright
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! Happy to make a potential follow-up PR with that in the future 😄
| def __init__(self, config: GraniteSpeechEncoderConfig): | ||
| super().__init__() | ||
| inner_dim = config.hidden_dim * config.conv_expansion_factor | ||
| padding = self.calc_same_padding(config.conv_kernel_size) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lets remove the function entirely!
| special_audio_mask = is_audio_index.unsqueeze(-1) | ||
| audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype) | ||
| if input_features_mask is not None: | ||
| assert torch.all(is_audio_index.int().sum(dim=1) == input_features_mask.int().sum(dim=1)).item(), ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let gracefully raise an error!
|
Awesome, thank you for all the guidance @ArthurZucker! I made the changes and have verified that the skipped tests still look good on our side 🙂 Will open a follow-up PR to add an example to the docs and enable the tests once the version of the model that is happy with transformers is up on the hub! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work guys !! 🤗
Mostly comments about the processor part that can be a bit improved. Do not hesitate to take inspiration from processing_mllama.py, processing_llava.py(the API for audio should be the same as for vision) and processing_qwen2_audio.py.
docs/source/en/_toctree.yml
Outdated
| - local: model_doc/granite_speech | ||
| title: GraniteSpeech |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be under the audio section IMO since ASR is its primary intended usage
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good point! Moved 🙂
src/transformers/models/granite_speech/feature_extraction_granite_speech.py
Outdated
Show resolved
Hide resolved
src/transformers/models/granite_speech/processing_granite_speech.py
Outdated
Show resolved
Hide resolved
| attributes = ["audio_processor", "tokenizer"] | ||
| valid_kwargs = ["audio_token"] | ||
|
|
||
| audio_processor_class = "GraniteSpeechFeatureExtractor" | ||
| tokenizer_class = "AutoTokenizer" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also be able to set a chat_template directly in the processor (cf. MLlama processor or Qwen2-Audio processor). The idea here is to be able to call the processor directly, like in:
chat = [
{
"role": "system",
"content": [
{"type": "text", "text": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant"},
]
},
{
"role": "user",
"content": [
{"type": "audio", "audio": input_speech[0]},
{"type": "text", "text": "can you transcribe the speech into a written format?"},
]
},
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")Note particularly (even if AFAIK you are already doing it) that we should, when tokenize=True expand <|AUDIO|> → <|AUDIO|>...<|AUDIO|> directly in the processor to avoid having to change the inputs_embeds's seq_len in the forward pass of the model.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great catch, thank you! 😄 Added and verified that the audio token gets expanded in the input processing when it's called in this way.
When we update the docs for the example, we will be sure to use the processor's chat template as shown here!
| if self.melspec is None: | ||
| self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for now. In the future should rather use the standardized torch batched feature extraction function in audio_utils (that is to be written ahah).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That sounds great! I added a TODO here with both of our names, but please feel free to tag me as the audio utils are developed, I'm happy to try to help change this to use the common utilities 😄
| batched_audio, | ||
| device=device, | ||
| ) | ||
| audio_embed_sizes = self._get_num_audio_features(audio_lengths) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great !!! Really useful for the processor. Nevertheless, I would rather not have audio_embed_sizes returned and a call to get_num_audio_tokens (modify so that it takes feature_lengths and not audio_lenghts) in the processor to retrieve the correct number of <|AUDIO|> tokens to expand.
| chat = [ | ||
| { | ||
| "role": "system", | ||
| "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant", | ||
| }, | ||
| { | ||
| "role": "user", | ||
| "content": "<|audio|>can you transcribe the speech into a written format?", | ||
| }, | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We want to rather handle <|AUDIO|> token placing with the chat_template and go for something like (and the system prompt should also be handled by the chat_template BUT not applicable here in tests):
| chat = [ | |
| { | |
| "role": "system", | |
| "content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant", | |
| }, | |
| { | |
| "role": "user", | |
| "content": "<|audio|>can you transcribe the speech into a written format?", | |
| }, | |
| ] | |
| chat = [ | |
| { | |
| "role": "system", | |
| "content": [ | |
| {"type": "text", "text": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant"}, | |
| ] | |
| }, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/resolve/main/audio/1.flac"}, | |
| {"type": "text", "text": "can you transcribe the speech into a written format?"}, | |
| ] | |
| }, | |
| ] |
See Llava example which is a good reference of the expected API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! We will submit an example once the model that is compatible with transformers is released and use this API! Still working on the final processor chat template from our side 🙂
src/transformers/models/granite_speech/modeling_granite_speech.py
Outdated
Show resolved
Hide resolved
| inputs_embeds = self.get_merged_audio_embeddings( | ||
| input_ids=input_ids, | ||
| audio_features=audio_features, | ||
| input_features_mask=input_features_mask, | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
to modify so that it really takes input_features_mask (so having input_features_mask the same shape as input_features)
402dbf0 to
e6eafce
Compare
Add encoder / projector, rename things
Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
a899836 to
677b4e5
Compare
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
* First pass at speech granite Add encoder / projector, rename things * Combine into one model file with causal lm outputs for forward * Add loss calc * Fix config loading Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> * Split new / old loading logic * Use transformers integration for loading peft adapters * Add generation wrapper for selective lora enablement * Add note for qformer encoder automodel * Guard torch/audio imports in feature extractor * Handle granite speech autoclasses * Handle optional deps in package structure for granite speech * Add granite pretrained model def for init * Add dummy objects for torch/torchaudio * Add tests for granite speech processor * Minor formatting fixes and refactoring * Add options for falling back to config in forward * Tentative model docstrings for granite speech * Fix config type * Remove legacy load * Allow non-lora variants for granite speech * Override weight tying for llm * Use text config instead of llm config * Add output embeddings getter to fix weight tying * Fix relative imports * computing the number of audio features, based on the raw audio sequence. * collating audio inputs, and keeping the original lengths. * asserted we have text. otherwise we can't specify the audio special token. * assering the number of audio-symbols/audios match correctly. running get validated_audios only when audio is present * indentation bugfix + supporting different feature lengths when expanding audio. * redundant, done in _get_validated_text * adapting the tests: - we must have text (not either audio or text) - _get_num_audio_features takes a list of raw lengths, provided it insetad. * Minor cleanup, remove unused import * Add more tests for batch feature processing * Allow setting offset in rel position embeddings * Add config option for warning if peft is not installed w/ lora * Port blip2 qformer code into granite speech * Add sad test for numpy arr processing * Allow numpy arrays / tuples in granite speech processor * Fix config type for projector * - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16) - cast input_features to the model dtype (support bfloat16) * merge Blip2QFormerConfig to GraniteSpeechProjectorConfig * prevent a crash when re-saving/loading the model (line 109) * consider additional edge cases during preprocessing. * consider additional edge cases during preprocessing. * add features mask for batched inference (bugfix) * Minor refactor, remove multiaudio processor tests * Add set input/output embeddings for granite speech * Fix feature dim check in processor test * Pop input features in embed test for granite speech * Small fixes for test edge cases Add granite speech to seq2seq causal lm mapping names * Add small tests for granite speech model * Fix data parallelism test * Standardize model class names * Fix check for copies * Fix misaligned init check * Skip granite speech in checkpoint check * Use default for tie_word_embeddings in granite speech * Fix non documentation granite speech repo issues * Fix comments and docstring checks * Add placeholder docs for granite speech * Fix test naming collision * Code formatting * Rerun torch dummy obj regen * Fix save pretrained for granite speech * Import sorting * Fix tests typo * Remove offset hack * Pass args through encoder config * Remove unused prune heads from blip2 * removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention. * remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention * remove GraniteSpeechConformerScale * rename to hidden_states * rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous. * move pre-norm to the attention/feedforward blocks (avoid complex module wrapping) * adding pre_norm into forward * feature extractor refactoring to resemble how it's done in phi4multimodal. * rename feature_extractor to audio_processor * bugfix: input_feature_mask fix to get the exact number tokens. * Fix pytest decorator in processor test * Add (disabled) integration tests for granite speech * Fix handling of optional feature masking * Loosen validation in processing for vLLM compatability * Formatting fixes * Update init structure to mirror llama * Make granite speech projector generic * Update test config to reflect generic projector * Formatting fixes * Fix typos, add license * Fix undefined var in input processing * Cleanup and expose ctc encoder * Add missing config docstrings * Better var names, type hints, etc * Set attn context size in init * Add max pos emb to encoder config * Cleanup feature extractor * Add granite speech architecture details * Remove granite speech qformer ref * Add paper link, explicit calc for qkv * Calculate padding directly in depthwise conv1d init * Raise value error instead of asserting * Reorder class defs (classes used at top) * Precompute relpos distances * Run formatting * Pass attention distances through forward * Apply suggestions from code review Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> * Add todo for using common batch feature extraction * Rename audios/features * Ensure chat template may be provided to processor * Move granite speech docs to audio models * Add todos for input proc refactoring * Fix import order * Guard torch import * Use relative imports * Require torch backend for processor in granite speech * Add backend guards in feature extractor --------- Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
* First pass at speech granite Add encoder / projector, rename things * Combine into one model file with causal lm outputs for forward * Add loss calc * Fix config loading Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> * Split new / old loading logic * Use transformers integration for loading peft adapters * Add generation wrapper for selective lora enablement * Add note for qformer encoder automodel * Guard torch/audio imports in feature extractor * Handle granite speech autoclasses * Handle optional deps in package structure for granite speech * Add granite pretrained model def for init * Add dummy objects for torch/torchaudio * Add tests for granite speech processor * Minor formatting fixes and refactoring * Add options for falling back to config in forward * Tentative model docstrings for granite speech * Fix config type * Remove legacy load * Allow non-lora variants for granite speech * Override weight tying for llm * Use text config instead of llm config * Add output embeddings getter to fix weight tying * Fix relative imports * computing the number of audio features, based on the raw audio sequence. * collating audio inputs, and keeping the original lengths. * asserted we have text. otherwise we can't specify the audio special token. * assering the number of audio-symbols/audios match correctly. running get validated_audios only when audio is present * indentation bugfix + supporting different feature lengths when expanding audio. * redundant, done in _get_validated_text * adapting the tests: - we must have text (not either audio or text) - _get_num_audio_features takes a list of raw lengths, provided it insetad. * Minor cleanup, remove unused import * Add more tests for batch feature processing * Allow setting offset in rel position embeddings * Add config option for warning if peft is not installed w/ lora * Port blip2 qformer code into granite speech * Add sad test for numpy arr processing * Allow numpy arrays / tuples in granite speech processor * Fix config type for projector * - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16) - cast input_features to the model dtype (support bfloat16) * merge Blip2QFormerConfig to GraniteSpeechProjectorConfig * prevent a crash when re-saving/loading the model (line 109) * consider additional edge cases during preprocessing. * consider additional edge cases during preprocessing. * add features mask for batched inference (bugfix) * Minor refactor, remove multiaudio processor tests * Add set input/output embeddings for granite speech * Fix feature dim check in processor test * Pop input features in embed test for granite speech * Small fixes for test edge cases Add granite speech to seq2seq causal lm mapping names * Add small tests for granite speech model * Fix data parallelism test * Standardize model class names * Fix check for copies * Fix misaligned init check * Skip granite speech in checkpoint check * Use default for tie_word_embeddings in granite speech * Fix non documentation granite speech repo issues * Fix comments and docstring checks * Add placeholder docs for granite speech * Fix test naming collision * Code formatting * Rerun torch dummy obj regen * Fix save pretrained for granite speech * Import sorting * Fix tests typo * Remove offset hack * Pass args through encoder config * Remove unused prune heads from blip2 * removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention. * remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention * remove GraniteSpeechConformerScale * rename to hidden_states * rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous. * move pre-norm to the attention/feedforward blocks (avoid complex module wrapping) * adding pre_norm into forward * feature extractor refactoring to resemble how it's done in phi4multimodal. * rename feature_extractor to audio_processor * bugfix: input_feature_mask fix to get the exact number tokens. * Fix pytest decorator in processor test * Add (disabled) integration tests for granite speech * Fix handling of optional feature masking * Loosen validation in processing for vLLM compatability * Formatting fixes * Update init structure to mirror llama * Make granite speech projector generic * Update test config to reflect generic projector * Formatting fixes * Fix typos, add license * Fix undefined var in input processing * Cleanup and expose ctc encoder * Add missing config docstrings * Better var names, type hints, etc * Set attn context size in init * Add max pos emb to encoder config * Cleanup feature extractor * Add granite speech architecture details * Remove granite speech qformer ref * Add paper link, explicit calc for qkv * Calculate padding directly in depthwise conv1d init * Raise value error instead of asserting * Reorder class defs (classes used at top) * Precompute relpos distances * Run formatting * Pass attention distances through forward * Apply suggestions from code review Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com> * Add todo for using common batch feature extraction * Rename audios/features * Ensure chat template may be provided to processor * Move granite speech docs to audio models * Add todos for input proc refactoring * Fix import order * Guard torch import * Use relative imports * Require torch backend for processor in granite speech * Add backend guards in feature extractor --------- Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com> Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com> Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
What does this PR do?
This PR adds support for (upcoming) Granite Speech models. In terms of model architecture, it uses a conformer-based encoder, with a blip2 qformer-based projector to encode the audio, and masks it into a granite llm. This model also uses an audio-specific lora adapter, which should only be enabled when the model is processing audio inputs.
Currently this is handled by using the transformers/peft mixin to load the adapter from the same directory as the base model and overriding
.generateto turn the lora on/off based on the presence of audio input features before forwarding to.generateon the superclass. Doing it this way, as opposed to encapsulating a Peft Model, makes this a lot cleaner in a lot of ways and keeps the lora out of the modeling code.(This PR is a collaboration with @avihu111 and @gsaon)
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.