Skip to content

Commit 357e331

Browse files
authored
Move tests for GSPOTokenTrainer to experimental (#4572)
1 parent a59f2cf commit 357e331

File tree

2 files changed

+60
-34
lines changed

2 files changed

+60
-34
lines changed
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import torch
17+
from datasets import load_dataset
18+
from transformers.utils import is_peft_available
19+
20+
from trl import GRPOConfig
21+
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
22+
23+
from ..testing_utils import TrlTestCase
24+
25+
26+
if is_peft_available():
27+
pass
28+
29+
30+
class TestGSPOTokenTrainer(TrlTestCase):
31+
def test_training(self):
32+
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
33+
34+
training_args = GRPOConfig(
35+
output_dir=self.tmp_dir,
36+
learning_rate=0.1, # increase the learning rate to speed up the test
37+
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
38+
num_generations=3, # reduce the number of generations to reduce memory usage
39+
max_completion_length=8, # reduce the completion length to reduce memory usage
40+
num_iterations=2, # the importance sampling weights won't be 0 in this case
41+
importance_sampling_level="sequence_token",
42+
report_to="none",
43+
)
44+
trainer = GSPOTokenTrainer(
45+
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
46+
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
47+
args=training_args,
48+
train_dataset=dataset,
49+
)
50+
51+
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
52+
53+
trainer.train()
54+
55+
assert trainer.state.log_history[-1]["train_loss"] is not None
56+
57+
# Check that the params have changed
58+
for n, param in previous_trainable_params.items():
59+
new_param = trainer.model.get_parameter(n)
60+
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."

tests/test_grpo_trainer.py

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
from transformers.utils import is_peft_available
3737

3838
from trl import GRPOConfig, GRPOTrainer
39-
from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
4039
from trl.trainer.utils import get_kbit_device_map
4140

4241
from .testing_utils import (
@@ -1799,39 +1798,6 @@ def test_single_reward_model_with_single_processing_class(self):
17991798
assert trainer.reward_processing_classes[0] == single_processing_class
18001799

18011800

1802-
class TestGSPOTokenTrainer(TrlTestCase):
1803-
def test_training(self):
1804-
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
1805-
1806-
training_args = GRPOConfig(
1807-
output_dir=self.tmp_dir,
1808-
learning_rate=0.1, # increase the learning rate to speed up the test
1809-
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
1810-
num_generations=3, # reduce the number of generations to reduce memory usage
1811-
max_completion_length=8, # reduce the completion length to reduce memory usage
1812-
num_iterations=2, # the importance sampling weights won't be 0 in this case
1813-
importance_sampling_level="sequence_token",
1814-
report_to="none",
1815-
)
1816-
trainer = GSPOTokenTrainer(
1817-
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
1818-
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
1819-
args=training_args,
1820-
train_dataset=dataset,
1821-
)
1822-
1823-
previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
1824-
1825-
trainer.train()
1826-
1827-
assert trainer.state.log_history[-1]["train_loss"] is not None
1828-
1829-
# Check that the params have changed
1830-
for n, param in previous_trainable_params.items():
1831-
new_param = trainer.model.get_parameter(n)
1832-
assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
1833-
1834-
18351801
@pytest.mark.slow
18361802
@require_torch_accelerator
18371803
class TestGRPOTrainerSlow(TrlTestCase):

0 commit comments

Comments
 (0)