Skip to content

Commit

Permalink
Fix SFT tuner (#1278)
Browse files Browse the repository at this point in the history
  • Loading branch information
vwxyzjn authored Jan 26, 2024
1 parent 9a71e67 commit 3843cfc
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion commands/run_sft.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ accelerate launch $EXTRA_ACCELERATE_ARGS \
--dataset_name $DATASET_NAME \
--output_dir $OUTPUT_DIR \
--max_steps $MAX_STEPS \
--batch_size $BATCH_SIZE \
--per_device_train_batch_size $BATCH_SIZE \
--seq_length $SEQ_LEN \
$EXTRA_TRAINING_ARGS
"""
Expand Down
4 changes: 2 additions & 2 deletions docs/source/lora_tuning_peft.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ The `trl` library is powered by `accelerate`. As such it is best to configure an

```bash
accelerate config # will prompt you to define the training configuration
accelerate launch scripts/gpt2-sentiment_peft.py # launches training
accelerate launch examples/scripts/ppo.py --use_peft # launch`es training
```

## Using `trl` + `peft` and Data Parallelism
Expand Down Expand Up @@ -140,5 +140,5 @@ python PATH_TO_SCRIPT
You can easily fine-tune Llama2 model using `SFTTrainer` and the official script! For example to fine-tune llama2-7b on the Guanaco dataset, run (tested on a single NVIDIA T4-16GB):

```bash
python examples/scripts/sft.py --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --batch_size 4 --gradient_accumulation_steps 2
python examples/scripts/sft.py --output_dir sft_openassistant-guanaco --model_name meta-llama/Llama-2-7b-hf --dataset_name timdettmers/openassistant-guanaco --load_in_4bit --use_peft --per_device_train_batch_size 4 --gradient_accumulation_steps 2
```
1 change: 1 addition & 0 deletions examples/scripts/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class ScriptArguments:
quantization_config=quantization_config,
)
tokenizer = AutoTokenizer.from_pretrained(model_config.model_name_or_path, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

################
# Dataset
Expand Down

0 comments on commit 3843cfc

Please sign in to comment.