File tree Expand file tree Collapse file tree 3 files changed +16
-5
lines changed Expand file tree Collapse file tree 3 files changed +16
-5
lines changed Original file line number Diff line number Diff line change 3636 reason = "aclgraph only support on v1" )
3737@pytest .mark .parametrize ("model" , MODELS )
3838@pytest .mark .parametrize ("max_tokens" , [32 ])
39+ @pytest .mark .parametrize ("full_graph" , [False ])
3940def test_models (
4041 model : str ,
4142 max_tokens : int ,
43+ full_graph : bool ,
4244 monkeypatch : pytest .MonkeyPatch ,
4345) -> None :
4446 with monkeypatch .context () as m :
@@ -54,7 +56,15 @@ def test_models(
5456 temperature = 0.0 )
5557 # TODO: change to use vllmrunner when the registry of custom op is solved
5658 # while running pytest
57- vllm_model = LLM (model )
59+ if full_graph :
60+ vllm_model = LLM (model ,
61+ compilation_config = {
62+ "full_cuda_graph" : True ,
63+ "cudagraph_capture_sizes" :
64+ [1 , 4 , 16 , 64 , 256 ]
65+ })
66+ else :
67+ vllm_model = LLM (model )
5868 vllm_aclgraph_outputs = vllm_model .generate (prompts , sampling_params )
5969 del vllm_model
6070 torch .npu .empty_cache ()
Original file line number Diff line number Diff line change @@ -352,9 +352,10 @@ def forward(
352352 if self .full_graph :
353353 graph_params = get_graph_params ()
354354 q = query .view (num_tokens , - 1 , self .hidden_size )
355- k = self .key_cache .view (- 1 , self .block_size ,
356- self .num_kv_heads * self .head_size )
357- v = self .value_cache .view (
355+ k = self .key_cache .view ( # type: ignore
356+ - 1 , self .block_size ,
357+ self .num_kv_heads * self .head_size )
358+ v = self .value_cache .view ( # type: ignore
358359 - 1 , self .block_size ,
359360 self .num_kv_heads * self .head_size )
360361 actual_seq_lens = attn_metadata .seq_lens_list
Original file line number Diff line number Diff line change @@ -11,7 +11,7 @@ class AscendCommonAttentionMetadata:
1111 cache groups and thus having different block table.
1212 """
1313
14- query_start_loc : Optional [ torch .Tensor ] = None
14+ query_start_loc : torch .Tensor = None
1515 """(batch_size + 1,), the start location of each request in query Tensor"""
1616 seq_lens : Optional [torch .Tensor ] = None
1717 """(batch_size,), the length of each request including both computed tokens
You can’t perform that action at this time.
0 commit comments