Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC] RLHF follow-ups #1395

Closed
5 of 8 tasks
SalmanMohammadi opened this issue Aug 22, 2024 · 1 comment
Closed
5 of 8 tasks

[RFC] RLHF follow-ups #1395

SalmanMohammadi opened this issue Aug 22, 2024 · 1 comment

Comments

@SalmanMohammadi
Copy link
Collaborator

SalmanMohammadi commented Aug 22, 2024

There are several optimizations to our PPO recipe which could help push it closer to SOTA in terms of performance. There are also several pieces of documentation we could offer alongside this recipe to increase visibility and improve accessibility. These are non-comprehensive and not all required.

Documentation


  • Recipe documentation page which sufficiently explains how to use the recipe, including:
  • Recipe tutorial guiding a user through the PPO fine-tuning process. This tutorial will only provide a high-level overview of the algorithm, and cover:
    • how to identify a suitable reward model and dataset for a given task
    • downloading models and guidance on early experimentation (start small and fast)
    • tuning parameters/things to look out for/how you know it's working
    • evaluating trained models (qualtiative analysis w/ generations, quantitative analysis through an appropriate eleuther eval)
    • uploading to hub?
  • E2E RLHF workflow in torchtune - SFT -> Reward modelling -> PPO
  • Recipe deepdive covering:
    • Explanation of relevant papers and priorart
    • Implementation details a la 1, 2

Optimizations

Rough benchmarks from deepspeed

I think the results from this page all use LoRA. Nonetheless, it's one of the only sources of compute useage for a modern RLHF implementation.

image

*It's unclear what size of reward model is used here. Throughout the blogpost they use reward model sizes << policy model sizes.

image

They also state:

For now, we suggest that users use "Total-GPU-Memory-in-GB / 6" as the upper parameter bound in billions for the sum of the actor model and critical model, for safety. Nevertheless, users are welcome to try the real limit.

Which gives 13.B for the combined memory of both actor + critic model on a single A100 80GB.


Compile issues @ebsmothers

  • Fix compile for recipe? (blocker: currently hitting recompile cache limit due to swapping between inference-training modes - plus many more)

  • Enable compile for batched RLHF generation utils  #1402

    • Managed to solve initial graph breaks due to rng object, by hiding away sampling logic into a util and outside of the compiled fn
    • New blocker: no more graph breaks and an initial recompile, but very slow
    • Fixed above issues. Refactoring and cleaning up - seeing ~20% speedups with compile.
  • Enable compile for trajectory generation step?

  • Enable compile for loss step?

  • ?? how else can we make inference go fast?

  • Reference +/ reward model offload to CPU @ebsmothers

  • Optimizer offload to CPU (Add CPU offload optimizer from torchao #1351) (to benchmark once it lands)

  • (From deepspeed link above) - granted these aren't strictly performance opt:

Exponential Moving Average (EMA) collection, where an EMA based checkpoint can be chosen for the final evaluation.
Mixture Training, which mixes the pretraining objective (i.e., the next word prediction) with the PPO objective to prevent regression performance on public benchmarks like SQuAD2.0.

cc @kartikayk

@gau-nernst
Copy link
Contributor

Not sure if it will be useful for you, but there are 8-bit and 4-bit AdamW in torchao https://github.com/pytorch/ao/tree/main/torchao/prototype/low_bit_optim. Both support FSDP1/2.
8-bit AdamW should match bnb and 4-bit version should match lpmm exactly. They are included in torchao 0.4 release, but there is a bug in handling LR schedule (fixed in main branch).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants