@@ -28,42 +28,49 @@ class Relu3(ReLUSquaredActivation):
2828
2929
3030@pytest .mark .parametrize (
31- "env, torch_level, ops_enabled, default_on" ,
31+ "env, torch_level, use_inductor, ops_enabled, default_on" ,
3232 [
3333 # Default values based on compile level
34- ("" , 0 , [True ] * 4 , True ),
35- ("" , 1 , [True ] * 4 , True ),
36- ("" , 2 , [True ] * 4 , True ), # All by default
37- ("" , 3 , [False ] * 4 , False ),
38- ("" , 4 , [False ] * 4 , False ), # None by default
34+ # - All by default (no Inductor compilation)
35+ ("" , 0 , False , [True ] * 4 , True ),
36+ ("" , 1 , True , [True ] * 4 , True ),
37+ ("" , 2 , False , [True ] * 4 , True ),
38+ # - None by default (with Inductor)
39+ ("" , 3 , True , [False ] * 4 , False ),
40+ ("" , 4 , True , [False ] * 4 , False ),
41+ # - All by default (without Inductor)
42+ ("" , 3 , False , [True ] * 4 , True ),
43+ ("" , 4 , False , [True ] * 4 , True ),
3944 # Explicitly enabling/disabling
4045 #
4146 # Default: all
4247 #
4348 # All but SiluAndMul
44- ("+rms_norm,-silu_and_mul" , 0 , [1 , 0 , 1 , 1 ], True ),
49+ ("+rms_norm,-silu_and_mul" , 0 , True , [1 , 0 , 1 , 1 ], True ),
4550 # Only ReLU3
46- ("none,-rms_norm,+relu3" , 0 , [0 , 0 , 0 , 1 ], False ),
51+ ("none,-rms_norm,+relu3" , 1 , False , [0 , 0 , 0 , 1 ], False ),
4752 # All but SiluAndMul
48- ("all,-silu_and_mul" , 1 , [1 , 0 , 1 , 1 ], True ),
53+ ("all,-silu_and_mul" , 2 , True , [1 , 0 , 1 , 1 ], True ),
4954 # All but ReLU3 (even if ReLU2 is on)
50- ("-relu3,relu2" , 1 , [1 , 1 , 1 , 0 ], True ),
51- # GeluAndMul and SiluAndMul
52- ("none,-relu3,+gelu_and_mul ,+silu_and_mul" , 2 , [ 0 , 1 , 1 , 0 ], False ),
55+ ("-relu3,relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ),
56+ # RMSNorm and SiluAndMul
57+ ("none,-relu3,+rms_norm ,+silu_and_mul" , 4 , False , [ 1 , 1 , 0 , 0 ], False ),
5358 # All but RMSNorm
54- ("-rms_norm" , 2 , [0 , 1 , 1 , 1 ], True ),
59+ ("-rms_norm" , 3 , False , [0 , 1 , 1 , 1 ], True ),
5560 #
5661 # Default: none
5762 #
5863 # Only ReLU3
59- ("-silu_and_mul,+relu3" , 3 , [0 , 0 , 0 , 1 ], False ),
64+ ("-silu_and_mul,+relu3" , 3 , True , [0 , 0 , 0 , 1 ], False ),
6065 # All but RMSNorm
61- ("all,-rms_norm" , 4 , [0 , 1 , 1 , 1 ], True ),
66+ ("all,-rms_norm" , 4 , True , [0 , 1 , 1 , 1 ], True ),
6267 ])
63- def test_enabled_ops (env : str , torch_level : int , ops_enabled : list [int ],
64- default_on : bool ):
65- vllm_config = VllmConfig (compilation_config = CompilationConfig (
66- level = torch_level , custom_ops = env .split ("," )))
68+ def test_enabled_ops (env : str , torch_level : int , use_inductor : bool ,
69+ ops_enabled : list [int ], default_on : bool ):
70+ vllm_config = VllmConfig (
71+ compilation_config = CompilationConfig (use_inductor = bool (use_inductor ),
72+ level = torch_level ,
73+ custom_ops = env .split ("," )))
6774 with set_current_vllm_config (vllm_config ):
6875 assert CustomOp .default_on () == default_on
6976
0 commit comments