@@ -36,55 +36,56 @@ class Relu3(ReLUSquaredActivation):
3636
3737
3838@pytest .mark .parametrize (
39- "env, torch_level, use_inductor , ops_enabled, default_on" ,
39+ "env, torch_level, backend , ops_enabled, default_on" ,
4040 [
4141 # Default values based on compile level
4242 # - All by default (no Inductor compilation)
43- (None , 0 , False , [True ] * 4 , True ),
44- (None , 1 , True , [True ] * 4 , True ),
45- (None , 2 , False , [True ] * 4 , True ),
43+ (None , 0 , "eager" , [True ] * 4 , True ),
44+ (None , 1 , "eager" , [True ] * 4 , True ),
45+ (None , 2 , "eager" , [True ] * 4 , True ),
46+ (None , 3 , "eager" , [True ] * 4 , True ),
4647 # - None by default (with Inductor)
47- (None , 3 , True , [False ] * 4 , False ),
48- ( None , 4 , True , [ False ] * 4 , False ),
49- # - All by default (without Inductor)
50- (None , 3 , False , [True ] * 4 , True ),
51- (None , 4 , False , [True ] * 4 , True ),
48+ (None , 0 , "inductor" , [True ] * 4 , True ),
49+ # - None by default (with Inductor)
50+ ( None , 1 , "inductor" , [ False ] * 4 , False ),
51+ (None , 2 , "inductor" , [False ] * 4 , False ),
52+ (None , 3 , "inductor" , [False ] * 4 , False ),
5253 # Explicitly enabling/disabling
5354 #
5455 # Default: all
5556 #
5657 # All but SiluAndMul
57- ("+rms_norm,-silu_and_mul" , 0 , True , [1 , 0 , 1 , 1 ], True ),
58+ ("+rms_norm,-silu_and_mul" , 0 , "inductor" , [1 , 0 , 1 , 1 ], True ),
5859 # Only ReLU3
59- ("none,-rms_norm,+relu3" , 1 , False , [0 , 0 , 0 , 1 ], False ),
60+ ("none,-rms_norm,+relu3" , 1 , "eager" , [0 , 0 , 0 , 1 ], False ),
6061 # All but SiluAndMul
61- ("all,-silu_and_mul" , 2 , True , [1 , 0 , 1 , 1 ], True ),
62+ ("all,-silu_and_mul" , 2 , "inductor" , [1 , 0 , 1 , 1 ], True ),
6263 # All but ReLU3 (even if ReLU2 is on)
63- ("-relu3,+relu2" , 3 , False , [1 , 1 , 1 , 0 ], True ),
64+ ("-relu3,+relu2" , 3 , "eager" , [1 , 1 , 1 , 0 ], True ),
6465 # RMSNorm and SiluAndMul
65- ("none,-relu3,+rms_norm,+silu_and_mul" , 4 , False , [1 , 1 , 0 , 0 ], False ),
66+ ("none,-relu3,+rms_norm,+silu_and_mul" , 3 , "eager" , [1 , 1 , 0 , 0 ], False ),
6667 # All but RMSNorm
67- ("-rms_norm" , 3 , False , [0 , 1 , 1 , 1 ], True ),
68+ ("-rms_norm" , 3 , "eager" , [0 , 1 , 1 , 1 ], True ),
6869 #
6970 # Default: none
7071 #
7172 # Only ReLU3
72- ("-silu_and_mul ,+relu3" , 3 , True , [0 , 0 , 0 , 1 ], False ),
73+ ("none ,+relu3" , 3 , "inductor" , [0 , 0 , 0 , 1 ], False ),
7374 # All but RMSNorm
74- ("all,-rms_norm" , 4 , True , [0 , 1 , 1 , 1 ], True ),
75+ ("all,-rms_norm" , 3 , "inductor" , [0 , 1 , 1 , 1 ], True ),
7576 ],
7677)
7778def test_enabled_ops (
7879 env : str | None ,
7980 torch_level : int ,
80- use_inductor : bool ,
81+ backend : str ,
8182 ops_enabled : list [int ],
8283 default_on : bool ,
8384):
8485 custom_ops = env .split ("," ) if env else []
8586 vllm_config = VllmConfig (
8687 compilation_config = CompilationConfig (
87- use_inductor = bool ( use_inductor ) , level = torch_level , custom_ops = custom_ops
88+ backend = backend , level = torch_level , custom_ops = custom_ops
8889 )
8990 )
9091 with set_current_vllm_config (vllm_config ):
0 commit comments