Skip to content

Conversation

@shirinyamani
Copy link
Contributor

@shirinyamani shirinyamani commented Jul 29, 2025

What does this PR do?

This is an RLOO trainer which is an updated version of the RLOO in TRL < 0.25.0. This is a simple script to run;

# train_rloo.py
from datasets import load_dataset
from trl import RLOOConfig, RLOOTrainer

dataset = load_dataset("trl-lib/tldr", split="train")

# Define the reward function, which rewards completions that are close to 20 characters
def reward_len(completions, **kwargs):
    return [-abs(20 - len(completion)) for completion in completions]

training_args = RLOOTrainer(output_dir="Qwen2-0.5B-GRPO")
trainer = RLOOConfig(
    model="Qwen/Qwen2-0.5B-Instruct",
    reward_funcs=reward_len,
    args=training_args,
    train_dataset=dataset,
)
trainer.train()

Please refer to migration guide section of the readme to see the complete list of renamed/removed params.

Benchmark on TLDR:

  • red is old
  • blue is new (training was interrupted)
Screenshot 2025-08-28 at 3 09 24 PM

To reproduce:

Before

Add this to the new trainer

# in _generate_and_score_completions
self._metrics[mode]["objective/scores"].append(mean_grouped_rewards.mean().item())
self._metrics[mode]["policy/clipfrac_avg"].append(gathered_clip_ratio.nanmean().item())
# in compute_loss
self._metrics[mode]["loss/policy_avg"].append(loss.item())

in the old trainer, the reward was not computed properly. Replace this:

_, score, _ = get_reward(
    reward_model, postprocessed_query_response, processing_class.pad_token_id, context_length
)

by

score = reward_model(postprocessed_query_response).logits[:, 0]

Training script

train.py:

import os
import shutil

from accelerate import PartialState
from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    HfArgumentParser,
)

from trl import ModelConfig, RLOOConfig, RLOOTrainer, ScriptArguments
from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE


# Enable logging in a Hugging Face Space
os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")

if __name__ == "__main__":
    parser = HfArgumentParser((ScriptArguments, RLOOConfig, ModelConfig))
    script_args, training_args, model_args = parser.parse_args_into_dataclasses()
    # remove output_dir if exists
    shutil.rmtree(training_args.output_dir, ignore_errors=True)

    model_id = "cleanrl/EleutherAI_pythia-1b-deduped__sft__tldr"
    reward_model_id = "cleanrl/EleutherAI_pythia-1b-deduped__reward__tldr"

    ################
    # Model & Tokenizer
    ################
    tokenizer = AutoTokenizer.from_pretrained(
        model_id, padding_side="left", trust_remote_code=model_args.trust_remote_code
    )
    tokenizer.add_special_tokens({"pad_token": "[PAD]"})
    if tokenizer.chat_template is None:
        tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
    reward_model = AutoModelForSequenceClassification.from_pretrained(
        reward_model_id, trust_remote_code=model_args.trust_remote_code, num_labels=1
    )
    reward_model.config.pad_token_id=0
    ref_policy = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=model_args.trust_remote_code)
    policy = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=model_args.trust_remote_code)

    ################
    # Dataset
    ################
    dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
    train_dataset = dataset[script_args.dataset_train_split]
    eval_dataset = dataset[script_args.dataset_test_split]

    def prepare_dataset(dataset, tokenizer):
        """pre-tokenize the dataset before training; only collate during training"""

        def tokenize(element):
            input_ids = tokenizer.apply_chat_template(
                element["messages"][:1],
                padding=False,
                add_generation_prompt=True,
            )
            return {"input_ids": input_ids, "lengths": len(input_ids)}

        return dataset.map(
            tokenize,
            remove_columns=dataset.column_names,
            num_proc=training_args.dataset_num_proc,
        )

    # Compute that only on the main process for faster data processing.
    # see: https://github.com/huggingface/trl/pull/1255
    with PartialState().local_main_process_first():
        train_dataset = prepare_dataset(train_dataset, tokenizer)
        eval_dataset = prepare_dataset(eval_dataset, tokenizer)
        # filtering
        train_dataset = train_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc)
        eval_dataset = eval_dataset.filter(lambda x: x["lengths"] <= 512, num_proc=training_args.dataset_num_proc)

    assert train_dataset[0]["input_ids"][-1] != tokenizer.eos_token_id, "The last token should not be an EOS token"

    ################
    # Training
    ################
    trainer = RLOOTrainer(
        config=training_args,
        processing_class=tokenizer,
        policy=policy,
        ref_policy=ref_policy,
        reward_model=reward_model,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
    )
    trainer.train()

    # Save and push to hub
    trainer.save_model(training_args.output_dir)
    if training_args.push_to_hub:
        trainer.push_to_hub(dataset_name=script_args.dataset_name)
python train.py \
    --dataset_name trl-internal-testing/tldr-preference-sft-trl-style \
    --dataset_test_split validation \
    --learning_rate 3e-6 \
    --output_dir pythia-1b-deduped-tldr-preference-sft-trl-style-rloo \
    --per_device_train_batch_size 1 \
    --gradient_accumulation_steps 64 \
    --num_ppo_epochs 4 \
    --total_episodes 30000 \
    --stop_token eos \
    --response_length 53 \
    # only available in new trainer:
    --logging_steps 4 \
    --log_completions \
    --num_completions_to_print 1

Results

See above

Why don't we get an exact match?

One of the reasons is that in the old trainer, the learning rate schedule was wrong, consequently it decays 2 times slower and ranges from 3e-6 to 1.5e-6.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@edbeeching edbeeching left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this clean refactor @shirinyamani and @qgallouedec .

Would it be possible to include more details in the PR description of the experiments you have run to validate the results with the new implementation vs the original one?

@qgallouedec qgallouedec changed the base branch from fake_support_branch_for_rloo to main August 28, 2025 16:37
@shirinyamani shirinyamani merged commit e7b37d4 into main Aug 29, 2025
11 checks passed
@shirinyamani shirinyamani deleted the rloo_final branch August 29, 2025 15:27
@huggingface huggingface deleted a comment from qgallouedec Aug 29, 2025
@shirinyamani
Copy link
Contributor Author

Thanks for this clean refactor @shirinyamani and @qgallouedec .

Would it be possible to include more details in the PR description of the experiments you have run to validate the results with the new implementation vs the original one?

Done! 🔥

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants