-
-
Couldn't load subscription status.
- Fork 10.8k
Description
Motivation.
The goal of this RFC is to propose a simple initial design to support multi-modal models that use cross attention for the V1 architecture. Whisper is a prime example of such a model. The design aims to be as simple as possible and easily replaceable without disrupting other ongoing V1 work. Currently in V1, the only encoder/decoder models that are supported are ones that do not use cross attention. These models use the EncoderCacheManager to communicate the outputs of the encode to the decoder. Multi-modal models that use cross attention need a separate KV cache for the encoder portion of the model. This cross attention KV cache has to be populated by the encoder and is used by the decoder's cross attention layers (as read only). The cross attention KV cache is separate from the existing decoder KV cache and has to be managed separately.
Non-goals
Since we are focusing on Whisper for the initial design, there are certain features/optimizations that can be deferred.
For now, the following optimizations will be disabled for cross attention models since they probably won't provide much benefit:
- Chunked prefill for the encoder
- Prefix caching
Supporting attention backends other than flash attention.
Abstracting the GPUModelRunner
For additional background see:
Proposed Change.
The proposed changes touch the following areas of the code:
Scheduler/RequestGPUModelRunnerFlashAttentionImpl- cross attention models, e.g.
WhisperForConditionalGeneration
Scheduler/Request
The scheduler will be updated to allocate the cross attention KV cache when the encoder portion of a encoder/decoder cross attention model is executed. This can be determined via the model config. The cache will be stored in the Request object since it is persistent (and read only by the decoder) once filled.
GPUModelRunner
Currently, in execute_model, the first stage is to generate the encoder outputs if the model is multi-modal. This stage will be updated to handle cross-attention multi-modal models.
- add an
_execute_encoder_decoder function(separate from_execute_encoder). This function will do the following:
- Construct a cross attention metadata object and initialize with the KV cache information from the
Request/SchedulerOutput. - Run a forward pass on the encoder model while updating the cross attention KV cache.
- Returns an instance of the cross-attention meta data class.
-
_execute_encoder_decoderwill be called during the first stage ofexecute_modelinstead of_execute_encoder, when the model is multi-modal and uses cross-attention. It returns an instance ofFlashAttentionMetadatathat is populated with the cross attention KV cache. Optionally, this could be moved to a separate_gather_cross_metadata, analogous to_gather_encoder_outputs. -
The
profile_run/_dummy_runfunctions also need to be updated to follow the same general steps as execute_model.
FlashAttentionImpl
Add support for AttentionType.ENCODER and AttentionType.ENCODER_DECODER.
Cross attention models, e.g. WhisperForConditionalGeneration
Add an extra AttentionMetadata parameter for the cross attention layers.
Feedback Period.
No response
CC List.
No response
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.