From b28c438def05472b07f40a2e75eb900b4c2eaa44 Mon Sep 17 00:00:00 2001 From: "Wang, Yi A" Date: Thu, 21 Dec 2023 23:08:19 -0800 Subject: [PATCH] clear up the parameters of supervised_finetuning.py no_gradient_checkpointing is always false Signed-off-by: Wang, Yi A --- examples/research_projects/stack_llama/scripts/README.md | 2 +- .../stack_llama/scripts/supervised_finetuning.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/research_projects/stack_llama/scripts/README.md b/examples/research_projects/stack_llama/scripts/README.md index 60ed5fd943..da9f067f20 100644 --- a/examples/research_projects/stack_llama/scripts/README.md +++ b/examples/research_projects/stack_llama/scripts/README.md @@ -1,7 +1,7 @@ # RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model. There were three main steps to the training process: 1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se: - - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --no_gradient_checkpointing --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` + - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` 2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm: - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=` 3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model: diff --git a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py index 47669ac8a7..e6e59a868d 100644 --- a/examples/research_projects/stack_llama/scripts/supervised_finetuning.py +++ b/examples/research_projects/stack_llama/scripts/supervised_finetuning.py @@ -38,9 +38,9 @@ def get_args(): parser.add_argument("--weight_decay", type=float, default=0.05) parser.add_argument("--local_rank", type=int, default=0) - parser.add_argument("--no_fp16", action="store_false") + parser.add_argument("--fp16", action="store_true", default=False) parser.add_argument("--bf16", action="store_true", default=False) - parser.add_argument("--no_gradient_checkpointing", action="store_false", default=False) + parser.add_argument("--gradient_checkpointing", action="store_true", default=False) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--num_workers", type=int, default=None) parser.add_argument("--output_dir", type=str, default="./checkpoints") @@ -159,8 +159,8 @@ def run_training(args, train_data, val_data): lr_scheduler_type=args.lr_scheduler_type, warmup_steps=args.num_warmup_steps, gradient_accumulation_steps=args.gradient_accumulation_steps, - gradient_checkpointing=not args.no_gradient_checkpointing, - fp16=not args.no_fp16, + gradient_checkpointing=args.gradient_checkpointing, + fp16=args.fp16, bf16=args.bf16, weight_decay=args.weight_decay, run_name="llama-7b-finetuned",