Skip to content

Commit 32ada21

Browse files
jiemingzroot
authored andcommitted
ensure importance sampling on
Signed-off-by: Jimmy Zhang <jiemingz@nvidia.com>
1 parent e748de1 commit 32ada21

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

nemo_rl/algorithms/grpo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ def setup(
309309
)
310310
elif backend == "vllm":
311311
generation_config = cast(VllmConfig, generation_config)
312+
if generation_config["vllm_cfg"]["precision"] == "fp8":
313+
assert loss_config["use_importance_sampling_correction"] is True, (
314+
"Importance sampling must be enabled for vLLM FP8 generation!"
315+
)
316+
312317
policy_generation = VllmGeneration(
313318
cluster=inference_cluster, config=generation_config
314319
)

0 commit comments

Comments
 (0)