Skip to content

Support token-level loss, make default#90

Merged
tyler-griggs merged 7 commits intomainfrom
tgriggs/token-loss
Jul 15, 2025
Merged

Support token-level loss, make default#90
tyler-griggs merged 7 commits intomainfrom
tgriggs/token-loss

Conversation

@tyler-griggs
Copy link
Member

@tyler-griggs tyler-griggs commented Jul 15, 2025

What does this PR do?

Adds support for token-level loss (ie, token_mean loss reduction type) as introduced by DAPO.

With token_mean loss reduction, all tokens in all sequences contribute equally to loss.

The loss reduction type is configurable via trainer.algorithm.loss_reduction, but the default is updated to be token_mean, as opposed to our previous implementation (sequence_mean). This loss reduction is what the community is standardizing on as default (TRL's default, verl's default)

Wandb report of comparing token_mean vs sequence_mean: https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k/reports/Token-level-loss-token_mean---VmlldzoxMzYwMDc4MQ

The only plot with a notable difference is policy_loss, which is much larger for token_mean than it is for sequence_mean:
Screenshot 2025-07-15 at 9 52 57 AM

However, this policy_loss matches the same magnitude of pg_loss we observe in verl:
Screenshot 2025-07-15 at 9 54 39 AM

@tyler-griggs tyler-griggs changed the title Support token-level loss Support token-level loss, make default Jul 15, 2025
@SumanthRH SumanthRH self-assigned this Jul 15, 2025
Copy link
Member

@SumanthRH SumanthRH left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

Left a couple nits

loss_mask = torch.tensor([[1.0, 1.0, 1.0], [1.0, 0.0, 0.0]], device=device)

# Test token_mean without mask
loss_fn_token = PolicyLoss(loss_type="regular", loss_reduction="token_mean")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: you should explicitly pass in the eps low and eps high values here to make the test less brittle

@SumanthRH
Copy link
Member

Would be nice to add a screenshot for convergence on gms8k (and how it changes from before) before merging

tyler-griggs and others added 3 commits July 14, 2025 22:06
Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
@tyler-griggs tyler-griggs mentioned this pull request Jul 15, 2025
7 tasks
@tyler-griggs
Copy link
Member Author

Added details form gsm8k run in initial PR description.

@tyler-griggs tyler-griggs merged commit fc59170 into main Jul 15, 2025
3 checks passed
@SumanthRH SumanthRH deleted the tgriggs/token-loss branch July 16, 2025 23:21
fannie1208 pushed a commit to vinid/SkyRL that referenced this pull request Aug 19, 2025
## What does this PR do?
Adds support for token-level loss (ie, `token_mean` loss reduction type)
as introduced by DAPO.

With `token_mean` loss reduction, all tokens in all sequences contribute
equally to loss.

The loss reduction type is configurable via
`trainer.algorithm.loss_reduction`, but the default is updated to be
`token_mean`, as opposed to our previous implementation
(`sequence_mean`). This loss reduction is what the community is
standardizing on as default (TRL's
[default](huggingface/trl#2881), verl's
[default](https://github.com/volcengine/verl/blob/517cc23c9dbb0da5c2cd2b012466790e29cb781a/verl/trainer/config/actor/actor.yaml#L63))

Wandb report of comparing `token_mean` vs `sequence_mean`:
https://wandb.ai/sky-posttraining-uc-berkeley/gsm8k/reports/Token-level-loss-token_mean---VmlldzoxMzYwMDc4MQ

The only plot with a notable difference is `policy_loss`, which is much
larger for `token_mean` than it is for `sequence_mean`:
<img width="312" height="274" alt="Screenshot 2025-07-15 at 9 52 57 AM"
src="https://github.com/user-attachments/assets/40f94cb6-c5e5-47f6-9b09-a076811746a0"
/>

However, this `policy_loss` matches the same magnitude of `pg_loss` we
observe in verl:
<img width="980" height="611" alt="Screenshot 2025-07-15 at 9 54 39 AM"
src="https://github.com/user-attachments/assets/53714573-2b21-4e67-b30a-dd3648279438"
/>

---------

Co-authored-by: Sumanth R Hegde <39546518+SumanthRH@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants