Skip to content

Commit

Permalink
RLHF with PPO (#1005)
Browse files Browse the repository at this point in the history
  • Loading branch information
SalmanMohammadi authored Aug 5, 2024
1 parent 5019074 commit c593c10
Show file tree
Hide file tree
Showing 35 changed files with 3,639 additions and 245 deletions.
15 changes: 15 additions & 0 deletions docs/source/api_ref_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ Loss
:toctree: generated/
:nosignatures:

loss.PPOLoss
loss.DPOLoss
loss.RSOLoss
loss.IPOLoss
Expand All @@ -98,3 +99,17 @@ Functions used for preprocessing images.
transforms.tile_crop
transforms.find_supported_resolutions
transforms.VisionCrossAttentionMask

Reinforcement Learning From Human Feedback (RLHF)
--------------------------------------------------
Components for RLHF algorithms like PPO.

.. autosummary::
:toctree: generated/
:nosignatures:

rlhf.estimate_advantages
rlhf.get_rewards_ppo
rlhf.truncate_sequence_at_first_stop_token
rlhf.left_padded_collate
rlhf.padded_collate_dpo
1 change: 0 additions & 1 deletion docs/source/api_ref_utilities.rst
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,6 @@ Utilities for working with data and datasets.
:nosignatures:

padded_collate
padded_collate_dpo

.. _gen_label:

Expand Down
180 changes: 180 additions & 0 deletions recipes/configs/mistral/7B_full_ppo_low_memory.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py
# using a Mistral 7B model.
#
# This config has been tested on an A100 80GB.
# This config uses hyperparameters based on small set of experiments and information
# available from existing implementations.
#
# This config assumes that you've run the following command before launching
# this run:
# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns=""
# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN
#
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders.
# The default config uses an optimizer from bitsandbytes. If you do not have it installed,
# you can install it with
# pip install bitsandbytes
#
# To launch on a single device, run the following command from root:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run ppo_full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#

# Tokenizer
tokenizer:
_component_: torchtune.models.mistral.mistral_tokenizer
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model

# Dataset
dataset:
_component_: torchtune.datasets.text_completion_dataset
source: trl-internal-testing/sentiment-trl-style
max_seq_len: null
split: train
column: prompt
add_eos: False

policy_model:
_component_: torchtune.models.mistral.mistral_7b

# we need to manually build the mistral classifier model
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token)
reward_and_value_model:
_component_: torchtune.models.mistral._component_builders.mistral_classifier
attn_dropout: 0.0
embed_dim: 4096
intermediate_dim: 14336
max_seq_len: 32768
norm_eps: 1.0e-05
num_classes: 1
num_heads: 32
num_kv_heads: 8
num_layers: 32
vocab_size: 32001

# checkpointer for the policy model - update this if resuming from checkpoint
checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files: [
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin"
]
# this is the only place where you should update `recipe_checkpoint` if resuming training
recipe_checkpoint: null
output_dir: ${output_dir}/policy
model_type: MISTRAL

# this should be setup identically to the policy model checkpointer at the start of training
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training
ref_policy_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/
checkpoint_files: [
"pytorch_model-00001-of-00003.bin",
"pytorch_model-00002-of-00003.bin",
"pytorch_model-00003-of-00003.bin"
]
output_dir: ${output_dir}/policy
model_type: MISTRAL

# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint
# since this model will be identical to the reward model it's helpful to initialise this
# from the trained reward model weights
value_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files: [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors"
]
output_dir: ${output_dir}/value
model_type: REWARD

# checkpointer for the reward model, ensure `checkpoint_files`
# always points to the original reward model weights, even if resuming training
reward_checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/RM-Mistral-7B/
checkpoint_files: [
"model-00001-of-00003.safetensors",
"model-00002-of-00003.safetensors",
"model-00003-of-00003.safetensors"
]
output_dir: ${output_dir}/value
model_type: REWARD


resume_from_checkpoint: False
output_dir: /tmp/mistral7b-ppo-finetune
seed: null
shuffle: True

# Training env
device: cuda

# Training arguments
batch_size: 64
num_steps: 10000
ppo_epochs: 2
ppo_batch_size: 32
gradient_accumulation_steps: 1

# Memory management and performance
compile: True
optimizer:
_component_: bitsandbytes.optim.PagedAdamW
lr: 3e-6
optimizer_in_bwd: True
log_peak_memory_stats: False
enable_activation_checkpointing: True

# Reduced precision
dtype: bf16


# batch size for forward pass during generation
forward_batch_size: 16
max_generated_tokens: 58
temperature: 0.7
top_k: null

# parameter for penalising generations shorter than `min_response_length`
min_response_length: 18
# parameter for penalising generations without a stop token
penalise_no_eos: True
# scalar penalty to apply when penalising
reward_penalty: -3

# tokens to consider as "end of sequence" tokens
stop_token_ids: [
2, # eos_id
28723 # mistral "." token
]
whiten_rewards: False

# GAE hyperparameters
gamma: 1
lmbda: 0.95

# PPO hyperparameters
loss:
_component_: torchtune.modules.loss.PPOLoss
epsilon: 0.2
value_coeff: 0.1
value_clip_range: 0.2
kl_coeff: 0.01


# Logging
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}

log_every_n_steps: 1
3 changes: 2 additions & 1 deletion recipes/lora_dpo_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -449,7 +450,7 @@ def _setup_data(
batch_size=batch_size,
sampler=sampler,
collate_fn=partial(
utils.padded_collate_dpo,
rlhf.padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
3 changes: 2 additions & 1 deletion recipes/lora_dpo_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from torchtune import config, modules, utils
from torchtune.data import CROSS_ENTROPY_IGNORE_IDX
from torchtune.datasets import ConcatDataset
from torchtune.modules import rlhf
from torchtune.modules.peft.peft_utils import (
disable_adapter,
get_adapter_params,
Expand Down Expand Up @@ -345,7 +346,7 @@ def _setup_data(
sampler=sampler,
batch_size=batch_size,
collate_fn=partial(
utils.padded_collate_dpo,
rlhf.padded_collate_dpo,
padding_idx=self._tokenizer.pad_id,
ignore_idx=CROSS_ENTROPY_IGNORE_IDX,
),
Expand Down
Loading

0 comments on commit c593c10

Please sign in to comment.