Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RLHF with PPO #1005

Merged
merged 44 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
11d88a2
Refactoring TransformerDecoder and adding value-head transformers
SalmanMohammadi May 9, 2024
2849ec5
adding ppo config and recipe to registry
SalmanMohammadi May 10, 2024
f0c1410
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 12, 2024
57c67bf
implemented ppo recipe structure, advantage and return estimation, tr…
SalmanMohammadi May 15, 2024
03cba4b
finished first pass implementation of ppo. added tests for ppo loss
SalmanMohammadi May 15, 2024
f50f047
reverting changes
SalmanMohammadi May 15, 2024
b034af7
adding lora to ppo recipe, adding lora value head component and model…
SalmanMohammadi May 16, 2024
466b683
added lora training, added value head checkpointing and recipe resumi…
SalmanMohammadi May 19, 2024
928037d
removing test model builders, adding batched generation to ppo recipe…
SalmanMohammadi May 21, 2024
68b6162
fixing bug in _checkpointer.py
SalmanMohammadi May 21, 2024
65ca12a
Adding support for user-provided masks in attention
SalmanMohammadi May 30, 2024
9d8c5a8
Merge branch 'pytorch:main' into ppo
SalmanMohammadi May 31, 2024
b99102c
merging transformer custom masking, adding support for generation wit…
SalmanMohammadi Jun 4, 2024
a1cde1c
adding functionality for truncation in generation, and further tests …
SalmanMohammadi Jun 4, 2024
b032778
updated lora recipe to use custom generation
SalmanMohammadi Jun 6, 2024
f126e9a
Merge branch 'pytorch:main' into ppo
SalmanMohammadi Jun 6, 2024
04d514a
added support for correct truncation and padding of responses, added …
SalmanMohammadi Jun 7, 2024
4854908
added correct mask and position id trajectory generation, score rejec…
SalmanMohammadi Jun 8, 2024
c885833
bugfixing in ppo recipe. refactoring ppo_utils and tests to individua…
SalmanMohammadi Jun 8, 2024
57d57fa
updating ppo_utils namespace
SalmanMohammadi Jun 8, 2024
cce5548
fixing bug in collation, updating loss tests
SalmanMohammadi Jun 10, 2024
c289566
bugfixes in masking and indexing logprobs and values, added fixed kl …
SalmanMohammadi Jun 12, 2024
a3fa1ea
added loss and value masking
SalmanMohammadi Jun 14, 2024
c3db142
some refactoring, lots of testing and docs
SalmanMohammadi Jun 16, 2024
589bf7d
improved early training stability by adding value head init. from rew…
SalmanMohammadi Jun 16, 2024
346c30b
updating metrics
SalmanMohammadi Jun 18, 2024
2e9d779
reworking causal masking
SalmanMohammadi Jun 18, 2024
46b75be
freeing up memory after steps to avoid mem leaks
SalmanMohammadi Jun 18, 2024
0fd885e
Merge branch 'main' into ppo
SalmanMohammadi Jul 16, 2024
1942b0f
cleaning up; verifying results; switching to full finetune
SalmanMohammadi Jul 16, 2024
58d92ab
tidying up
SalmanMohammadi Jul 16, 2024
1fbb6dc
detaching losses for metric logging
SalmanMohammadi Jul 18, 2024
65ef9dc
removing 1b, merging main
SalmanMohammadi Jul 25, 2024
c7bbff1
merging
SalmanMohammadi Jul 25, 2024
1129f9e
deleting logits in loss
SalmanMohammadi Jul 29, 2024
fe87dfb
Merge branch 'main' into ppo
SalmanMohammadi Aug 2, 2024
662ab2c
cleaning conf
SalmanMohammadi Aug 2, 2024
76b124f
pYdOcLiNt
SalmanMohammadi Aug 2, 2024
dc4887c
downloading weights
SalmanMohammadi Aug 3, 2024
ef85dba
addressing comments
SalmanMohammadi Aug 5, 2024
fd87fe6
updating test
SalmanMohammadi Aug 5, 2024
ba365a8
let's finish this the way we started... together
SalmanMohammadi Aug 5, 2024
e76304c
Merge branch 'main' into ppo
SalmanMohammadi Aug 5, 2024
4e6be43
lInTiNG
SalmanMohammadi Aug 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ Loss
:nosignatures:

loss.DPOLoss
loss.PPOLoss


