diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 81fabdbdfc83c..31ae0751486f5 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -166,6 +166,37 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: return self._cached_decode_metadata +def _make_alibi_bias(alibi_slopes: torch.Tensor, + dtype: torch.dtype, + seq_lens: Optional[List[int]], + make_attn_mask: bool = True) -> List[torch.Tensor]: + attn_biases = [] + if seq_lens: + for seq_len in seq_lens: + bias = torch.arange(seq_len, dtype=dtype) + # NOTE(zhuohan): HF uses + # `bias = bias[None, :].repeat(seq_len, 1)` + # here. We find that both biases give the same results, but + # the bias below more accurately follows the original ALiBi + # paper. + bias = bias[None, :] - bias[:, None] + + num_heads = alibi_slopes.shape[0] + bias = bias[None, :].repeat( + (num_heads, 1, 1)).to(alibi_slopes.device) + bias.mul_(alibi_slopes[:, None, None]) + if make_attn_mask: + inf_mask = torch.empty( + (1, seq_len, seq_len), + dtype=bias.dtype).fill_(-torch.inf).triu_(diagonal=1).to( + alibi_slopes.device) + attn_biases.append((bias + inf_mask).to(dtype)) + else: + attn_biases.append(bias.to(dtype)) + + return attn_biases + + class ROCmFlashAttentionImpl(AttentionImpl): """ If the input tensors contain prompt tokens, the layout is as follows: @@ -324,7 +355,14 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. + attn_masks = None if self.use_triton_flash_attn: + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=False) # type: ignore out, _ = self.attn_func( query, key, @@ -336,12 +374,20 @@ def forward( prefill_meta.max_prefill_seq_len, True, self.scale, + attn_masks[0][None] + if attn_masks is not None else None, ) elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) + if self.alibi_slopes is not None: + attn_masks = _make_alibi_bias( + self.alibi_slopes, + query.dtype, + attn_metadata.seq_lens, + make_attn_mask=True) # type: ignore query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -355,6 +401,7 @@ def forward( self.num_heads, self.head_size, self.scale, + attn_masks, ) else: out = self.attn_func( @@ -418,13 +465,14 @@ def _sdpa_attention( num_heads: int, head_size: int, scale: float, + attn_masks: Optional[List[torch.Tensor]] = None, ) -> torch.Tensor: start = 0 output = torch.empty((num_tokens, num_heads, head_size), dtype=query.dtype, device=query.device) - for seq_len in seq_lens: + for i, seq_len in enumerate(seq_lens): end = start + seq_len with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, @@ -434,7 +482,8 @@ def _sdpa_attention( key[:, start:end, :], value[:, start:end, :], dropout_p=0.0, - is_causal=True, + is_causal=attn_masks is None, + attn_mask=attn_masks[i] if attn_masks else None, scale=scale).movedim(query.dim() - 2, 0) output[start:end, :, :] = sub_out start = end