forked from thisserand/FastChat
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train-vicuna.yaml
130 lines (115 loc) · 4.19 KB
/
train-vicuna.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
resources:
accelerators: A100-80GB:8
disk_size: 1000
use_spot: true
num_nodes: 1
file_mounts:
/artifacts:
name: skypilot-chatbot # Change to your own bucket
store: gcs
mode: MOUNT
/data:
name: model-weights # Change to your own bucket
store: gcs
mode: MOUNT
# /llamma:
# name: llama-ckpts # Change to the bucket that contains the LLaMA weights
# store: gcs
# mode: MOUNT
workdir: .
setup: |
# Setup the environment
conda create -n chatbot python=3.10 -y
conda activate chatbot
# Install pytorch
pip install torch==1.13.1+cu116 torchvision==0.14.1+cu116 torchaudio==0.13.1 --extra-index-url https://download.pytorch.org/whl/cu116
# Install huggingface with the LLaMA commit
cd ~
git clone https://github.com/huggingface/transformers.git
cd transformers
git checkout c612628045822f909020f7eb6784c79700813eda # pin to latest commit
pip install .
cd ~/sky_workdir
# Install fastchat
pip install -e .
pip install flash-attn
mkdir -p /artifacts/llama-hf/llama-${MODEL_SIZE}B
if [ ! -f /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete ]; then
mkdir -p ~/llama-${MODEL_SIZE}b
gsutil -m rsync -r /llama/${MODEL_SIZE}b/ ~/llama-${MODEL_SIZE}b
cd ~/transformers
python src/transformers/models/llama/convert_llama_weights_to_hf.py \
--input_dir $HOME/llama-${MODEL_SIZE}b \
--model_size ${MODEL_SIZE}B \
--output_dir ~/hf-output || exit 1
mv ~/hf-output/tokenizer/* ~/hf-output/llama-${MODEL_SIZE}b
gsutil -m rsync -r ~/hf-output/llama-${MODEL_SIZE}b/ /artifacts/llama-hf/llama-${MODEL_SIZE}B
touch /artifacts/llama-hf/llama-${MODEL_SIZE}B/complete
else
mkdir -p ~/hf-output/llama-${MODEL_SIZE}b
gsutil -m cp -r /artifacts/llama-hf/llama-${MODEL_SIZE}B/* ~/hf-output/llama-${MODEL_SIZE}b
fi
run: |
conda activate chatbot
SEQ_LEN=${SEQ_LEN:-512}
GC_SCALE=${GC_SCALE:-1}
DATE=${DATE:-20230303}
USE_FLASH_ATTN=${USE_FLASH_ATTN:-0}
if [ $USE_FLASH_ATTN -eq 1 ]; then
TRAIN_SCRIPT=fastchat/train/train_mem.py
USE_FLASH_SUFFIX="-flash"
else
TRAIN_SCRIPT=fastchat/train/train.py
USE_FLASH_SUFFIX=""
fi
echo "Training with seq_len=${SEQ_LEN} and gc_scale=${GC_SCALE}"
PER_DEVICE_BATCH_SIZE=$((2048 * $GC_SCALE / $SEQ_LEN))
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Do the periodic syncing manually, to avoid the degradation of
# the training for saving checkpoints.
mkdir -p ~/.checkpoints
LOCAL_CKPT_PATH=~/.checkpoints
CKPT_PATH=/artifacts/chatbot/${MODEL_SIZE}b/sharegpt-${DATE}-seq-${SEQ_LEN}${USE_FLASH_SUFFIX}
last_ckpt=$(ls ${CKPT_PATH} | grep -E '[0-9]+' | sort -t'-' -k1,1 -k2,2n | tail -1)
mkdir -p ~/.checkpoints/${last_ckpt}
gsutil -m rsync -r ${CKPT_PATH}/${last_ckpt}/ ~/.checkpoints/${last_ckpt}
bash scripts/sync_local_checkpoint.sh ${LOCAL_CKPT_PATH} ${CKPT_PATH} > sync.log 2>&1 &
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path ~/hf-output/llama-${MODEL_SIZE}b \
--data_path /data/sharegpt/sharegpt_20230322_clean_lang_split.json \
--bf16 True \
--output_dir $LOCAL_CKPT_PATH \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 1200 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--gradient_checkpointing True \
--lazy_preprocess True
# Sync any files not in the checkpoint-* folders
gsutil -m rsync -r -x 'checkpoint-*' $LOCAL_CKPT_PATH/ $CKPT_PATH/
envs:
MODEL_SIZE: 13
SEQ_LEN: 2048
GC_SCALE: 4
DATE: 20230322
USE_FLASH_ATTN: 1