1+ # SPDX-License-Identifier: Apache-2.0
2+ """
3+ Test the piecewise compilation with a simple model so that we
4+ can exactly calculate the expected output and side effects.
5+ """
6+
7+ import os
8+
9+ import torch
10+ import vllm_ascend # noqa: F401
11+ from torch import nn
12+ from torch .library import Library
13+ from torch_npu .contrib import transfer_to_npu # noqa: F401
14+ from vllm .compilation .counter import compilation_counter
15+ from vllm .compilation .decorators import support_torch_compile
16+ from vllm .config import (CompilationConfig , CompilationLevel , VllmConfig ,
17+ set_current_vllm_config ,)
18+ from vllm .utils import direct_register_custom_op
19+
20+ global_counter = 0
21+
22+ # create a library to hold the custom op
23+ silly_lib = Library ("silly" , "FRAGMENT" ) # noqa
24+
25+
26+ def silly_attention (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
27+ out : torch .Tensor ) -> None :
28+ global global_counter
29+ global_counter += 1
30+ print (f"{ global_counter = } " )
31+ out .copy_ (q )
32+ out [0 ] += 1
33+
34+
35+ def silly_attention_fake (q : torch .Tensor , k : torch .Tensor , v : torch .Tensor ,
36+ out : torch .Tensor ) -> None :
37+ return
38+
39+
40+ direct_register_custom_op (
41+ op_name = "attention" ,
42+ op_func = silly_attention ,
43+ mutates_args = ["out" ],
44+ fake_impl = silly_attention_fake ,
45+ dispatch_key = "PrivateUse1" ,
46+ target_lib = silly_lib ,
47+ )
48+
49+
50+ @support_torch_compile
51+ class SillyModel (nn .Module ):
52+
53+ def __init__ (self ,
54+ * ,
55+ vllm_config : VllmConfig ,
56+ prefix : str = '' ,
57+ ** kwargs ) -> None :
58+ super ().__init__ ()
59+
60+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
61+ """
62+ Overall effect:
63+ x += 1
64+ x[0] += 2
65+ global_counter += 2
66+ """
67+ x = x + 1
68+ x = x + 2
69+ out = torch .empty_like (x )
70+ torch .ops .silly .attention (x , x , x , out )
71+ x = out
72+ x = x - 2
73+ x = x - 1
74+ out = torch .empty_like (x )
75+ torch .ops .silly .attention (x , x , x , out )
76+ x = out
77+ x = x + 1
78+ return x
79+
80+
81+ def test_simple_piecewise_compile ():
82+
83+ vllm_config = VllmConfig (compilation_config = CompilationConfig (
84+ level = CompilationLevel .PIECEWISE ,
85+ use_inductor = False ,
86+ use_cudagraph = True ,
87+ splitting_ops = ["silly.attention" ],
88+ cudagraph_copy_inputs = True ,
89+ cudagraph_capture_sizes = [1 , 2 ],
90+ ))
91+ vllm_config .compilation_config .pass_config .enable_fusion = False
92+ with set_current_vllm_config (vllm_config ):
93+ model = SillyModel (vllm_config = vllm_config , prefix = '' )
94+
95+ inputs = torch .randn (100 ).npu ()
96+
97+ with compilation_counter .expect (
98+ num_graphs_seen = 1 , # one graph for the model
99+ num_piecewise_graphs_seen = 5 , # 2 * num_layers + 1
100+ num_piecewise_capturable_graphs_seen = 3 , # 1 + num_layers
101+ num_backend_compilations = 3 , # num_piecewise_capturable_graphs_seen
102+ num_cudagraph_caputured =
103+ 6 , # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
104+ ):
105+
106+ model (inputs )
107+
108+ model (torch .randn (2 ).npu ())
109+ model (torch .randn (1 ).npu ())
110+
111+ input = torch .zeros (2 ).npu ()
112+ global global_counter
113+ global_counter = 0
114+ output = model (input )
115+ assert global_counter == 2
116+ assert torch .allclose (output .cpu (), torch .tensor ([3. , 1. ]))
117+
118+
119+ if __name__ == "__main__" :
120+ os .environ ["VLLM_USE_V1" ] = "1"
121+ test_simple_piecewise_compile ()
0 commit comments