Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion tests/singlecard/test_aclgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,11 @@
reason="aclgraph only support on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("full_graph", [False])
def test_models(
model: str,
max_tokens: int,
full_graph: bool,
monkeypatch: pytest.MonkeyPatch,
) -> None:
with monkeypatch.context() as m:
Expand All @@ -54,7 +56,15 @@ def test_models(
temperature=0.0)
# TODO: change to use vllmrunner when the registry of custom op is solved
# while running pytest
vllm_model = LLM(model)
if full_graph:
vllm_model = LLM(model,
compilation_config={
"full_cuda_graph": True,
"cudagraph_capture_sizes":
[1, 4, 16, 64, 256]
})
else:
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
Expand Down
4 changes: 4 additions & 0 deletions vllm_ascend/ascend_forward_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ def set_ascend_forward_context(

forward_context.in_profile_run = in_profile_run

# NOTE: This cannot be set using set_forward_context
# due to multiple warmups before actual capturing
forward_context.capturing = False

dp_world_size = get_dp_group().world_size
if dp_world_size > 1 and forward_context.dp_metadata is not None:
forward_context.max_tokens_across_dp = forward_context.dp_metadata.max_tokens_across_dp_cpu.item(
Expand Down
171 changes: 137 additions & 34 deletions vllm_ascend/attention/attention_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,16 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer, AttentionType)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
from vllm.utils import direct_register_custom_op
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch

from vllm_ascend.attention.utils import \
AscendCommonAttentionMetadata as CommonAttentionMetadata
from vllm_ascend.ops.attention import vanilla_chunked_prefill
from vllm_ascend.utils import get_graph_params


class AscendAttentionBackend(AttentionBackend):
Expand Down Expand Up @@ -114,6 +118,7 @@ class AscendMetadata:
query_start_loc: torch.Tensor
query_lens: torch.Tensor
seq_lens: torch.Tensor
seq_lens_list: list
# Maximum query length in the batch. None for decoding.
max_query_len: Optional[int] = None
# (num_tokens,). The indices of the token slots that input tokens will be
Expand Down Expand Up @@ -149,37 +154,69 @@ def build(self,
num_reqs,
num_actual_tokens,
max_query_len,
common_prefix_len,
enable_dbo_across_dp: bool = False):
common_attn_metadata: CommonAttentionMetadata,
enable_dbo_across_dp: bool = False,
*args,
**kwargs):

block_table = self.runner.input_batch.block_table[0].get_device_tensor(
)
block_table[:num_reqs, :self.runner.max_num_blocks_per_req] = (
block_table[:num_reqs])

query_lens = self.runner.query_lens
seq_lens = self.runner.seq_lens_cpu[:num_reqs]
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True)
query_start_loc = common_attn_metadata.query_start_loc
seq_lens = common_attn_metadata.seq_lens
# TODO: Refactor these two param to common metadata in runners,
# preparing for the hybrid KV groups feature
query_lens = common_attn_metadata.query_lens if common_attn_metadata.query_lens is not None else self.runner.query_lens
seq_lens_list = common_attn_metadata.seq_lens_list if common_attn_metadata.seq_lens_list is not None else self.runner.seq_lens_list

slot_mapping = self.runner.slot_mapping[:num_actual_tokens]
attn_mask = self.runner.attn_mask
attn_state = self.runner.attn_state
query_start_loc_cpu = self.runner.query_start_loc_cpu[:num_reqs + 1]
query_start_loc = query_start_loc_cpu.to(self.runner.device,
non_blocking=True)

attn_metadata = AscendMetadata(
num_actual_tokens=num_actual_tokens,
block_tables=block_table,
query_start_loc=query_start_loc,
query_lens=query_lens,
seq_lens=seq_lens,
seq_lens_list=seq_lens_list,
max_query_len=max_query_len,
slot_mapping=slot_mapping,
attn_mask=attn_mask,
attn_state=attn_state,
enable_dbo_across_dp=enable_dbo_across_dp)
return attn_metadata

def build_dummy_metadata(self, num_actual_tokens, num_reqs,
num_scheduled_tokens, attn_state):
if attn_state == AscendAttentionState.DecodeOnly:
# NOTE: We only need to pay attention to seq_lens_list and block_table here
common_attn_metadata = CommonAttentionMetadata(seq_lens_list=[2] *
num_reqs)

block_table = self.runner.input_batch.block_table[0].block_table
block_table[:num_reqs, 0] = torch.arange(1,
num_reqs + 1,
device=block_table.device,
dtype=block_table.dtype)

attn_metadata = self.build(
num_reqs=num_reqs,
num_actual_tokens=num_actual_tokens,
max_query_len=num_scheduled_tokens.max(),
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise NotImplementedError(
"Currently we only support building dummy metadata for DecodeOnly state"
)

attn_metadata.attn_state = attn_state
return attn_metadata


class AscendAttentionBackendImpl(AttentionImpl):

Expand Down Expand Up @@ -217,6 +254,10 @@ def __init__(
self.key_cache = None
self.value_cache = None

vllm_config = get_current_vllm_config()
self.full_graph = vllm_config.compilation_config.full_cuda_graph
self.block_size = vllm_config.cache_config.block_size

def forward(
self,
layer: AttentionLayer,
Expand All @@ -228,21 +269,7 @@ def forward(
output: Optional[torch.Tensor] = None,
trace_flag: bool = True,
) -> torch.Tensor:
"""Forward pass with Ascend attention.
Args:
query: shape = [batch_size, seq_len, num_heads * head_size]
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
kv_cache: shape = [2, num_blocks, block_size,
num_kv_heads, head_size]
key_cache = [num_blocks, block_size,
num_kv_heads, head_size]
value_cache = [num_blocks, block_size,
num_kv_heads, head_size]
attn_metadata: Metadata for attention.
Returns:
shape = [batch_size * seq_len, num_heads, head_size]
"""
"""Forward pass with Ascend attention."""
num_tokens = query.shape[0]
if output is None:
output = torch.empty(num_tokens,
Expand Down Expand Up @@ -322,16 +349,92 @@ def forward(
scale_value=self.scale,
out=output)
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
if self.full_graph:
graph_params = get_graph_params()
q = query.view(num_tokens, -1, self.hidden_size)
k = self.key_cache.view( # type: ignore
-1, self.block_size,
self.num_kv_heads * self.head_size)
v = self.value_cache.view( # type: ignore
-1, self.block_size,
self.num_kv_heads * self.head_size)
actual_seq_lens = attn_metadata.seq_lens_list
attn_args = {
"query": q,
"key": k,
"value": v,
"actual_seq_lengths_kv": actual_seq_lens,
"block_table": attn_metadata.block_tables,
"num_heads": self.num_heads,
"scale": self.scale,
"input_layout": "BSH",
"num_key_value_heads": self.num_kv_heads,
"block_size": self.block_size,
}

# Prepare tensors for attention output
# TODO: Refactor this to step-level instead of layer-level
attn_output = torch.empty(num_tokens,
1,
self.hidden_size,
dtype=output.dtype,
device=output.device)
softmax_lse = torch.empty(num_tokens,
dtype=output.dtype,
device=output.device)

# Get workspace from cache or calculate it if not present.
workspace = graph_params.workspaces.get(num_tokens)
if workspace is None:
workspace = torch_npu._npu_fused_infer_attention_score_get_max_workspace(
**attn_args)
graph_params.workspaces[num_tokens] = workspace

forward_context = get_forward_context()
if not forward_context.capturing:
# Execute attention kernel directly in non-capturing mode
torch.ops.npu.npu_fused_infer_attention_score.out(
workspace=workspace,
out=[attn_output, softmax_lse],
**attn_args)
else:
# Handle graph capturing mode
stream = torch_npu.npu.current_stream()

event = torch.npu.ExternalEvent()
event.wait(stream)
event.reset(stream)
graph_params.events[num_tokens].append(event)

graph_params.attn_params[num_tokens].append(
(q, k, v, actual_seq_lens,
attn_metadata.block_tables, self.num_heads,
self.scale, self.num_kv_heads, attn_output,
softmax_lse))

torch.npu.graph_task_group_begin(stream)
torch.ops.npu.npu_fused_infer_attention_score.out(
workspace=workspace,
out=[attn_output, softmax_lse],
**attn_args)
handle = torch.npu.graph_task_group_end(stream)
graph_params.handles[num_tokens].append(handle)

# Reshape output to match the expected format
output.copy_(
attn_output.view(num_tokens, self.num_heads,
self.head_size))
else:
torch_npu._npu_paged_attention(
query=query,
key_cache=self.key_cache,
value_cache=self.value_cache,
num_kv_heads=self.num_kv_heads,
num_heads=self.num_heads,
scale_value=self.scale,
block_table=attn_metadata.block_tables,
context_lens=attn_metadata.seq_lens,
out=output)
# Normal V1 situation.
else:
# use chunked prefill for head size 192 scenario, like deepseek
Expand Down
16 changes: 2 additions & 14 deletions vllm_ascend/attention/mla_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from vllm_ascend import envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import \
AscendCommonAttentionMetadata as CommonAttentionMetadata
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
from vllm_ascend.multistream.context import get_multistream_comm_context
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
Expand All @@ -28,20 +30,6 @@
from vllm.v1.worker.gpu_input_batch import InputBatch


@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""

query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""


class AscendMLABackend(AttentionBackend):

accept_output_buffer: bool = True
Expand Down
23 changes: 23 additions & 0 deletions vllm_ascend/attention/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass
from typing import Optional

import torch


@dataclass
class AscendCommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""

query_start_loc: torch.Tensor = None
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: Optional[torch.Tensor] = None
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
query_lens: Optional[torch.Tensor] = None
"""(batch_size,), the length of each request including only the newly
scheduled tokens"""
seq_lens_list: Optional[list] = None
"""(num_input_tokens,), note that this is specifically for FIA kernel"""
Loading