|
11 | 11 | from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe, is_func |
12 | 12 | from vllm.compilation.noop_elimination import NoOpEliminationPass |
13 | 13 | from vllm.compilation.post_cleanup import PostCleanupPass |
14 | | -from vllm.config import CompilationConfig, PassConfig, VllmConfig |
| 14 | +from vllm.config import ( |
| 15 | + CompilationConfig, |
| 16 | + ModelConfig, |
| 17 | + PassConfig, |
| 18 | + VllmConfig, |
| 19 | + set_current_vllm_config, |
| 20 | +) |
15 | 21 | from vllm.model_executor.layers.activation import SiluAndMul |
16 | 22 | from vllm.model_executor.layers.layernorm import RMSNorm |
17 | 23 | from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape |
@@ -217,42 +223,48 @@ def ops_not_in_model(self): |
217 | 223 | def test_fix_functionalization(model_class: torch.nn.Module, do_fusion: bool): |
218 | 224 | torch.set_default_device("cuda") |
219 | 225 |
|
220 | | - vllm_config = VllmConfig() |
221 | | - vllm_config.compilation_config = CompilationConfig( |
222 | | - pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True) |
| 226 | + vllm_config = VllmConfig( |
| 227 | + model_config=ModelConfig(dtype=torch.bfloat16), |
| 228 | + compilation_config=CompilationConfig( |
| 229 | + custom_ops=["all"], |
| 230 | + pass_config=PassConfig(enable_fusion=do_fusion, enable_noop=True), |
| 231 | + ), |
223 | 232 | ) |
224 | | - noop_pass = NoOpEliminationPass(vllm_config) |
225 | | - fusion_pass = RMSNormQuantFusionPass(vllm_config) |
226 | | - cleanup_pass = PostCleanupPass(vllm_config) |
227 | | - act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) |
228 | | - |
229 | | - passes = ( |
230 | | - [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] |
231 | | - if do_fusion |
232 | | - else [noop_pass, cleanup_pass] |
233 | | - ) |
234 | | - func_pass = FixFunctionalizationPass(vllm_config) |
235 | 233 |
|
236 | | - backend_func = TestBackend(*passes, func_pass) |
237 | | - backend_no_func = TestBackend(*passes) |
| 234 | + with set_current_vllm_config(vllm_config): |
| 235 | + assert RMSNorm.enabled() |
| 236 | + noop_pass = NoOpEliminationPass(vllm_config) |
| 237 | + fusion_pass = RMSNormQuantFusionPass(vllm_config) |
| 238 | + cleanup_pass = PostCleanupPass(vllm_config) |
| 239 | + act_quant_fusion_pass = ActivationQuantFusionPass(vllm_config) |
| 240 | + |
| 241 | + passes = ( |
| 242 | + [noop_pass, fusion_pass, act_quant_fusion_pass, cleanup_pass] |
| 243 | + if do_fusion |
| 244 | + else [noop_pass, cleanup_pass] |
| 245 | + ) |
| 246 | + func_pass = FixFunctionalizationPass(vllm_config) |
238 | 247 |
|
239 | | - model = model_class() |
240 | | - torch.compile(model, backend=backend_func)(*model.example_inputs()) |
241 | | - torch.compile(model, backend=backend_no_func)(*model.example_inputs()) |
| 248 | + backend_func = TestBackend(*passes, func_pass) |
| 249 | + backend_no_func = TestBackend(*passes) |
242 | 250 |
|
243 | | - # check if the functionalization pass is applied |
244 | | - for op in model.ops_in_model(do_fusion): |
245 | | - find_auto_fn(backend_no_func.graph_post_pass.nodes, op) |
246 | | - assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None |
| 251 | + model = model_class() |
| 252 | + torch.compile(model, backend=backend_func)(*model.example_inputs()) |
| 253 | + torch.compile(model, backend=backend_no_func)(*model.example_inputs()) |
247 | 254 |
|
248 | | - # make sure the ops were all de-functionalized |
249 | | - found = dict() |
250 | | - for node in backend_func.graph_post_pass.nodes: |
| 255 | + # check if the functionalization pass is applied |
251 | 256 | for op in model.ops_in_model(do_fusion): |
252 | | - if is_func(node, op): |
253 | | - found[op] = True |
254 | | - for op in model.ops_not_in_model(): |
255 | | - if is_func(node, op): |
256 | | - found[op] = True |
257 | | - assert all(found[op] for op in model.ops_in_model(do_fusion)) |
258 | | - assert all(not found.get(op) for op in model.ops_not_in_model()) |
| 257 | + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) |
| 258 | + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, op) is None |
| 259 | + |
| 260 | + # make sure the ops were all de-functionalized |
| 261 | + found = dict() |
| 262 | + for node in backend_func.graph_post_pass.nodes: |
| 263 | + for op in model.ops_in_model(do_fusion): |
| 264 | + if is_func(node, op): |
| 265 | + found[op] = True |
| 266 | + for op in model.ops_not_in_model(): |
| 267 | + if is_func(node, op): |
| 268 | + found[op] = True |
| 269 | + assert all(found[op] for op in model.ops_in_model(do_fusion)) |
| 270 | + assert all(not found.get(op) for op in model.ops_not_in_model()) |
0 commit comments