Skip to content

Latest commit

 

History

History
128 lines (66 loc) · 4.31 KB

MultiHeadCacheAttention.md

File metadata and controls

128 lines (66 loc) · 4.31 KB

MultiHeadCacaheAttention

The original definition of MultiHeadAttention refers to here.

The original definition of KeyValueCache refers to here.

For MultiHeadCacheAttention, it is just fuse MultiHeadAttention and KeyValueCache together.

Allows the model to jointly attend to information from different representation subspaces as described in the paper: Attention Is All You Need.

Multi-Head Attention(${\rm MHA}$) is defined as:

$${\rm MHA}(Q,K,V)=[head_1, head_2,...,head_h]$$

$$head_i={\rm softmax}(\frac{Q_iK_i^T}{\sqrt{head\_dim}})V_i$$

$Q$ is query, $K$ is key and $V$ is value.

Shape of $Q$ is $(batch, num\_heads, seqlen\_q, head\_dim)$ and shape of $K$ and $V$ are $(batch, num\_kv\_heads, seqlen\_kv, head\_dim)$.

In MultiHeadCacheAttention, key = cat(past_key, current_key) and value = cat(past_value, current_value)

But in this operator, shape of $Q$ will be $(batch, seqlen\_q, num\_heads, head\_dim)$ and shape of $K$ and $V$ will be $(batch, seqlen\_kv, num\_kv\_heads, head\_dim)$. So we need to do some transpose before applying attention.

Attributes/Parameters

num_heads: int

Number of heads

head_dim: int

Dimension of each head, where $head\_dim * num\_heads = hidden\_dim$

is_causal: bool

Whether apply casual mask when sequence length > 1.

is_alibi: bool(default: False)

Whether apply alibi mask within the operator. Do not need to set alibi mask in attn_mask when it is True

num_kv_heads: int(default: 0)

For Grouped-Query Attention. If num_kv_heads and num_heads are not equal, we should repeat key and value num_heads/num_kv_heads times before applying ${\rm MHA}$ for each token. num_heads must be divisible by num_kv_heads. Default is 0, and at this point, num_heads is used as num_kv_heads.

num_layer: int(default: 1)

Number of attention layers.

layer_idx: int(default: 0)

Attention layer index for cache and scale.

quant_bit: int(default: 0)

Quantize bit for cache compression. For example, 8 means int8 compression. 0 means disabled.

quant_group: int(default: 8)

Quantize scale shared group size. $2^n$ and $n > 2$ is recommanded for hardware implementation.

cache_layout: int(default: 0)

Define data layout of cache and scale. Default is zero.

Meaning of numbers:

  • 0: $cache(MaxB,L,2,MaxS,H,Dh)$ and $scale(MaxB,L,2,MaxS,H,Dh/quant\_group)$
  • 1: $cache(L,MaxB,2,H,MaxS,Dh)$ and $scale(L,MaxB,2,H,MaxS,Dh/quant\_group)$

Inputs

query: tensor(T1)

Input Query tensor

Shape: $(batch, seqlen\_q, num\_heads, head\_dim)$

current_key: tensor(T1)

Input Key tensor

Shape: $(batch, seqlen\_q, num\_kv\_heads, head\_dim)$

currnet_value: tensor(T1)

Input Value tensor

Shape: $(batch, seqlen\_q, num\_kv\_heads, head\_dim)$

start_pos: scalar(int64)

Sequence position where current_key and current_value begining to store.

cache: tensor(T2)

Shape: Determinated by cache_layout.

Contains key and value caches of attention layer. When cache_layout is 0, subspace $(:B,:,0,:,:,:)$ contains key caches and subspace $(:B,:,1,:,:,:)$ contains value caches. Data in this tensor will be modified.

scale(optional): tensor(T3)

Shape: Determinated by cache_layout.

Contains key and value cache quantize scales of attention layer. When cache_layout is 0, subspace $(:B,:,0,:,:,:)$ contains key cache scales and subspace $(:B,:,1,:,:,:)$ contains value cache scales. Must appear if quant_bit is not zero. Data in this tensor will be modified.

attn_mask(optional): tensor(T1)

Optional custom mask. If shape is $(seqlen\_q, >=seqlen\_kv)$, attn_mask will be broadcasted.

Note: The last dim of mask could be bigger than $seqlen\_kv$, because in some flash attention implement may force it to aligned with specific padding value.

Shape: $(seqlen\_q, >=seqlen\_kv)$ or $(num\_heads, seqlen\_q, >=seqlen\_kv)$ or $(batch, num\_heads, seqlen\_q, >=seqlen\_kv)$

Outputs

attn_output: tensor(T1)

Output feature of attention result

Shape: $(batch, seqlen\_q, num\_heads, head\_dim)$

Type Constraints

T1: float32, float16

T2: float32, float16, int8, int4

T3: float32, float16