diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 3002ce2b7fa..7e41b48a813 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -703,7 +703,12 @@ def test_training_beta_non_zero(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - def test_training_with_cast_lm_head_to_fp32(self): + @pytest.mark.parametrize( + "model_name", + ["trl-internal-testing/tiny-Qwen3ForCausalLM", "trl-internal-testing/tiny-Gemma2ForCausalLM"], + # Gemma2 has the input word embeddings and lm_head tied, Qwen3 does not + ) + def test_training_with_cast_lm_head_to_fp32(self, model_name): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") training_args = GRPOConfig( output_dir=self.tmp_dir, @@ -715,7 +720,7 @@ def test_training_with_cast_lm_head_to_fp32(self): cast_lm_head_to_fp32=True, ) trainer = GRPOTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + model=model_name, reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", args=training_args, train_dataset=dataset, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8d9492eed8a..b6b89e6d427 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -479,21 +479,31 @@ def __init__( # Cast LM Head To FP32 if args.cast_lm_head_to_fp32: - if not model.config.tie_word_embeddings: - def cast_inputs_to_fp32(module, input): - return (input[0].float(),) + def _cast_lm_head_to_fp32(target_model: PreTrainedModel): + """Cast lm_head to fp32 while preserving embedding output dtype if tied.""" - model.lm_head = model.lm_head.float() - model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) - if self.ref_model is not None: - self.ref_model.lm_head = self.ref_model.lm_head.float() - self.ref_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) - else: - raise NotImplementedError( - "`cast_lm_head_to_fp32=True` is only supported when the model has untied word embedding and language modeling head layers" - "i.e. `tie_word_embeddings` in the model config is False." - ) + def cast_inputs_to_fp32(module, inputs): + # Preserve other positional args and kwargs untouched + if not inputs: + return inputs + return (inputs[0].to(torch.float32),) + inputs[1:] + + original_dtype_local = target_model.lm_head.weight.dtype + target_model.lm_head = target_model.lm_head.float() + target_model.lm_head.register_forward_pre_hook(cast_inputs_to_fp32) + + if target_model.config.tie_word_embeddings: + + def cast_outputs_to_original_dtype(module, args, output): + return output.to(original_dtype_local) + + # Only cast activations; weights are now fp32 (intentional for numerical stability of logits) + target_model.model.embed_tokens.register_forward_hook(cast_outputs_to_original_dtype) + + _cast_lm_head_to_fp32(model) + if self.ref_model is not None: + _cast_lm_head_to_fp32(self.ref_model) # Liger loss if self.use_liger_kernel: