Skip to content

Conversation

@alex-jw-brooks
Copy link
Contributor

@alex-jw-brooks alex-jw-brooks commented Mar 18, 2025

What does this PR do?

This PR adds support for (upcoming) Granite Speech models. In terms of model architecture, it uses a conformer-based encoder, with a blip2 qformer-based projector to encode the audio, and masks it into a granite llm. This model also uses an audio-specific lora adapter, which should only be enabled when the model is processing audio inputs.

Currently this is handled by using the transformers/peft mixin to load the adapter from the same directory as the base model and overriding .generate to turn the lora on/off based on the presence of audio input features before forwarding to .generate on the superclass. Doing it this way, as opposed to encapsulating a Peft Model, makes this a lot cleaner in a lot of ways and keeps the lora out of the modeling code.

(This PR is a collaboration with @avihu111 and @gsaon)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@Rocketknight1
Copy link
Member

cc @eustlb

@alex-jw-brooks alex-jw-brooks marked this pull request as ready for review March 24, 2025 22:03
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great!!!

Comment on lines 792 to 803
GraniteSpeechConformerBlock(
dim=config.hidden_dim,
dim_head=config.dim_head,
heads=config.num_heads,
ff_mult=config.feedforward_mult,
conv_expansion_factor=config.conv_expansion_factor,
conv_kernel_size=config.conv_kernel_size,
context_size=config.context_size, # attention context size
attn_dropout=config.dropout,
ff_dropout=config.dropout,
conv_dropout=config.dropout,
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the args should be passed through the config directly!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Updated the Conformer Block (and the other nn modules it creates internally, i.e., FF/attention/conv module to just pass the config 😄

Comment on lines 919 to 178
dots = einsum("b m h i d, b m h j d -> b m h i j", q, k) * self.scale

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't use einsum in transformers we should just use matmuls

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! Replaced this einsum with sdpa attention and the second (relative positions dot product) with a standard dot product.

Comment on lines 947 to 232
self.net = nn.Sequential(
nn.Linear(dim, dim * mult), nn.SiLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), nn.Dropout(dropout)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just not use sequential here an explicitly create this layer

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed and done, much cleaner this way, and tensor names are readable.

Comment on lines 962 to 232
self.net = nn.Sequential(
nn.LayerNorm(dim),
GraniteSpeechConformerPermute(dims=(0, 2, 1)),
nn.Conv1d(dim, inner_dim * 2, 1),
nn.GLU(dim=1),
GraniteSpeechConformerDepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding),
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(),
nn.SiLU(),
nn.Conv1d(inner_dim, dim, 1),
GraniteSpeechConformerPermute(dims=(0, 2, 1)),
nn.Dropout(dropout),
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would much rather have something explicit for this! 🤗

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.
I also removed the permute layer and dropped the not causal part of the logic which is never used.

Comment on lines 1024 to 1029
x = self.ff1(x) + x
x = self.attn(x, context_size) + x
x = self.conv(x) + x
x = self.ff2(x) + x
x = self.post_norm(x)
return x
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no single letter variables, let's use transformers common annothations for this (see DecoderLayer for llama for example)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done :)



# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerMultiHeadAttention with Blip2->GraniteSpeech
class GraniteSpeechQFormerMultiHeadAttention(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's not use Blip2! we should not register hooks and save attentions ! Let's rather take llama or whisper as references!

if self.melspec is None:
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)

def __call__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure this supports the batch features? Also not super aligned with our API here we use return_tensors=... and etc. Have a look at PHI4 feature extraction code to make it more aligned!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We support batched feature extraction and inference, but the batching logic is done in our Processor.
I'll re-organize this to match existing feature-extractors like Phi4. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I attempted to make the feature_extraction/preprocessor resemble the logic/namings used in Phi4.
I moved all the audio validating+batching+attention_mask logic to the audio_processor.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!!

@@ -0,0 +1,88 @@
# Copyright 2025 EleutherAI and The HuggingFace Inc. team. All rights reserved.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should use the new init format! Available for example for Llama's init.py 🤗

text_inputs = self.tokenizer(text, padding=True, **kwargs)
return BatchFeature(data={**text_inputs, **speech_inputs})

def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's align with how we do this in vision language models! (mostly just naming) I have not checked the qwen audio processor, but we should do something similar cc @eustlb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! Do you have an example for what convention should be followed here?

This is super similar to what's in the processor for Llava Next here, but we had just added separated it out from _call__ to make it a bit easier to read 🙂

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need an integration test here!!!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added! Currently it's disabled while we await the model release, but everything passes on the version I have.

Will submit a PR to turn the tests on and add an example to docs once the model is out!

@alex-jw-brooks
Copy link
Contributor Author

Awesome, thanks a lot for the quick feedback @ArthurZucker! We'll work on addressing the changes 🙂

One question - some of the things you pointed out are copied fromblip2's qformer code, which was needed to avoid depending on modeling.blip2. The current blip2 qformer code does work without any issues though. I believe the reason we did this instead of abstracting it and loading it out of the config like other models, e.g., the visual encoder in llava next (here) is that Blip2QFormerModel can't currently be created through AutoModel.from_config, since it's not an attribute on transformers.models, and it seemed like an anti pattern to add it there without also pulling it out into its own model dir 😓

Do you have any thoughts on the best way to handle this in this PR? We can keep things as is and work on the changes in the copied code if that is what makes sense, but thought it would be best to check in case you had over thoughts! AFAIK we don't really have a need on our end to abstract the encoder at the moment, just wondered what would be easiest

if is_torchaudio_available():
from transformers import GraniteSpeechFeatureExtractor, GraniteSpeechProcessor

pytest.skip("Public models not yet available", allow_module_level=True)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are disabled since there are no public models available yet, but we've verified that they all pass on our end - if it would be better to create the processor off of a config so that these can be enabled, we can definitely do that, otherwise happy to submit a follow-up PR to turn these back on with the released model later on!

</div>

## Overview
Currently being updated!
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to add an overview and example here - still working on it 🙂

Comment on lines 571 to 579
# Copied from transformers.models.blip_2.modeling_blip_2.Blip2QFormerModel._prune_heads
def _prune_heads(self, heads_to_prune):
"""
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
class PreTrainedModel
"""
for layer, heads in heads_to_prune.items():
self.encoder.layer[layer].attention.prune_heads(heads)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup! Deleted it 🙂

Comment on lines 792 to 803
GraniteSpeechConformerBlock(
dim=config.hidden_dim,
dim_head=config.dim_head,
heads=config.num_heads,
ff_mult=config.feedforward_mult,
conv_expansion_factor=config.conv_expansion_factor,
conv_kernel_size=config.conv_kernel_size,
context_size=config.context_size, # attention context size
attn_dropout=config.dropout,
ff_dropout=config.dropout,
conv_dropout=config.dropout,
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Updated the Conformer Block (and the other nn modules it creates internally, i.e., FF/attention/conv module to just pass the config 😄

text_inputs = self.tokenizer(text, padding=True, **kwargs)
return BatchFeature(data={**text_inputs, **speech_inputs})

def _expand_audio_placeholders(self, text: list[str], num_audio_features: List[int]):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! Do you have an example for what convention should be followed here?

This is super similar to what's in the processor for Llava Next here, but we had just added separated it out from _call__ to make it a bit easier to read 🙂

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have not reviewed all yet but:

The current blip2 qformer code does work without any issues though. I believe the reason we did this instead of abstracting it and loading it out of the config like other models, e.g., the visual encoder in llava next (here) is that Blip2QFormerModel can't currently be created through AutoModel.from_config, since it's not an attribute on transformers.models, and it seemed like an anti pattern to add it there without also pulling it out into its own model dir 😓

A way to fix this is just to make the qformer model part of the public API!
Something along the lines of #36493 should be possible!

Also let's make sure you create a modular_granite_speech.py model, this will help keep the model up to date and isolate differences / simplify the actual diff! 🤗 https://huggingface.co/docs/transformers/en/modular_transformers for details and don't hesitate to ping me!

)
self.ff2 = GraniteSpeechConformerFeedForward(dim=dim, mult=ff_mult, dropout=ff_dropout)
self.ff1 = GraniteSpeechConformerFeedForward(config)
self.attn = GraniteSpeechConformerAttention(config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is overwritten by the next call no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I didn't like that part either :)
Simplified those in here and here
I hope the new conformer structure looks better!

Comment on lines 972 to 974
self.attn = GraniteSpeechConformerPreNormAttn(config.hidden_dim, self.attn)
self.ff1 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(config.hidden_dim, self.ff1))
self.ff2 = GraniteSpeechConformerScale(0.5, GraniteSpeechConformerPreNorm(config.hidden_dim, self.ff2))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same :)

if self.melspec is None:
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)

def __call__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!!!

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Apr 4, 2025

Hey @ArthurZucker! I think we have addressed all of the changes you have requested, except for the modular_granite_speech part. Could you please take a look when you get the chance?

Regarding adding a modular file - I've opened a PR to expose Blip2QFormer here (and also rebased this PR on top of it). With this change, we can delete the mountain of copied Blip2 code and generically load the projector through AutoModel.from_config, similar to the vision tower in most VLMs 🎉 There is no functional diff with the existing Blip2 QFormer code, and it was only originally copied because it wasn't registered as a special auto model yet.

Given this, the model code is pretty much independent from existing transformers modules (i.e., not copying or inheriting from anything), so the code in the modular file would be identical to the current modeling code. Any thoughts on how to proceed here would be super appreciated!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! 🤗 Thanks for bearing with me!

return hidden_states


class GraniteSpeechConformerFeedForward(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be above! 🤗

return hidden_states


class GraniteSpeechConformerBlock(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definition of clases being used should be placed above!



class GraniteSpeechConformerAttention(nn.Module):
"""Attention for conformer blocks with shaw's relpos embeddings."""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a link to paper! 🤗


query_states = self.to_q(hidden_states)
key_states, value_states = self.to_kv(hidden_states).chunk(2, dim=-1)
query_states, key_states, value_states = [
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's have this written explicitly please

]

# shaw's relative positional embedding
seq = torch.arange(self.context_size, device=hidden_states.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should comput this in the Model this way each layer receives it already preopaerd no?

Copy link
Contributor Author

@alex-jw-brooks alex-jw-brooks Apr 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! I moved the calculation to precompute the clamped distances to __init__ in the ctc model and pass through forward

mask_value = -torch.finfo(pos_attn.dtype).max
pos_attn[:, -1, :].masked_fill_(mask, mask_value)

with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we usually try to have the attention_interface but should be alright

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! Happy to make a potential follow-up PR with that in the future 😄

def __init__(self, config: GraniteSpeechEncoderConfig):
super().__init__()
inner_dim = config.hidden_dim * config.conv_expansion_factor
padding = self.calc_same_padding(config.conv_kernel_size)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lets remove the function entirely!

special_audio_mask = is_audio_index.unsqueeze(-1)
audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
if input_features_mask is not None:
assert torch.all(is_audio_index.int().sum(dim=1) == input_features_mask.int().sum(dim=1)).item(), (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let gracefully raise an error!

@alex-jw-brooks
Copy link
Contributor Author

alex-jw-brooks commented Apr 9, 2025

Awesome, thank you for all the guidance @ArthurZucker! I made the changes and have verified that the skipped tests still look good on our side 🙂

Will open a follow-up PR to add an example to the docs and enable the tests once the version of the model that is happy with transformers is up on the hub!

Copy link
Contributor

@eustlb eustlb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work guys !! 🤗
Mostly comments about the processor part that can be a bit improved. Do not hesitate to take inspiration from processing_mllama.py, processing_llava.py(the API for audio should be the same as for vision) and processing_qwen2_audio.py.

Comment on lines 488 to 489
- local: model_doc/granite_speech
title: GraniteSpeech
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be under the audio section IMO since ASR is its primary intended usage

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Moved 🙂

Comment on lines +31 to +37
attributes = ["audio_processor", "tokenizer"]
valid_kwargs = ["audio_token"]

audio_processor_class = "GraniteSpeechFeatureExtractor"
tokenizer_class = "AutoTokenizer"
Copy link
Contributor

@eustlb eustlb Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also be able to set a chat_template directly in the processor (cf. MLlama processor or Qwen2-Audio processor). The idea here is to be able to call the processor directly, like in:

chat = [
    {
        "role": "system",
        "content": [
            {"type": "text", "text": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant"},
        ]
    },
    {
        "role": "user",
        "content": [
            {"type": "audio", "audio": input_speech[0]},
            {"type": "text", "text": "can you transcribe the speech into a written format?"},
        ]
    },
]
inputs = processor.apply_chat_template(messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt")

Note particularly (even if AFAIK you are already doing it) that we should, when tokenize=True expand <|AUDIO|><|AUDIO|>...<|AUDIO|> directly in the processor to avoid having to change the inputs_embeds's seq_len in the forward pass of the model.

Copy link
Contributor Author

@alex-jw-brooks alex-jw-brooks Apr 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great catch, thank you! 😄 Added and verified that the audio token gets expanded in the input processing when it's called in this way.

When we update the docs for the example, we will be sure to use the processor's chat template as shown here!

Comment on lines 92 to 106
if self.melspec is None:
self.melspec = torchaudio.transforms.MelSpectrogram(**self.melspec_kwargs)
Copy link
Contributor

@eustlb eustlb Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for now. In the future should rather use the standardized torch batched feature extraction function in audio_utils (that is to be written ahah).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That sounds great! I added a TODO here with both of our names, but please feel free to tag me as the audio utils are developed, I'm happy to try to help change this to use the common utilities 😄

batched_audio,
device=device,
)
audio_embed_sizes = self._get_num_audio_features(audio_lengths)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great !!! Really useful for the processor. Nevertheless, I would rather not have audio_embed_sizes returned and a call to get_num_audio_tokens (modify so that it takes feature_lengths and not audio_lenghts) in the processor to retrieve the correct number of <|AUDIO|> tokens to expand.

Comment on lines +321 to +330
chat = [
{
"role": "system",
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
},
{
"role": "user",
"content": "<|audio|>can you transcribe the speech into a written format?",
},
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to rather handle <|AUDIO|> token placing with the chat_template and go for something like (and the system prompt should also be handled by the chat_template BUT not applicable here in tests):

Suggested change
chat = [
{
"role": "system",
"content": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant",
},
{
"role": "user",
"content": "<|audio|>can you transcribe the speech into a written format?",
},
]
chat = [
{
"role": "system",
"content": [
{"type": "text", "text": "Knowledge Cutoff Date: April 2024.\nToday's Date: December 19, 2024.\nYou are Granite, developed by IBM. You are a helpful AI assistant"},
]
},
{
"role": "user",
"content": [
{"type": "audio", "url": "https://huggingface.co/datasets/hf-internal-testing/librispeech_asr_dummy/resolve/main/audio/1.flac"},
{"type": "text", "text": "can you transcribe the speech into a written format?"},
]
},
]

See Llava example which is a good reference of the expected API.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good! We will submit an example once the model that is compatible with transformers is released and use this API! Still working on the final processor chat template from our side 🙂

inputs_embeds = self.get_merged_audio_embeddings(
input_ids=input_ids,
audio_features=audio_features,
input_features_mask=input_features_mask,
)
Copy link
Contributor

@eustlb eustlb Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to modify so that it really takes input_features_mask (so having input_features_mask the same shape as input_features)

@eustlb eustlb merged commit 623d395 into huggingface:main Apr 11, 2025
20 checks passed
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

cyr0930 pushed a commit to cyr0930/transformers that referenced this pull request Apr 18, 2025
* First pass at speech granite

Add encoder / projector, rename things

* Combine into one model file with causal lm outputs for forward

* Add loss calc

* Fix config loading

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>

* Split new / old loading logic

* Use transformers integration for loading peft adapters

* Add generation wrapper for selective lora enablement

* Add note for qformer encoder automodel

* Guard torch/audio imports in feature extractor

* Handle granite speech autoclasses

* Handle optional deps in package structure for granite speech

* Add granite pretrained model def for init

* Add dummy objects for torch/torchaudio

* Add tests for granite speech processor

* Minor formatting fixes and refactoring

* Add options for falling back to config in forward

* Tentative model docstrings for granite speech

* Fix config type

* Remove legacy load

* Allow non-lora variants for granite speech

* Override weight tying for llm

* Use text config instead of llm config

* Add output embeddings getter to fix weight tying

* Fix relative imports

* computing the number of audio features, based on the raw audio sequence.

* collating audio inputs, and keeping the original lengths.

* asserted we have text. otherwise we can't specify the audio special token.

* assering the number of audio-symbols/audios match correctly.
running get validated_audios only when audio is present

* indentation bugfix + supporting different feature lengths when expanding audio.

* redundant, done in _get_validated_text

* adapting the tests:
- we must have text (not either audio or text)
- _get_num_audio_features takes a list of raw lengths, provided it insetad.

* Minor cleanup, remove unused import

* Add more tests for batch feature processing

* Allow setting offset in rel position embeddings

* Add config option for warning if peft is not installed w/ lora

* Port blip2 qformer code into granite speech

* Add sad test for numpy arr processing

* Allow numpy arrays / tuples in granite speech processor

* Fix config type for projector

* - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16)
- cast input_features to the model dtype (support bfloat16)

* merge Blip2QFormerConfig to GraniteSpeechProjectorConfig

* prevent a crash when re-saving/loading the model (line 109)

* consider additional edge cases during preprocessing.

* consider additional edge cases during preprocessing.

* add features mask for batched inference (bugfix)

* Minor refactor, remove multiaudio processor tests

* Add set input/output embeddings for granite speech

* Fix feature dim check in processor test

* Pop input features in embed test for granite speech

* Small fixes for test edge cases

Add granite speech to seq2seq causal lm mapping names

* Add small tests for granite speech model

* Fix data parallelism test

* Standardize model class names

* Fix check for copies

* Fix misaligned init check

* Skip granite speech in checkpoint check

* Use default for tie_word_embeddings in granite speech

* Fix non documentation granite speech repo issues

* Fix comments and docstring checks

* Add placeholder docs for granite speech

* Fix test naming collision

* Code formatting

* Rerun torch dummy obj regen

* Fix save pretrained for granite speech

* Import sorting

* Fix tests typo

* Remove offset hack

* Pass args through encoder config

* Remove unused prune heads from blip2

* removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention.

* remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention

* remove GraniteSpeechConformerScale

* rename to hidden_states

* rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous.

* move pre-norm to the attention/feedforward blocks (avoid complex module wrapping)

* adding pre_norm into forward

* feature extractor refactoring to resemble how it's done in phi4multimodal.

* rename feature_extractor to audio_processor

* bugfix: input_feature_mask fix to get the exact number tokens.

* Fix pytest decorator in processor test

* Add (disabled) integration tests for granite speech

* Fix handling of optional feature masking

* Loosen validation in processing for vLLM compatability

* Formatting fixes

* Update init structure to mirror llama

* Make granite speech projector generic

* Update test config to reflect generic projector

* Formatting fixes

* Fix typos, add license

* Fix undefined var in input processing

* Cleanup and expose ctc encoder

* Add missing config docstrings

* Better var names, type hints, etc

* Set attn context size in init

* Add max pos emb to encoder config

* Cleanup feature extractor

* Add granite speech architecture details

* Remove granite speech qformer ref

* Add paper link, explicit calc for qkv

* Calculate padding directly in depthwise conv1d init

* Raise value error instead of asserting

* Reorder class defs (classes used at top)

* Precompute relpos distances

* Run formatting

* Pass attention distances through forward

* Apply suggestions from code review

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Add todo for using common batch feature extraction

* Rename audios/features

* Ensure chat template may be provided to processor

* Move granite speech docs to audio models

* Add todos for input proc refactoring

* Fix import order

* Guard torch import

* Use relative imports

* Require torch backend for processor in granite speech

* Add backend guards in feature extractor

---------

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
* First pass at speech granite

Add encoder / projector, rename things

* Combine into one model file with causal lm outputs for forward

* Add loss calc

* Fix config loading

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>

* Split new / old loading logic

* Use transformers integration for loading peft adapters

* Add generation wrapper for selective lora enablement

* Add note for qformer encoder automodel

* Guard torch/audio imports in feature extractor

* Handle granite speech autoclasses

* Handle optional deps in package structure for granite speech

* Add granite pretrained model def for init

* Add dummy objects for torch/torchaudio

* Add tests for granite speech processor

* Minor formatting fixes and refactoring

* Add options for falling back to config in forward

* Tentative model docstrings for granite speech

* Fix config type

* Remove legacy load

* Allow non-lora variants for granite speech

* Override weight tying for llm

* Use text config instead of llm config

* Add output embeddings getter to fix weight tying

* Fix relative imports

* computing the number of audio features, based on the raw audio sequence.

* collating audio inputs, and keeping the original lengths.

* asserted we have text. otherwise we can't specify the audio special token.

* assering the number of audio-symbols/audios match correctly.
running get validated_audios only when audio is present

* indentation bugfix + supporting different feature lengths when expanding audio.

* redundant, done in _get_validated_text

* adapting the tests:
- we must have text (not either audio or text)
- _get_num_audio_features takes a list of raw lengths, provided it insetad.

* Minor cleanup, remove unused import

* Add more tests for batch feature processing

* Allow setting offset in rel position embeddings

* Add config option for warning if peft is not installed w/ lora

* Port blip2 qformer code into granite speech

* Add sad test for numpy arr processing

* Allow numpy arrays / tuples in granite speech processor

* Fix config type for projector

* - pad instead of creating a zeros tensor, to keep the original dtype/device (support bfloat16)
- cast input_features to the model dtype (support bfloat16)

* merge Blip2QFormerConfig to GraniteSpeechProjectorConfig

* prevent a crash when re-saving/loading the model (line 109)

* consider additional edge cases during preprocessing.

* consider additional edge cases during preprocessing.

* add features mask for batched inference (bugfix)

* Minor refactor, remove multiaudio processor tests

* Add set input/output embeddings for granite speech

* Fix feature dim check in processor test

* Pop input features in embed test for granite speech

* Small fixes for test edge cases

Add granite speech to seq2seq causal lm mapping names

* Add small tests for granite speech model

* Fix data parallelism test

* Standardize model class names

* Fix check for copies

* Fix misaligned init check

* Skip granite speech in checkpoint check

* Use default for tie_word_embeddings in granite speech

* Fix non documentation granite speech repo issues

* Fix comments and docstring checks

* Add placeholder docs for granite speech

* Fix test naming collision

* Code formatting

* Rerun torch dummy obj regen

* Fix save pretrained for granite speech

* Import sorting

* Fix tests typo

* Remove offset hack

* Pass args through encoder config

* Remove unused prune heads from blip2

* removing einsum. replaced with explicit multiplication (relative positional encodings) and sdpa attention.

* remove Sequential from ConformerFeedForward and ConformerConvModule. + fix for sdpa attention

* remove GraniteSpeechConformerScale

* rename to hidden_states

* rename conformer layers to self.layers, remove the first linear from the list to keep the list homogenous.

* move pre-norm to the attention/feedforward blocks (avoid complex module wrapping)

* adding pre_norm into forward

* feature extractor refactoring to resemble how it's done in phi4multimodal.

* rename feature_extractor to audio_processor

* bugfix: input_feature_mask fix to get the exact number tokens.

* Fix pytest decorator in processor test

* Add (disabled) integration tests for granite speech

* Fix handling of optional feature masking

* Loosen validation in processing for vLLM compatability

* Formatting fixes

* Update init structure to mirror llama

* Make granite speech projector generic

* Update test config to reflect generic projector

* Formatting fixes

* Fix typos, add license

* Fix undefined var in input processing

* Cleanup and expose ctc encoder

* Add missing config docstrings

* Better var names, type hints, etc

* Set attn context size in init

* Add max pos emb to encoder config

* Cleanup feature extractor

* Add granite speech architecture details

* Remove granite speech qformer ref

* Add paper link, explicit calc for qkv

* Calculate padding directly in depthwise conv1d init

* Raise value error instead of asserting

* Reorder class defs (classes used at top)

* Precompute relpos distances

* Run formatting

* Pass attention distances through forward

* Apply suggestions from code review

Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>

* Add todo for using common batch feature extraction

* Rename audios/features

* Ensure chat template may be provided to processor

* Move granite speech docs to audio models

* Add todos for input proc refactoring

* Fix import order

* Guard torch import

* Use relative imports

* Require torch backend for processor in granite speech

* Add backend guards in feature extractor

---------

Signed-off-by: Alex-Brooks <Alex.brooks@ibm.com>
Co-authored-by: Avihu Dekel <avihu.dekel@ibm.com>
Co-authored-by: eustlb <94853470+eustlb@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants