Skip to content

Conversation

@pramodith
Copy link
Collaborator

What does this PR do?

CISPO Loss was first introduced in the Minimax-M1 paper, the ScaleRL paper subsequently showed that CISPO Loss scales the best in terms of performance and efficiency as models are trained for longer.

image

Where rho is

image

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@pramodith pramodith requested a review from Copilot November 6, 2025 15:32
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR implements the CISPO (Clipped Importance Sampling Policy Optimization) loss function in the GRPO trainer, a technique introduced in the MiniMax-M1 paper and shown to scale effectively in the ScaleRL paper. CISPO clips the importance sampling weights directly instead of clipping advantage-scaled importance weights as in standard GRPO.

Key changes:

  • Added CISPO loss computation logic that clips importance weights using epsilon_high and multiplies with advantages and log probabilities
  • Extended loss aggregation to support CISPO alongside DAPO normalization
  • Added CISPO-specific clipping metrics to track training behavior

Reviewed Changes

Copilot reviewed 4 out of 4 changed files in this pull request and generated 4 comments.

File Description
trl/trainer/grpo_trainer.py Implements CISPO loss computation, conditional aggregation logic, and CISPO-specific clipping metrics
trl/trainer/grpo_config.py Documents the new CISPO loss type in configuration metadata and help text
tests/test_grpo_trainer.py Adds CISPO to the parametrized loss type tests
docs/source/paper_index.md Documents the ScaleRL paper and provides example configuration for reproducing results

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Member

@qgallouedec qgallouedec left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super clean! only a few nits


A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL.

You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs:
You can partially reproduce the ScaleRL recipe using the [`GRPOTrainer`] with the following configs:


