Skip to content

Commit f5401dc

Browse files
committed
fix TP
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
1 parent 1017715 commit f5401dc

File tree

2 files changed

+44
-21
lines changed

2 files changed

+44
-21
lines changed

nemo_rl/models/generation/fp8.py

Lines changed: 44 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch, os
1+
import torch, os, ray
22
from accelerate import init_empty_weights
33
from dataclasses import dataclass, field
44
from transformers import AutoConfig, AutoModel
@@ -33,12 +33,23 @@ class FP8State:
3333

3434
# Global FP8 config that can be accessed by patched vLLM functions
3535
# initialized by 'init_fp8_cfg()'
36-
fp8_config: FP8Config = None
36+
global_fp8_config: FP8Config = None
3737
# Global FP8 state that holds runtime fp8 objects
3838
fp8_state: FP8State = FP8State()
3939

40+
fp8_patches_applied = False
41+
42+
43+
from vllm.executor.ray_distributed_executor import RayDistributedExecutor
44+
original_run_workers = RayDistributedExecutor._run_workers
45+
46+
47+
def apply_fp8_patches(self, fp8_config):
48+
global global_fp8_config, fp8_patches_applied
49+
50+
if global_fp8_config is None:
51+
global_fp8_config = fp8_config
4052

41-
def init_fp8(vllm_cfg, model_name):
4253
# This patch is used to support torch.compile with vllm parameter subclasses, such as
4354
# PerTensorScaleParameter. Because we need weight loaders to update fp8 weights each
4455
# refit, we patch fp8 parameters to have a reference to their weight loader. Eventually
@@ -47,10 +58,9 @@ def init_fp8(vllm_cfg, model_name):
4758
func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading"
4859
patcher1 = patch(func1_path, process_weights_after_loading)
4960
fp8_state.vllm_patches.append(patcher1)
50-
5161
# These patches add support for pow2, e8 dynamic activation scalings factors which are believed to have higher
5262
# SNR compared to plain fp32 scaling factors. This feature is still under active research.
53-
if vllm_cfg.get("pow2_activation_scaling_factors", False):
63+
if global_fp8_config.use_activation_pow2_scale:
5464
func2_path = "vllm.model_executor.layers.quantization.utils.fp8_utils.per_token_group_quant_fp8"
5565
func3_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8"
5666
func4_path = "vllm.model_executor.layers.quantization.utils.fp8_utils._per_token_group_quant_fp8_colmajor"
@@ -62,6 +72,33 @@ def init_fp8(vllm_cfg, model_name):
6272
for p in fp8_state.vllm_patches:
6373
p.start()
6474

75+
fp8_patches_applied = True
76+
77+
def patched_run_workers(self, *args, **kwargs):
78+
global fp8_patches_applied
79+
if not fp8_patches_applied:
80+
apply_fp8_patches(self, global_fp8_config)
81+
futures = [worker.execute_method.remote(apply_fp8_patches, global_fp8_config) for worker in self.workers]
82+
[ray.get(future) for future in futures]
83+
84+
return original_run_workers(self, *args, **kwargs)
85+
86+
# we patch vllm's _run_workers so that before vllm initalizes the model, we execute a remote call that patches
87+
# each worker with our required fp8 vllm patches
88+
RayDistributedExecutor._run_workers = patched_run_workers
89+
90+
91+
def init_fp8(vllm_cfg, model_name):
92+
global global_fp8_config
93+
global_fp8_config = FP8Config(
94+
use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False),
95+
use_activation_pow2_scale=vllm_cfg.get(
96+
"pow2_activation_scaling_factors", False
97+
),
98+
num_first_layers_in_bf16=vllm_cfg.get("num_first_layers_in_bf16", 0),
99+
num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0),
100+
)
101+
65102
if vllm_cfg.get("use_deep_gemm", False):
66103
os.environ["VLLM_USE_DEEP_GEMM"] = "1"
67104

@@ -106,17 +143,6 @@ def init_fp8(vllm_cfg, model_name):
106143
return vllm_kwargs
107144

108145

109-
def init_fp8_cfg(vllm_cfg):
110-
global fp8_config
111-
fp8_config = FP8Config(
112-
use_weight_pow2_scale=vllm_cfg.get("pow2_weight_scaling_factors", False),
113-
use_activation_pow2_scale=vllm_cfg.get(
114-
"pow2_activation_scaling_factors", False
115-
),
116-
num_first_layers_in_bf16=vllm_cfg.get("num_first_layers_in_bf16", 0),
117-
num_last_layers_in_bf16=vllm_cfg.get("num_last_layers_in_bf16", 0),
118-
)
119-
120146

121147
def is_fp8_model(vllm_config):
122148
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
@@ -269,7 +295,7 @@ def kitchen_block_scale(
269295
# Calculate descale factor
270296
descale = max_abs / max_dtype
271297

272-
if fp8_config.use_weight_pow2_scale:
298+
if global_fp8_config.use_weight_pow2_scale:
273299
exponent = torch.ceil(torch.log2(descale))
274300
# Post process exponent to be in range of -127 to 127 and to be E8M0 biased
275301
exponent = torch.clamp(exponent, min=-127, max=127) + 127
@@ -483,7 +509,7 @@ def _per_token_group_quant_fp8_colmajor(
483509
def per_token_group_quant_fp8(
484510
*args, **kwargs,
485511
) -> tuple[torch.Tensor, torch.Tensor]:
486-
assert fp8_config.use_activation_pow2_scale
512+
assert global_fp8_config.use_activation_pow2_scale
487513
from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 as vllm_per_token_group_quant_fp8
488514
return vllm_per_token_group_quant_fp8(*args, **kwargs)
489515

nemo_rl/models/generation/vllm.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,6 @@ def _patch_vllm_init_workers_ray():
359359
# used in update_weights_from_ipc_handles
360360
self.vllm_device_ids = None
361361

362-
if self.cfg["vllm_cfg"]["precision"] == "fp8":
363-
self.llm.collective_rpc("init_fp8_cfg", args=(self.cfg["vllm_cfg"],))
364-
365362
def post_init(self):
366363
self.vllm_device_ids = self.report_device_id()
367364

0 commit comments

Comments
 (0)