Skip to content

[RFC]: Initial support for multi-model models using cross attention in V1 #12761

@bnellnm

Description

@bnellnm

Motivation.

The goal of this RFC is to propose a simple initial design to support multi-modal models that use cross attention for the V1 architecture. Whisper is a prime example of such a model. The design aims to be as simple as possible and easily replaceable without disrupting other ongoing V1 work. Currently in V1, the only encoder/decoder models that are supported are ones that do not use cross attention. These models use the EncoderCacheManager to communicate the outputs of the encode to the decoder. Multi-modal models that use cross attention need a separate KV cache for the encoder portion of the model. This cross attention KV cache has to be populated by the encoder and is used by the decoder's cross attention layers (as read only). The cross attention KV cache is separate from the existing decoder KV cache and has to be managed separately.

Non-goals

Since we are focusing on Whisper for the initial design, there are certain features/optimizations that can be deferred.
For now, the following optimizations will be disabled for cross attention models since they probably won't provide much benefit:

  • Chunked prefill for the encoder
  • Prefix caching

Supporting attention backends other than flash attention.

Abstracting the GPUModelRunner

For additional background see:

Proposed Change.

The proposed changes touch the following areas of the code:

  • Scheduler/Request
  • GPUModelRunner
  • FlashAttentionImpl
  • cross attention models, e.g. WhisperForConditionalGeneration

Scheduler/Request

The scheduler will be updated to allocate the cross attention KV cache when the encoder portion of a encoder/decoder cross attention model is executed. This can be determined via the model config. The cache will be stored in the Request object since it is persistent (and read only by the decoder) once filled.

GPUModelRunner

Currently, in execute_model, the first stage is to generate the encoder outputs if the model is multi-modal. This stage will be updated to handle cross-attention multi-modal models.

  1. add an _execute_encoder_decoder function (separate from _execute_encoder). This function will do the following:
  • Construct a cross attention metadata object and initialize with the KV cache information from the Request/SchedulerOutput.
  • Run a forward pass on the encoder model while updating the cross attention KV cache.
  • Returns an instance of the cross-attention meta data class.
  1. _execute_encoder_decoder will be called during the first stage of execute_model instead of _execute_encoder, when the model is multi-modal and uses cross-attention. It returns an instance of FlashAttentionMetadata that is populated with the cross attention KV cache. Optionally, this could be moved to a separate _gather_cross_metadata, analogous to _gather_encoder_outputs.

  2. The profile_run/_dummy_run functions also need to be updated to follow the same general steps as execute_model.

FlashAttentionImpl

Add support for AttentionType.ENCODER and AttentionType.ENCODER_DECODER.

Cross attention models, e.g. WhisperForConditionalGeneration

Add an extra AttentionMetadata parameter for the cross attention layers.

Feedback Period.

No response

CC List.

No response

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Labels

RFCkeep-openPrevents stale label being applied

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions