Skip to content

[RFC]: Decode Context Parallel for GQA #24685

@frank-wei

Description

@frank-wei

Motivation.

A follow-up to adapt DCP to GQA after #23734. The purpose is to reduce the KV cache storage on each GPU as it is split across the DCP ranks. It increase the overhead of collective communication brought by CP attention. Attention is not able to be executed by TP across the heads.

Proposed Change.

  • For decode, we need to make sure the KV from prefill is stored in interleave mode with each rank storing i % dcp_world_size. It should not change the index of group KV as they are orthogonal.
  • The attention kernel needs to be checked and support return LSE. We can start from flashinfer. FA and triton will be checked afterwards.
  • The LSE needs to be collected by allgather after blockwise based attention and correct it. The blockwise attention outputs are allreduced afterwards to get the final correct outputs. The corresponding impl is done and supposed there is no change needed.

Feedback Period.

1-2 weeks

CC List.

@youkaichao @youzhedian

Any Other Things.

Test plan is to experiment on llama3 model

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions