Skip to content

Commit 5987715

Browse files
committed
[Test] Adds tests to validate model generation with the full graph feature both enabled and disabled
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
1 parent 80fb280 commit 5987715

File tree

3 files changed

+16
-5
lines changed

3 files changed

+16
-5
lines changed

tests/singlecard/test_aclgraph.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,11 @@
3636
reason="aclgraph only support on v1")
3737
@pytest.mark.parametrize("model", MODELS)
3838
@pytest.mark.parametrize("max_tokens", [32])
39+
@pytest.mark.parametrize("full_graph", [False])
3940
def test_models(
4041
model: str,
4142
max_tokens: int,
43+
full_graph: bool,
4244
monkeypatch: pytest.MonkeyPatch,
4345
) -> None:
4446
with monkeypatch.context() as m:
@@ -54,7 +56,15 @@ def test_models(
5456
temperature=0.0)
5557
# TODO: change to use vllmrunner when the registry of custom op is solved
5658
# while running pytest
57-
vllm_model = LLM(model)
59+
if full_graph:
60+
vllm_model = LLM(model,
61+
compilation_config={
62+
"full_cuda_graph": True,
63+
"cudagraph_capture_sizes":
64+
[1, 4, 16, 64, 256]
65+
})
66+
else:
67+
vllm_model = LLM(model)
5868
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
5969
del vllm_model
6070
torch.npu.empty_cache()

vllm_ascend/attention/attention_v1.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,10 @@ def forward(
352352
if self.full_graph:
353353
graph_params = get_graph_params()
354354
q = query.view(num_tokens, -1, self.hidden_size)
355-
k = self.key_cache.view(-1, self.block_size,
356-
self.num_kv_heads * self.head_size)
357-
v = self.value_cache.view(
355+
k = self.key_cache.view( # type: ignore
356+
-1, self.block_size,
357+
self.num_kv_heads * self.head_size)
358+
v = self.value_cache.view( # type: ignore
358359
-1, self.block_size,
359360
self.num_kv_heads * self.head_size)
360361
actual_seq_lens = attn_metadata.seq_lens_list

vllm_ascend/attention/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class AscendCommonAttentionMetadata:
1111
cache groups and thus having different block table.
1212
"""
1313

14-
query_start_loc: Optional[torch.Tensor] = None
14+
query_start_loc: torch.Tensor = None
1515
"""(batch_size + 1,), the start location of each request in query Tensor"""
1616
seq_lens: Optional[torch.Tensor] = None
1717
"""(batch_size,), the length of each request including both computed tokens

0 commit comments

Comments
 (0)