From 628e7952f09843f06fe6dc707aeda441003b0f80 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 25 Nov 2025 04:39:46 +0000 Subject: [PATCH] Move tests for GSPOTokenTrainer to experimental --- tests/experimental/test_gspo_token_trainer.py | 60 +++++++++++++++++++ tests/test_grpo_trainer.py | 34 ----------- 2 files changed, 60 insertions(+), 34 deletions(-) create mode 100644 tests/experimental/test_gspo_token_trainer.py diff --git a/tests/experimental/test_gspo_token_trainer.py b/tests/experimental/test_gspo_token_trainer.py new file mode 100644 index 0000000000..cee2663d5c --- /dev/null +++ b/tests/experimental/test_gspo_token_trainer.py @@ -0,0 +1,60 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +from datasets import load_dataset +from transformers.utils import is_peft_available + +from trl import GRPOConfig +from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer + +from ..testing_utils import TrlTestCase + + +if is_peft_available(): + pass + + +class TestGSPOTokenTrainer(TrlTestCase): + def test_training(self): + dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + num_iterations=2, # the importance sampling weights won't be 0 in this case + importance_sampling_level="sequence_token", + report_to="none", + ) + trainer = GSPOTokenTrainer( + model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 7fb56e4420..baaea524b2 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -36,7 +36,6 @@ from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer -from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer from trl.trainer.utils import get_kbit_device_map from .testing_utils import ( @@ -1799,39 +1798,6 @@ def test_single_reward_model_with_single_processing_class(self): assert trainer.reward_processing_classes[0] == single_processing_class -class TestGSPOTokenTrainer(TrlTestCase): - def test_training(self): - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") - - training_args = GRPOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=3, # reduce the batch size to reduce memory usage - num_generations=3, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - num_iterations=2, # the importance sampling weights won't be 0 in this case - importance_sampling_level="sequence_token", - report_to="none", - ) - trainer = GSPOTokenTrainer( - model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check that the params have changed - for n, param in previous_trainable_params.items(): - new_param = trainer.model.get_parameter(n) - assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - - @pytest.mark.slow @require_torch_accelerator class TestGRPOTrainerSlow(TrlTestCase):