-
Notifications
You must be signed in to change notification settings - Fork 523
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5019074
commit c593c10
Showing
35 changed files
with
3,639 additions
and
245 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
# Config for single device RLHF full finetuning using PPO in ppo_full_finetune_single_device.py | ||
# using a Mistral 7B model. | ||
# | ||
# This config has been tested on an A100 80GB. | ||
# This config uses hyperparameters based on small set of experiments and information | ||
# available from existing implementations. | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# tune download weqweasdas/RM-Mistral-7B --output-dir /tmp/RM-Mistral-7B/ --ignore-patterns="" | ||
# tune download mistralai/Mistral-7B-Instruct-v0.2 --output-dir /tmp/Mistral-7B-Instruct-v0.2/ --hf-token HF_TOKEN | ||
# | ||
# You'll also need to ensure that {output_dir} exists beforehand, as checkpoints for policy and value models are saved in sub-folders. | ||
# The default config uses an optimizer from bitsandbytes. If you do not have it installed, | ||
# you can install it with | ||
# pip install bitsandbytes | ||
# | ||
# To launch on a single device, run the following command from root: | ||
# tune run ppo_full_finetune_single_device --config mistral/7B_full_ppo_low_memory | ||
# | ||
# You can add specific overrides through the command line. For example | ||
# to override the checkpointer directory while launching training | ||
# you can run: | ||
# tune run ppo_full_finetune_single_device --config mistral/7B_full_low_memory checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR> | ||
# | ||
|
||
# Tokenizer | ||
tokenizer: | ||
_component_: torchtune.models.mistral.mistral_tokenizer | ||
path: /tmp/Mistral-7B-Instruct-v0.2/tokenizer.model | ||
|
||
# Dataset | ||
dataset: | ||
_component_: torchtune.datasets.text_completion_dataset | ||
source: trl-internal-testing/sentiment-trl-style | ||
max_seq_len: null | ||
split: train | ||
column: prompt | ||
add_eos: False | ||
|
||
policy_model: | ||
_component_: torchtune.models.mistral.mistral_7b | ||
|
||
# we need to manually build the mistral classifier model | ||
# because our reward model checkpoint has a larger vocabulary size (due to an added padding token) | ||
reward_and_value_model: | ||
_component_: torchtune.models.mistral._component_builders.mistral_classifier | ||
attn_dropout: 0.0 | ||
embed_dim: 4096 | ||
intermediate_dim: 14336 | ||
max_seq_len: 32768 | ||
norm_eps: 1.0e-05 | ||
num_classes: 1 | ||
num_heads: 32 | ||
num_kv_heads: 8 | ||
num_layers: 32 | ||
vocab_size: 32001 | ||
|
||
# checkpointer for the policy model - update this if resuming from checkpoint | ||
checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ | ||
checkpoint_files: [ | ||
"pytorch_model-00001-of-00003.bin", | ||
"pytorch_model-00002-of-00003.bin", | ||
"pytorch_model-00003-of-00003.bin" | ||
] | ||
# this is the only place where you should update `recipe_checkpoint` if resuming training | ||
recipe_checkpoint: null | ||
output_dir: ${output_dir}/policy | ||
model_type: MISTRAL | ||
|
||
# this should be setup identically to the policy model checkpointer at the start of training | ||
# ensure `checkpoint_files` always points to the original policy weights, even if resuming training | ||
ref_policy_checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/Mistral-7B-Instruct-v0.2/ | ||
checkpoint_files: [ | ||
"pytorch_model-00001-of-00003.bin", | ||
"pytorch_model-00002-of-00003.bin", | ||
"pytorch_model-00003-of-00003.bin" | ||
] | ||
output_dir: ${output_dir}/policy | ||
model_type: MISTRAL | ||
|
||
# checkpointer for the value model - update `checkpoint_files` if resuming from checkpoint | ||
# since this model will be identical to the reward model it's helpful to initialise this | ||
# from the trained reward model weights | ||
value_checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/RM-Mistral-7B/ | ||
checkpoint_files: [ | ||
"model-00001-of-00003.safetensors", | ||
"model-00002-of-00003.safetensors", | ||
"model-00003-of-00003.safetensors" | ||
] | ||
output_dir: ${output_dir}/value | ||
model_type: REWARD | ||
|
||
# checkpointer for the reward model, ensure `checkpoint_files` | ||
# always points to the original reward model weights, even if resuming training | ||
reward_checkpointer: | ||
_component_: torchtune.utils.FullModelHFCheckpointer | ||
checkpoint_dir: /tmp/RM-Mistral-7B/ | ||
checkpoint_files: [ | ||
"model-00001-of-00003.safetensors", | ||
"model-00002-of-00003.safetensors", | ||
"model-00003-of-00003.safetensors" | ||
] | ||
output_dir: ${output_dir}/value | ||
model_type: REWARD | ||
|
||
|
||
resume_from_checkpoint: False | ||
output_dir: /tmp/mistral7b-ppo-finetune | ||
seed: null | ||
shuffle: True | ||
|
||
# Training env | ||
device: cuda | ||
|
||
# Training arguments | ||
batch_size: 64 | ||
num_steps: 10000 | ||
ppo_epochs: 2 | ||
ppo_batch_size: 32 | ||
gradient_accumulation_steps: 1 | ||
|
||
# Memory management and performance | ||
compile: True | ||
optimizer: | ||
_component_: bitsandbytes.optim.PagedAdamW | ||
lr: 3e-6 | ||
optimizer_in_bwd: True | ||
log_peak_memory_stats: False | ||
enable_activation_checkpointing: True | ||
|
||
# Reduced precision | ||
dtype: bf16 | ||
|
||
|
||
# batch size for forward pass during generation | ||
forward_batch_size: 16 | ||
max_generated_tokens: 58 | ||
temperature: 0.7 | ||
top_k: null | ||
|
||
# parameter for penalising generations shorter than `min_response_length` | ||
min_response_length: 18 | ||
# parameter for penalising generations without a stop token | ||
penalise_no_eos: True | ||
# scalar penalty to apply when penalising | ||
reward_penalty: -3 | ||
|
||
# tokens to consider as "end of sequence" tokens | ||
stop_token_ids: [ | ||
2, # eos_id | ||
28723 # mistral "." token | ||
] | ||
whiten_rewards: False | ||
|
||
# GAE hyperparameters | ||
gamma: 1 | ||
lmbda: 0.95 | ||
|
||
# PPO hyperparameters | ||
loss: | ||
_component_: torchtune.modules.loss.PPOLoss | ||
epsilon: 0.2 | ||
value_coeff: 0.1 | ||
value_clip_range: 0.2 | ||
kl_coeff: 0.01 | ||
|
||
|
||
# Logging | ||
metric_logger: | ||
_component_: torchtune.utils.metric_logging.DiskLogger | ||
log_dir: ${output_dir} | ||
|
||
log_every_n_steps: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.