Vision Transforms
Expand All @@ -96,3 +97,17 @@ Functions used for preprocessing images.
transforms.tile_crop
transforms.find_supported_resolutions
transforms.VisionCrossAttentionMask

Reinforcement Learning From Human Feedback (RLHF)
--------------------------------------------------
Components for RLHF algorithms like PPO.

.. autosummary::
:toctree: generated/
:nosignatures:

rlhf.estimate_advantages
rlhf.get_rewards_ppo
rlhf.truncate_sequence_at_first_stop_token
rlhf.left_padded_collate
rlhf.padded_collate_dpo
1 change: 0 additions & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ Utilities for working with data and datasets.
:nosignatures:

padded_collate
padded_collate_dpo

.. _gen_label:

Expand Down
182 changes: 182 additions & 0 deletions recipes/configs/llama2/1B_full_ppo.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
# using a TinyLlama2 1B model.
#
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download TinyLlama/TinyLlama_v1.1 --hf-token <HF_TOKEN> --output-dir /tmp/TinyLlama_v1.1
# tune download smohammadi/tinyllama_rm_sentiment_1b --hf-token <HF_TOKEN> --output-dir /tmp/tinyllama_rm_sentiment_1b --ignore-patterns ""
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# To launch on a single device, run the following command from root:
# tune run ppo_full_finetune_single_device --config llama2/1B_full_ppo_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run ppo_full_finetune_single_device --config llama2/1B_full_ppo checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

# Tokenizer
tokenizer:
_component_: torchtune.models.llama2.llama2_tokenizer
path: /tmp/TinyLlama_v1.1/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: trl-internal-testing/sentiment-trl-style
max_seq_len: null
split: train
column: prompt
add_eos: False


# manually constructing 1B models
policy_model:
_component_: torchtune.models.llama2.llama2
vocab_size: 32000
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
num_layers: 22
num_heads: 32
num_kv_heads: 4
embed_dim: 2048
max_seq_len: 2048
intermediate_dim: 5632
attn_dropout: 0.0
norm_eps: 1e-5

reward_and_value_model:
_component_: torchtune.models.llama2.llama2_classifier
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
num_classes: 1
vocab_size: 32000
num_layers: 22
num_heads: 32
num_kv_heads: 4
embed_dim: 2048
max_seq_len: 2048
intermediate_dim: 5632
attn_dropout: 0.0
norm_eps: 1e-5

# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/TinyLlama_v1.1
checkpoint_files: [
"pytorch_model.bin",
]
# this is the only place where you should update `recipe_checkpoint` if resuming training
recipe_checkpoint: null
output_dir: ${output_dir}/policy
model_type: LLAMA2

# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/TinyLlama_v1.1
checkpoint_files: [
"pytorch_model.bin",
]
output_dir: ${output_dir}/policy
model_type: LLAMA2

# checkpointer for the value model - update this if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/tinyllama_rm_sentiment_1b
# only `checkpoint_files` need to be updated if resuming training
checkpoint_files: [
"model.safetensors"
]
output_dir: ${output_dir}/value
model_type: MISTRAL_REWARD

# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/tinyllama_rm_sentiment_1b
checkpoint_files: [
"model.safetensors"
]
output_dir: ${output_dir}/value
model_type: MISTRAL_REWARD


# Training env
device: cuda

# Training arguments
batch_size: 256
num_steps: 100000
ppo_epochs: 2
ppo_batch_size: 128
gradient_accumulation_steps: 8

# Memory management and performance
compile: True
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-6

optimizer_in_bwd: False
log_peak_memory_stats: False
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16

# Trajectory generation arguments
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved

# batch size for forward pass during generation
forward_batch_size: 32
max_generated_tokens: 58
temperature: 0.7
top_k: null

# Reward args

# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3

# tokens to consider as "end of sequence" tokens
stop_token_ids: [
2, # eos_id
29889 # llama2 "." token
]
whiten_rewards: False
# GAE hyperparameters
gamma: 1
lmbda: 0.95

# PPO hyperparameters
loss:
_component_: torchtune.modules.loss.PPOLoss
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.2
kl_coeff: 0.01


# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}

log_every_n_steps: 1

resume_from_checkpoint: False
output_dir: /tmp/llama2-1b-ppo-finetune
seed: null
shuffle: True
SalmanMohammadi marked this conversation as resolved.
Show resolved Hide resolved
Loading
Loading