-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add chunk_length
parameter to Whisper
#1909
Conversation
@MahmoudAshraf97 Hi, thanks for your effort. I would take this PR into our internal gitlab. Also, we would add your name into the co-author list and credit your work on the release notes for whisper IFB feature. |
@MahmoudAshraf97 Hi, I just tried the more than 1 dynamic shape conv1d solution by setting codes below:
However, the build process failed. Seems the slice operator would need to know the value of x.shape[1]. I was wondering why you set fixed config.chunk_length here rather than let it be dynamic. |
as I mentioned in my trials in the PR, this was a step to make it work but I couldn't complete it because of the slice operator or other operators that aim to add the positional embeddings to |
@MahmoudAshraf97 I see. Thanks. Btw, the remove_input_padding for decoder issue has been fixed. The code would sync to github one week later. |
cross_attention_mask = torch.ones( | ||
[encoder_outputs.shape[0], 1, | ||
encoder_outputs.shape[1]]).int().cuda() | ||
cross_attention_mask = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @MahmoudAshraf97 if I understand correctly, you are making this change because distil-whisper can work on dynamic chunk sizes, unlike whisper which must use fixed 30 second chunks. Am I understanding correctly? Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @galv , this PR contains 2 main changes
- Encoder is no longer restricted to 30s inputs, this helps in case of distil-whisper as you mentioned
- Decoder now supports
remove_input_padding
and accepts packed inputs to save memory
closing this since it was merged |
distil-whisper models perform the best with chunk sizes less than 30s that the original whisper models use, this PR introduces the option to build the engine with a different chunk length
Summary of the changes in this PR:
remove_input_padding
in the decoderconv1d
now supports input with more than 1 dynamic shapepaged_kv_cache
using the executor, although there is no clear way to feed the encoder input and the prompt to thetensorrt_llm.bindings.Request
class as it only accepts list of tokens in all inputs, and the encoder output is a float tensorenabling
remove_input_padding
in the encoder wasn't as easy as I thought, all of my trials failed at the step where the positional embeddings are added to the conv output. chunk size is not defined at build time, this didn't work because the positional embeddings tensor first dim is 1500 which corresponds to 30s inputs. When the chunk_size is known at build time it's easy to slice the positional embeddings tensor to the correct size and add it to the conv output, but when the chunk size is unknown, the build fails at fetching the correct indices, for example:input_lengths.unbind()
fails becauseinput_lengths
shape is[-1]
removing input padding from the encoder isn't that much important TBH as we expect encoder inputs to be of the same shape and size except for the last window in an audio, it will be beneficial in scenarios where we expect the requests to be multiple audio files which all of them are less than 30s and vary a lot in length
on the other side,
remove_input_padding
is important on the decoder side because it's required to enable inflight batching, from a quick trial on a 30 min audio file, the larger the batch size, the slower the generationas we notice, the time taken increases with batch size which is counter productive for large workloads, hence the need for inflight batching