1- import torch , os
1+ import torch , os , ray
22from accelerate import init_empty_weights
33from dataclasses import dataclass , field
44from transformers import AutoConfig , AutoModel
@@ -33,12 +33,23 @@ class FP8State:
3333
3434# Global FP8 config that can be accessed by patched vLLM functions
3535# initialized by 'init_fp8_cfg()'
36- fp8_config : FP8Config = None
36+ global_fp8_config : FP8Config = None
3737# Global FP8 state that holds runtime fp8 objects
3838fp8_state : FP8State = FP8State ()
3939
40+ fp8_patches_applied = False
41+
42+
43+ from vllm .executor .ray_distributed_executor import RayDistributedExecutor
44+ original_run_workers = RayDistributedExecutor ._run_workers
45+
46+
47+ def apply_fp8_patches (self , fp8_config ):
48+ global global_fp8_config , fp8_patches_applied
49+
50+ if global_fp8_config is None :
51+ global_fp8_config = fp8_config
4052
41- def init_fp8 (vllm_cfg , model_name ):
4253 # This patch is used to support torch.compile with vllm parameter subclasses, such as
4354 # PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each
4455 # refit, we patch fp8 parameters to have a reference to their weight loader. Eventually
@@ -47,10 +58,9 @@ def init_fp8(vllm_cfg, model_name):
4758 func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading"
4859 patcher1 = patch (func1_path , process_weights_after_loading )
4960 fp8_state .vllm_patches .append (patcher1 )
50-
5161 # These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher
5262 # SNR compared to plain fp32 scaling factors. This feature is still under active research.
53- if vllm_cfg . get ( "pow2_activation_scaling_factors" , False ) :
63+ if global_fp8_config . use_activation_pow2_scale :
5464 func2_path = "vllm.model_executor.layers.quantization.utils.fp8_utils.per_token_group_quant_fp8"
5565 func3_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8"
5666 func4_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8_colmajor"
@@ -62,6 +72,33 @@ def init_fp8(vllm_cfg, model_name):
6272 for p in fp8_state .vllm_patches :
6373 p .start ()
6474
75+ fp8_patches_applied = True
76+
77+ def patched_run_workers (self , * args , ** kwargs ):
78+ global fp8_patches_applied
79+ if not fp8_patches_applied :
80+ apply_fp8_patches (self , global_fp8_config )
81+ futures = [worker .execute_method .remote (apply_fp8_patches , global_fp8_config ) for worker in self .workers ]
82+ [ray .get (future ) for future in futures ]
83+
84+ return original_run_workers (self , * args , ** kwargs )
85+
86+ # we patch vllm's _run_workers so that before vllm initalizes the model, we execute a remote call that patches
87+ # each worker with our required fp8 vllm patches
88+ RayDistributedExecutor ._run_workers = patched_run_workers
89+
90+
91+ def init_fp8 (vllm_cfg , model_name ):
92+ global global_fp8_config
93+ global_fp8_config = FP8Config (
94+ use_weight_pow2_scale = vllm_cfg .get ("pow2_weight_scaling_factors" , False ),
95+ use_activation_pow2_scale = vllm_cfg .get (
96+ "pow2_activation_scaling_factors" , False
97+ ),
98+ num_first_layers_in_bf16 = vllm_cfg .get ("num_first_layers_in_bf16" , 0 ),
99+ num_last_layers_in_bf16 = vllm_cfg .get ("num_last_layers_in_bf16" , 0 ),
100+ )
101+
65102 if vllm_cfg .get ("use_deep_gemm" , False ):
66103 os .environ ["VLLM_USE_DEEP_GEMM" ] = "1"
67104
@@ -106,17 +143,6 @@ def init_fp8(vllm_cfg, model_name):
106143 return vllm_kwargs
107144
108145
109- def init_fp8_cfg (vllm_cfg ):
110- global fp8_config
111- fp8_config = FP8Config (
112- use_weight_pow2_scale = vllm_cfg .get ("pow2_weight_scaling_factors" , False ),
113- use_activation_pow2_scale = vllm_cfg .get (
114- "pow2_activation_scaling_factors" , False
115- ),
116- num_first_layers_in_bf16 = vllm_cfg .get ("num_first_layers_in_bf16" , 0 ),
117- num_last_layers_in_bf16 = vllm_cfg .get ("num_last_layers_in_bf16" , 0 ),
118- )
119-
120146
121147def is_fp8_model (vllm_config ):
122148 from vllm .model_executor .layers .quantization .fp8 import Fp8Config
@@ -269,7 +295,7 @@ def kitchen_block_scale(
269295 # Calculate descale factor
270296 descale = max_abs / max_dtype
271297
272- if fp8_config .use_weight_pow2_scale :
298+ if global_fp8_config .use_weight_pow2_scale :
273299 exponent = torch .ceil (torch .log2 (descale ))
274300 # Post process exponent to be in range of -127 to 127 and to be E8M0 biased
275301 exponent = torch .clamp (exponent , min = - 127 , max = 127 ) + 127
@@ -483,7 +509,7 @@ def _per_token_group_quant_fp8_colmajor(
483509def per_token_group_quant_fp8 (
484510 * args , ** kwargs ,
485511) -> tuple [torch .Tensor , torch .Tensor ]:
486- assert fp8_config .use_activation_pow2_scale
512+ assert global_fp8_config .use_activation_pow2_scale
487513 from vllm .model_executor .layers .quantization .utils .fp8_utils import per_token_group_quant_fp8 as vllm_per_token_group_quant_fp8
488514 return vllm_per_token_group_quant_fp8 (* args , ** kwargs )
489515
0 commit comments