config = GRPOConfig(
loss_type="cispo",
epsilon_high=5,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, it's a float, so I think it's better:

Suggested change
epsilon_high=5,
epsilon_high=5.0,

self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
if self.loss_type != "cispo":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit, again (explicit better than implicit

Suggested change
if self.loss_type != "cispo":
if self.loss_type in ["grpo", "bnpo", "dr_grpo", "dapo"]:

self._metrics[mode]["clip_ratio/high_max"].append(nanmax(gathered_high_clip).item())
gathered_clip_ratio = self.accelerator.gather(clip_ratio)
self._metrics[mode]["clip_ratio/region_mean"].append(gathered_clip_ratio.nanmean().item())
else:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
else:
elif loss_type == "cispo":

batch. Note that normalization is performed over the local batch only, so results may slightly vary
depending on the local batch size, despite a constant effective batch size. When using
`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
- `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you just make sure that the line length <= 120

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

`per_device_train_batch_size==1`, the loss is equivalent to the GRPO loss.
- `"cispo"`: Clips the importance sampling weights instead of the advantage scaled importance weights. The clipped weights
are then multiplied with the advantages and policy model's log probs. Individual token losses are aggregated by
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the MiniMax-M1 paper.
normalizing with the number of active tokens in the global accumulated batch. This method was introduced in the [MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).

# From here, log_importance_weights (and all subsequent tensors, coef_1, coef_2, etc.) shape depends on
# importance_sampling_level: "token" level: (B, T); "sequence" level: (B, 1)
if self.loss_type == "cispo":
clamped_ratios = torch.clamp(coef_1, max=self.epsilon_high).detach()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe in the documentation of epsilon_high we can mention that this is the value used for epsilon_max when used with CISPO loss. and that the paper recommends =5.0


A systematic study that defines a framework for analyzing and predicting reinforcement learning scaling in large language models, identifies key design choices that affect compute efficiency and propose a best-practice recipe called ScaleRL.

You can partially reproduce the ScaleRL recipe using the `GRPOTrainer` with the following configs:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not needed for this PR, but it would be neat to have a list of what's supported and what's not. If useful, these are some light reading notes

baseline: GRPO beta=0 and asymetric clipping

scaling low applies after ∼1.5k GPU hours

3 stages:
1. Ablation: Test individual design choices on the baseline (3.5k–4k GPU-hours) to identify stable ones.
2. LOO Experiments: Combine stable choices into *ScaleRL* and run 16k GPU-hour leave-one-out tests to assess predictability (fit first 8k, extrapolate rest).
3. Scaling Demonstration: Validate ScaleRL predictability on larger, more complex setups (bigger batches, MoE models, multitask, longer sequences).

### Asynchronous RL Setup

**PPO-off-policy-k** is equivalent to `steps_per_generation` in TRL. -> faster convergence `steps_per_generation=8` than `steps_per_generation=1`
**PipelineRL-k** acheives true asyncronisity by allowing weight updates during rollouts. -> faster convergence than **PPO-off-policy-k**

Note that the asymptotic performance is the same for these methods.

### Algorithmic Choices

* loss type: DAPO/GSPO/CISPO: CISPO > GSPO >> DAPO (not yet in TRL, but PR opened by a contributor)
* precision fixes: fp32 for lm_head substantially improves final performance (`cast_lm_head_to_fp32=True` in TRL)
* loss aggregation: (GRPO vs DAPO vs DrGRPO): DAPO-style works best (`loss_type="dapo"` in TRL)
* advantage normalization (prompt-level vs batch-level vs none): equivalent, choose batch-level ( `scale_rewards="batch"` in TRL)
* batch definition: zero-Variance filtering gives better final performance (not implemented in TRL!)
* data curriculum: filter "too easy" prompt are filtered for futur epochs (not yet in TRL)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a nice list! I might have a PR for batch definition coming up, have a draft version that needs polishing.

epsilon_high (`float`, *optional*):
Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the
When used with `loss_type="cispo"`, this corresponds to the ε_max param specified in the

Upper-bound epsilon value for clipping. If not specified, it defaults to the same value as the lower-bound
specified in argument `epsilon`. Paper [DAPO](https://huggingface.co/papers/2503.14476) recommends `0.28`.
When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the
[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`.
[ScaleRL paper](https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`.

The clipped weights are then multiplied with the advantages and policy model's log probs.
Individual token losses are aggregated by normalizing with the number of active tokens in
the global accumulated batch. This method was introduced in the
[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).
[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585).

"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`."
"lower-bound specified in argument `epsilon`. Paper DAPO recommends `0.28`. "
"When used with `loss_type='cispo`, this corresponds to the ε_max param specified in the"
"[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"[ScaleRL paper](https://arxiv.org/pdf/2510.13786) and the recommended value is `5.0`."
"[ScaleRL paper](https://huggingface.co/papers/2510.13786) and the recommended value is `5.0`."

"The clipped weights are then multiplied with the advantages and policy model's log probs. "
"Individual token losses are aggregated by normalizing with the number of active tokens in "
"the global accumulated batch. This method was introduced in the "
"[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"[MiniMax-M1 paper][MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."
"[MiniMax-M1 paper](https://huggingface.co/papers/2506.13585)."

@pramodith
Copy link
Collaborator Author

Failing CI seems unrelated to this PR, going ahead and merging.

@pramodith pramodith merged commit 642b721 into huggingface:main Nov 6, 2025
3 of 9 checks passed
@Neelectric
Copy link

In the Paper Index docs, this is mentioned as a code snippet:

from trl import GRPOConfig

config = GRPOConfig(
    loss_type="cispo",
    epsilon_high=5.0,
    num_completions=16,
    scale_rewards="batch",
    cast_lm_head_to_fp32=True
)

I believe it should be num_generations instead of num_completions?

@qgallouedec
Copy link
Member

True! Do you want to open a PR?

@Neelectric
Copy link

Fair point! Just saw that @pramodith beat me to it :)

qgallouedec added a commit that referenced this pull request Nov 21, 2025
commit 52ed4df
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Thu Nov 20 21:41:23 2025 +0000

    Fix style OpenEnv example

commit a263946
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Nov 20 14:44:15 2025 +0100

    Update OpenEnv guide with latest details (#4552)

    Co-authored-by: burtenshaw <ben.burtenshaw@gmail.com>

commit 1a9ff52
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Wed Nov 19 15:34:25 2025 +0100

    [OpenEnv] browsergym example script (#4539)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 6cbcd94
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Wed Nov 19 14:39:44 2025 +0100

    Update OpenEnv example scripts (#4547)

commit 8510589
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Wed Nov 19 14:39:20 2025 +0100

    Add OpenEnv Script examples to docs (#4533)

commit e622196
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Mon Nov 17 03:12:30 2025 -0700

    [Doc] Drop dummy reward and dataset for DeepMath-103K and accuracy reward (#4524)

commit 1b1242c
Author: Kashif Rasul <kashif.rasul@gmail.com>
Date:   Fri Nov 14 20:51:41 2025 +0100

    [OpenEnv] add vllm colocate mode to openenv scripts (#4510)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit f39d18a
Author: Fabio Milentiansen Sim <sim.fabio.fms@gmail.com>
Date:   Fri Nov 14 23:39:02 2025 +0700

    fix(GOLDTrainer): Resolve incorrect attribute access and VLLMClient.generate() output type (#4526)

commit d45eaab
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 14 12:12:09 2025 +0100

    Add vLLM quantization option for colocate (#4496)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

commit a91d4b3
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Fri Nov 14 02:19:08 2025 +0100

    Prevent upcasting norm layers in `prepare_model_for_kbit_training` (#4457)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 121318e
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 17:13:16 2025 -0800

    docs: Extend CLI basic usage examples to all supported CLIs (#4425)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 7918320
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 13:20:52 2025 -0700

    Remove test trainer args (#4517)

commit 102dc41
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 12:36:43 2025 -0700

    Rename `flash-attn` to `flash-attn2` (#4514)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 5de62b0
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Thu Nov 13 12:05:48 2025 -0700

    Add step time metric to GRPO Trainer for performance tracking (#4516)

    Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>

commit f1e6377
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 11:01:19 2025 -0800

    Move PPOTrainer to trl.experimental.ppo (#4482)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 01f497e
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 13 10:14:58 2025 -0800

    Move NashMDTrainer to experimental module (#4477)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit b6c838a
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Thu Nov 13 16:53:26 2025 +0000

    `aws-general-8-plus` runner for Docker build

commit ed5c7bb
Author: YangKai0616 <kai.yang@intel.com>
Date:   Fri Nov 14 00:42:48 2025 +0800

    [Bug Fix] OnlineDPOTrainer with vLLM Server Mode (#4500)

commit ded9bc6
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Thu Nov 13 17:33:59 2025 +0100

    Fix Docker images for Liger (#4522)

commit fd04760
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 13 11:31:10 2025 +0000

    Paper Index: Change `num_completions` to `num_generations` (#4515)

commit b7918c0
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Wed Nov 12 20:35:44 2025 -0800

    Move GKDTrainer to experimental module (#4474)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 07b5011
Author: Tamoghno Kandar <55907205+tamoghnokandar@users.noreply.github.com>
Date:   Wed Nov 12 20:07:33 2025 -0800

    Replace flash attention2 with kernels-community/flash-attn2 (#4426)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit 7a57fd4
Author: Yuxian Gu <guyx21@mails.tsinghua.edu.cn>
Date:   Thu Nov 13 11:16:20 2025 +0800

    MiniLLM: Fix arguments in config & add to documentation index (#4518)

commit a145eaf
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Wed Nov 12 16:35:46 2025 -0800

    refactor: Move CPOTrainer to experimental module (#4470)

commit d2dc717
Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com>
Date:   Thu Nov 13 00:56:47 2025 +0100

    Replace `wandb_log_unique_prompts` with `log_unique_prompts` (#4508)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 799b39b
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 12 16:21:05 2025 -0700

    `device_map` and `dtype` to `"auto"` by default (#4509)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit a6a2beb
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Wed Nov 12 09:42:31 2025 -0700

    Add temporary workaround for `lr_scheduler_kwargs` dtype issue in Transformers 4.57.0 (#4513)

commit 346701a
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Wed Nov 12 17:42:18 2025 +0100

    Replace accelerate logging with stdlib in CLI (#4512)

commit 4db63af
Author: Quentin Gallouédec <gallouedec.quentin@gmail.com>
Date:   Wed Nov 12 02:19:51 2025 +0000

    Fix GRPO unsqueeze advantages

commit ecb2811
Author: Yuxian Gu <guyx21@mails.tsinghua.edu.cn>
Date:   Wed Nov 12 10:17:22 2025 +0800

    Add MiniLLM Trainer (#4504)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 89e4688
Author: Taha Yassine <40228615+taha-yassine@users.noreply.github.com>
Date:   Tue Nov 11 20:36:23 2025 +0100

    Add support for images inside tables with Trackio completions logging (#4505)

commit 2d3279c
Author: lewtun <lewis.c.tunstall@gmail.com>
Date:   Tue Nov 11 19:22:25 2025 +0100

    Tweak description for vLLM sleep mode (#4506)

    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 02a3477
Author: Luke Hinds <lukehinds@gmail.com>
Date:   Mon Nov 10 16:41:51 2025 +0000

    Fix link to OpenEnv docs (#4502)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>

commit aaed6c1
Author: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Date:   Sat Nov 8 08:20:48 2025 -0700

    Consistency regarding relative imports (#4498)

commit 20760ba
Author: burtenshaw <ben.burtenshaw@gmail.com>
Date:   Fri Nov 7 10:50:50 2025 +0100

    [DOCS] update and fix openenv (#4490)

    Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 64cfca4
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 22:47:04 2025 -0800

    Move judges to experimental submodule (#4439)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 97ca1a2
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Fri Nov 7 00:20:15 2025 +0000

    Fix bugs in CISPO conditions (#4499)

commit ffb3dd5
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 16:03:00 2025 -0800

    docs: Add PEFT subsection to reducing memory usage guide (#4430)

    Co-authored-by: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>

commit 43b6541
Author: SolarWindRider <31797478+SolarWindRider@users.noreply.github.com>
Date:   Fri Nov 7 06:55:34 2025 +0800

    Support completion bootstrap for VLM in GRPO/RLOO (#4452)

    Co-authored-by: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 642b721
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 6 22:33:00 2025 +0000

    ScaleRL: Add CISPO Loss (#4495)

commit 32e9c9f
Author: Ishita Bhattacharyya <139248026+ishitab02@users.noreply.github.com>
Date:   Fri Nov 7 03:37:43 2025 +0530

    ⛴️ Add kernels to Docker images (#4445)

    Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 1bcfc50
Author: Behrooz Azarkhalili <80390531+behroozazarkhalili@users.noreply.github.com>
Date:   Thu Nov 6 13:40:12 2025 -0800

    Move XPOTrainer to trl.experimental.xpo (#4485)

    Co-authored-by: Invidia19 <54266187+Invidia19@users.noreply.github.com>
    Co-authored-by: Quentin Gallouédec <gallouedec.quentin@gmail.com>

commit 37942bc
Author: Pramodith Ballapuram <16939722+pramodith@users.noreply.github.com>
Date:   Thu Nov 6 21:32:03 2025 +0000

    Buffer samples based on group level stds. (#4492)

commit 66cd02a
Author: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com>
Date:   Thu Nov 6 20:58:25 2025 +0100

    Add tiny model Qwen3VLForConditionalGeneration to CI (#4494)

commit 32febb4
Author: Sergio Paniego Blanco <sergiopaniegoblanco@gmail.com>
Date:   Thu Nov 6 18:21:56 2025 +0100

    Add LFM2 to SFT notebook examples (#4455)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants