Skip to content

Commit e27eae8

Browse files
jiemingzSahilJain314
authored andcommitted
feat: fp8 block scaling (NVIDIA-NeMo#543)
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com> Signed-off-by: Jimmy Zhang <133159885+jiemingz@users.noreply.github.com> Signed-off-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <sahilj@nvidia.com> Co-authored-by: Sahil Jain <48468750+SahilJain314@users.noreply.github.com> Signed-off-by: Qidong Su <qidongs@nvidia.com>
1 parent d22d309 commit e27eae8

File tree

15 files changed

+1172
-171
lines changed

15 files changed

+1172
-171
lines changed

docs/assets/fp8_curves.png

343 KB
Loading

docs/fp8.md

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# FP8 for NeMo-RL
2+
3+
This module provides a suite of tools to enable FP8 quantization for large language models. This module is still in developement. Currently we support FP8 generation, using Deepseek style FP8 (sub channel scaling).
4+
5+
NeMo-RL monkey patches several vLLM functions to enable FP8 generations for reinforcement learning. The `init_fp8` function patches key `vLLM` components when initialized:
6+
1. **`RayDistributedExecutor`**: For multi-GPU inference, the executor is patched to ensure that every worker process applies the same FP8 patches before model initialization.
7+
2. **Quantization Utilities**: Functions within `vllm.model_executor.layers.quantization` are replaced with versions that support power-of-2 scaling and other custom features.
8+
3. **Weight Loading**: A custom `load_weights` function handles the on-the-fly quantization of model weights from a higher-precision format to FP8 with the correct scaling factors.
9+
10+
---
11+
12+
## Usage
13+
14+
FP8 generations are recommended to be configured with the following settings:
15+
16+
```
17+
loss_fn:
18+
# importance sampling helps improve stability
19+
use_importance_sampling_correction: true
20+
21+
policy:
22+
generation:
23+
vllm_cfg:
24+
precision: 'fp8'
25+
# DeepGemm is much more performant than vLLM's default cutlass fp8 subchannel scaling kernels
26+
use_deep_gemm: true
27+
# Keeping the first and last three layers in bf16 reduces the multi-token error without
28+
# a signficant effect to performance
29+
num_last_layers_in_bf16: 3
30+
num_first_layers_in_bf16: 1
31+
# Use FP32 scaling factors. Rounding scaling factors to the nearest pow2 may improve quantization
32+
# fidelity however this feature is still under research.
33+
use_weight_pow2_scale: False
34+
use_activation_pow2_scale: False
35+
```
36+
37+
## Accuracy
38+
39+
We observe on the Llama 8b recipe a ~5% accuracy loss is incurred with FP8 generations. Convergence is still under active research and FP8 generations should be used with caution. We are investigating ways to close the accuracy gap and further improve performance.

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ testing.md
4949
documentation.md
5050
debugging.md
5151
nsys-profiling.md
52+
fp8.md
5253
guides/use-custom-vllm.md
5354
apidocs/index.rst
5455
```

examples/configs/grpo_math_1B.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,9 @@ policy:
174174
gpu_memory_utilization: 0.6
175175
max_model_len: ${policy.max_total_sequence_length}
176176
enforce_eager: False
177+
use_deep_gemm: False
178+
num_last_layers_in_bf16: 0
179+
num_first_layers_in_bf16: 0
177180
colocated:
178181
# true: generation shares training GPUs
179182
# false: uses dedicated generation resources
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# GRPO Algorithm Configuration
2+
defaults: "grpo_math_8B_megatron.yaml"
3+
4+
loss_fn:
5+
use_importance_sampling_correction: true
6+
7+
policy:
8+
generation:
9+
vllm_cfg:
10+
precision: 'fp8'
11+
use_deep_gemm: true
12+
num_last_layers_in_bf16: 0
13+
num_first_layers_in_bf16: 0
Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
grpo:
2+
num_prompts_per_step: 64
3+
num_generations_per_prompt: 32
4+
max_rollout_turns: 1
5+
max_num_steps: 500
6+
normalize_rewards: true
7+
use_leave_one_out_baseline: true
8+
val_period: 10
9+
val_at_start: false
10+
max_val_samples: 256
11+
val_batch_size: 256
12+
seed: 42
13+
loss_fn:
14+
reference_policy_kl_penalty: 0.01
15+
ratio_clip_min: 0.2
16+
ratio_clip_max: 0.2
17+
ratio_clip_c: null
18+
use_on_policy_kl_approximation: false
19+
use_importance_sampling_correction: True
20+
token_level_loss: true
21+
checkpointing:
22+
enabled: true
23+
checkpoint_dir: results/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
24+
metric_name: val_reward
25+
higher_is_better: true
26+
keep_top_k: 3
27+
save_period: 10
28+
checkpoint_must_save_by: null
29+
policy:
30+
model_name: meta-llama/Llama-3.1-8B-Instruct
31+
tokenizer:
32+
name: meta-llama/Llama-3.1-8B-Instruct
33+
train_global_batch_size: 512
34+
train_micro_batch_size: 1
35+
generation_batch_size: 32
36+
logprob_batch_size: 2
37+
max_total_sequence_length: 4096
38+
precision: bfloat16
39+
make_sequence_length_divisible_by: 1
40+
max_grad_norm: 1
41+
42+
dtensor_cfg:
43+
enabled: False
44+
45+
dynamic_batching:
46+
enabled: False
47+
48+
sequence_packing:
49+
enabled: True
50+
train_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.train_micro_batch_size}}
51+
logprob_mb_tokens: ${mul:${policy.max_total_sequence_length}, ${policy.logprob_batch_size}}
52+
algorithm: "modified_first_fit_decreasing"
53+
sequence_length_round: 64
54+
55+
megatron_cfg:
56+
enabled: True
57+
empty_unused_memory_level: 1
58+
converter_type: "LlamaForCausalLM"
59+
tensor_model_parallel_size: 1
60+
pipeline_model_parallel_size: 2
61+
context_parallel_size: 1
62+
expert_tensor_parallel_size: 1
63+
expert_model_parallel_size: 1
64+
sequence_parallel: False
65+
pipeline_dtype: ${policy.precision}
66+
num_layers_in_first_pipeline_stage: null
67+
num_layers_in_last_pipeline_stage: null
68+
freeze_moe_router: True
69+
moe_router_dtype: "fp64"
70+
moe_router_load_balancing_type: "none" # "seq_aux_loss" causes logprob error divergence for grpo
71+
moe_router_bias_update_rate: 0.0 # by default, disable bias updates for grpo
72+
apply_rope_fusion: True
73+
activation_checkpointing: True
74+
defer_fp32_logits: True
75+
76+
optimizer:
77+
optimizer: "adam"
78+
lr: 5.0e-7
79+
min_lr: 5.0e-8
80+
weight_decay: 0.0
81+
bf16: True
82+
fp16: False
83+
params_dtype: "float32"
84+
85+
adam_beta1: 0.9
86+
adam_beta2: 0.999
87+
adam_eps: 1e-8
88+
89+
use_distributed_optimizer: True
90+
use_precision_aware_optimizer: True
91+
92+
clip_grad: ${policy.max_grad_norm}
93+
94+
scheduler:
95+
start_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
96+
end_weight_decay: ${policy.megatron_cfg.optimizer.weight_decay}
97+
weight_decay_incr_style: "constant"
98+
lr_decay_style: "constant"
99+
lr_decay_iters: null
100+
lr_warmup_iters: 2
101+
lr_warmup_init: 5.0e-8
102+
103+
distributed_data_parallel_config:
104+
grad_reduce_in_fp32: False
105+
overlap_grad_reduce: True
106+
overlap_param_gather: True
107+
average_in_collective: True
108+
use_custom_fsdp: False
109+
data_parallel_sharding_strategy: "optim_grads_params"
110+
111+
generation:
112+
backend: vllm
113+
max_new_tokens: 4096
114+
temperature: 1
115+
top_p: 1
116+
top_k: null
117+
stop_token_ids:
118+
- 128009
119+
stop_strings: null
120+
vllm_cfg:
121+
async_engine: false
122+
precision: 'fp8'
123+
tensor_parallel_size: 1
124+
pipeline_parallel_size: 1
125+
gpu_memory_utilization: 0.6
126+
max_model_len: 4096
127+
enforce_eager: False
128+
use_deep_gemm: true
129+
num_last_layers_in_bf16: 0
130+
num_first_layers_in_bf16: 0
131+
colocated:
132+
enabled: true
133+
resources:
134+
gpus_per_node: null
135+
num_nodes: null
136+
data:
137+
max_input_seq_length: 4096
138+
prompt_file: examples/prompts/cot.txt
139+
system_prompt_file: null
140+
dataset_name: OpenMathInstruct-2
141+
shuffle: true
142+
env:
143+
math:
144+
num_workers: 8
145+
logger:
146+
log_dir: logs/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
147+
num_val_samples_to_print: 0
148+
wandb_enabled: true
149+
tensorboard_enabled: true
150+
mlflow_enabled: false
151+
monitor_gpus: true
152+
wandb:
153+
project: nemo-rl
154+
name: grpo-llama3.1-8b-instruct-1n8g-megatron-fp8
155+
tensorboard: {}
156+
gpu_monitoring:
157+
collection_interval: 10
158+
flush_interval: 10
159+
cluster:
160+
gpus_per_node: 8
161+
num_nodes: 4

nemo_rl/algorithms/grpo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,11 @@ def setup(
319319
)
320320
elif backend == "vllm":
321321
generation_config = cast(VllmConfig, generation_config)
322+
if generation_config["vllm_cfg"]["precision"] == "fp8":
323+
assert loss_config["use_importance_sampling_correction"] is True, (
324+
"Importance sampling must be enabled for vLLM FP8 generation for good convergence!"
325+
)
326+
322327
policy_generation = VllmGeneration(
323328
cluster=inference_cluster, config=generation_config
324329
)

0 commit comments

Comments
 (0)