Skip to content

Commit 41c8ca1

Browse files
GRPO: ScaleRL -> Support casting LM Head to FP32 (#4303)
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
1 parent 5cefb39 commit 41c8ca1

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

tests/test_grpo_trainer.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,6 +703,35 @@ def test_training_beta_non_zero(self):
703703
new_param = trainer.model.get_parameter(n)
704704
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
705705

706+
def test_training_with_cast_lm_head_to_fp32(self):
707+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
708+
training_args = GRPOConfig(
709+
output_dir=self.tmp_dir,
710+
learning_rate=0.1,
711+
per_device_train_batch_size=3,
712+
num_generations=3,
713+
max_completion_length=8,
714+
report_to="none",
715+
cast_lm_head_to_fp32=True,
716+
)
717+
trainer = GRPOTrainer(
718+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
719+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
720+
args=training_args,
721+
train_dataset=dataset,
722+
)
723+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
724+
725+
trainer.train()
726+
727+
assert trainer.state.log_history[-1]["train_loss"] is not None
728+
assert trainer.model.lm_head.weight.dtype == torch.float32
729+
730+
# Check that the params have changed
731+
for n, param in previous_trainable_params.items():
732+
new_param = trainer.model.get_parameter(n)
733+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
734+
706735
def test_training_with_entropy_filter(self):
707736
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
708737
training_args = GRPOConfig(

trl/trainer/grpo_config.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@ class GRPOConfig(TrainingArguments):
4141
disable_dropout (`bool`, *optional*, defaults to `False`):
4242
Whether to disable dropout in the model. This is useful for training with a reference model, as it prevents
4343
the model from generating different logprobs for the same input.
44+
cast_lm_head_to_fp32 (`bool`, *optional*, defaults to `False`):
45+
Whether to cast the language modeling head of the policy and reference models to float32. As recommended by
46+
the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model
47+
has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False.
4448
4549
> Parameters that control the data preprocessing
46-
4750
remove_unused_columns (`bool`, *optional*, defaults to `False`):
4851
Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that
4952
requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`.
@@ -297,6 +300,14 @@ class GRPOConfig(TrainingArguments):
297300
"it prevents the model from generating different logprobs for the same input."
298301
},
299302
)
303+
cast_lm_head_to_fp32: bool = field(
304+
default=False,
305+
metadata={
306+
"help": "Whether to cast the language modeling head of the policy and reference, models to float32."
307+
"As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model"
308+
" has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False."
309+
},
310+
)
300311

301312
# Parameters that control the data preprocessing
302313
# The default value remove_unused_columns is overwritten from the parent class, because in GRPO we usually rely on

trl/trainer/grpo_trainer.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,24 @@ def __init__(
477477
if self.ref_model is not None:
478478
disable_dropout_in_model(self.ref_model)
479479

480+
# Cast LM Head To FP32
481+
if args.cast_lm_head_to_fp32:
482+
if not model.config.tie_word_embeddings:
483+
484+
def cast_inputs_to_fp32(module, input):
485+
return (input[0].float(),)
486+
487+
model.lm_head = model.lm_head.float()
488+
model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
489+
if self.ref_model is not None:
490+
self.ref_model.lm_head = self.ref_model.lm_head.float()
491+
self.ref_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32)
492+
else:
493+
raise NotImplementedError(
494+
"`cast_lm_head_to_fp32=True` is only supported when the model has untied word embedding and language modeling head layers"
495+
"i.e. `tie_word_embeddings` in the model config is False."
496+
)
497+
480498
# Liger loss
481499
if self.use_liger_kernel:
482500
if not is_liger_kernel_available():
@@ -876,7 +894,6 @@ def _get_per_token_logps_and_entropies(
876894
# Divide logits by sampling temperature.
877895
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
878896
logits = logits / self.temperature
879-
880897
completion_ids = input_ids_batch[:, -logits_to_keep:]
881898
logps = selective_log_softmax(logits, completion_ids) # compute logprobs
882899
all_logps.append(logps)
@@ -1300,6 +1317,8 @@ def _generate_single_turn(self, prompts: list):
13001317
unwrapped_model.to(torch.bfloat16)
13011318
elif self.args.fp16:
13021319
unwrapped_model.to(torch.float16)
1320+
if self.args.cast_lm_head_to_fp32:
1321+
unwrapped_model.lm_head.to(torch.float32)
13031322
with torch.inference_mode():
13041323
# Continuous batching API expects 'inputs' arg only
13051324
all_outputs = unwrapped_model.generate_batch(

0 commit comments

Comments
 (0)