Skip to content

Commit bd6e627

Browse files
authored
Merge pull request #13 from winglian/dev
merge dev branch for various fixes
2 parents a56cf2d + e09cb40 commit bd6e627

9 files changed

+301
-110
lines changed

TODO.md

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# todo list
2+
3+
- [] Validation of parameters for combinations that won't work
4+
5+
6+
7+
## things that are known not to work
8+
9+
- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
10+
- adamw_bnb_8bit doesn't play well with FSDP offload

ds_config.json

+24-3
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,42 @@
1010
"hysteresis": 2,
1111
"min_loss_scale": 1
1212
},
13+
"optimizer": {
14+
"type": "Adam",
15+
"params": {
16+
"lr": "auto",
17+
"betas": "auto",
18+
"eps": "auto",
19+
"weight_decay": "auto"
20+
}
21+
},
1322
"scheduler": {
14-
"type": "OneCycle",
23+
"type": "WarmupDecayLR",
1524
"params": {
16-
"cycle_min_lr": 1e-7,
17-
"cycle_max_lr": 1e-4
25+
"warmup_min_lr": "auto",
26+
"warmup_max_lr": "auto",
27+
"warmup_num_steps": "auto",
28+
"total_num_steps": "auto"
1829
}
1930
},
2031
"zero_optimization": {
2132
"stage": 2,
33+
"offload_optimizer": {
34+
"device": "cpu",
35+
"pin_memory": true
36+
},
37+
"offload_param": {
38+
"device": "cpu",
39+
"pin_memory": true
40+
},
2241
"overlap_comm": true,
2342
"allgather_partitions": true,
2443
"allgather_bucket_size": 5e8,
2544
"contiguous_gradients": true,
2645
"reduce_bucket_size": "auto",
2746
"reduce_scatter": true,
47+
"stage3_max_live_parameters": 0,
48+
"stage3_max_reuse_distance": 0,
2849
"stage3_gather_16bit_weights_on_model_save": true
2950
},
3051
"gradient_accumulation_steps": "auto",

scripts/finetune.py

+25-45
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
1+
import importlib
12
import logging
23
import os
4+
import pathlib
35
import random
46
import signal
57
import sys
@@ -11,6 +13,8 @@
1113
from attrdict import AttrDefault
1214

1315
# add src to the pythonpath so we don't need to pip install this
16+
from axolotl.utils.tokenization import check_dataset_labels
17+
1418
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
1519
src_dir = os.path.join(project_root, "src")
1620
sys.path.insert(0, src_dir)
@@ -42,48 +46,20 @@ def get_device():
4246
cfg.device_map = {"": cfg.device}
4347

4448

45-
def check_dataset_labels(dataset, tokenizer):
46-
from termcolor import colored
47-
48-
# the dataset is already shuffled, so let's just check the first 5 elements
49-
for idx in range(5):
50-
# Get the input_ids, labels, and attention_mask from the dataset
51-
input_ids = dataset[idx]["input_ids"]
52-
labels = dataset[idx]["labels"]
53-
attention_mask = dataset[idx]["attention_mask"]
54-
55-
# You can compare the input_ids and labels element-wise
56-
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
57-
colored_tokens = []
58-
for i, (input_id, label_id, mask) in enumerate(
59-
zip(input_ids, labels, attention_mask)
60-
):
61-
decoded_input_token = tokenizer.decode(input_id)
62-
# Choose the color based on whether the label has the ignore value or not
63-
color = (
64-
"red" if label_id == -100 else ("yellow" if label_id == 0 else "green")
65-
)
66-
colored_token = colored(decoded_input_token, color) + colored(
67-
f"({label_id}, {mask})", "white"
68-
)
69-
colored_tokens.append(colored_token)
70-
71-
logging.info(" ".join(colored_tokens))
72-
logging.info("\n\n\n")
73-
74-
75-
def do_inference(cfg, model, tokenizer):
49+
def do_inference(cfg, model, tokenizer, prompter="AlpacaPrompter"):
7650
tokenizer.add_special_tokens({"unk_token": "<unk>"})
7751
tokenizer.add_special_tokens({"bos_token": "<s>"})
7852
tokenizer.add_special_tokens({"eos_token": "</s>"})
7953

80-
from axolotl.prompters import ReflectAlpacaPrompter
54+
prompter_module = getattr(importlib.import_module("axolotl.prompters"), prompter)
8155

8256
while True:
83-
instruction = str(input("Give me an instruction: "))
57+
# support for multiline inputs
58+
print("Give me an instruction (Ctrl + D to finish): ")
59+
instruction = pathlib.Path("/proc/self/fd/0").read_text()
8460
if not instruction:
8561
return
86-
prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction)
62+
prompt = prompter_module().build_prompt(instruction=instruction)
8763
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
8864

8965
model.eval()
@@ -174,8 +150,8 @@ def train(
174150
cfg.bf16 = False
175151

176152
# Load the model and tokenizer
177-
logging.info("loading model, tokenizer, and lora_config...")
178-
model, tokenizer, lora_config = load_model(
153+
logging.info("loading model, tokenizer, and peft_config...")
154+
model, tokenizer, peft_config = load_model(
179155
cfg.base_model,
180156
cfg.base_model_config,
181157
cfg.model_type,
@@ -190,6 +166,10 @@ def train(
190166
do_inference(cfg, model, tokenizer)
191167
return
192168

169+
if "shard" in kwargs:
170+
model.save_pretrained(cfg.output_dir)
171+
return
172+
193173
train_dataset, eval_dataset = load_prepare_datasets(
194174
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
195175
)
@@ -199,8 +179,9 @@ def train(
199179
return
200180

201181
if cfg.debug:
182+
logging.info("check_dataset_labels...")
202183
check_dataset_labels(
203-
train_dataset.select([random.randrange(0, len(train_dataset) - 1)]),
184+
train_dataset.select([random.randrange(0, len(train_dataset) - 1) for i in range(5)]),
204185
tokenizer,
205186
)
206187

@@ -213,9 +194,9 @@ def train(
213194
model = torch.compile(model)
214195

215196
# go ahead and presave, so we have the adapter config available to inspect
216-
if lora_config:
197+
if peft_config:
217198
logging.info(f"Pre-saving adapter config to {cfg.output_dir}")
218-
lora_config.save_pretrained(cfg.output_dir)
199+
peft_config.save_pretrained(cfg.output_dir)
219200

220201
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
221202
if cfg.local_rank == 0:
@@ -234,12 +215,11 @@ def train(
234215
logging.info(f"Using Auto-resume functionality to start with checkpoint at {resume_from_checkpoint}")
235216
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
236217

237-
if cfg.local_rank == 0:
238-
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
239-
logging.info(
240-
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
241-
)
242-
model.save_pretrained(cfg.output_dir)
218+
logging.info(
219+
f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}"
220+
)
221+
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
222+
trainer.save_model(cfg.output_dir)
243223

244224

245225
if __name__ == "__main__":

scripts/setup-runpod.sh

+9
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,15 @@ if [ -z "${TORCH_CUDA_ARCH_LIST}" ]; then # only set this if not set yet
2626
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
2727
fi
2828

29+
# install flash-attn and deepspeed from pre-built wheels for this specific container b/c these take forever to install
30+
mkdir -p /workspace/wheels
31+
cd /workspace/wheels
32+
curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
33+
curl -L -O https://github.com/winglian/axolotl/raw/wheels/wheels/flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
34+
pip install deepspeed-0.9.2%2B7ddc3b01-cp38-cp38-linux_x86_64.whl
35+
pip install flash_attn-1.0.4-cp38-cp38-linux_x86_64.whl
36+
pip install "peft @ git+https://github.com/huggingface/peft.git@main" --force-reinstall --no-dependencies
37+
2938
cd /workspace/
3039
git clone https://github.com/winglian/axolotl.git
3140
cd axolotl

src/axolotl/prompters.py

+16-5
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def append_message(self, role, message):
127127

128128

129129
class ShareGPTPrompter:
130-
def build_prompt(self, source, tokenizer):
130+
def build_prompt(self, source, tokenizer, sequence_len=2048):
131131
# ignore the system prompt if provided
132132
if source[0]["from"] == "system":
133133
source.pop(0)
@@ -157,13 +157,14 @@ def build_prompt(self, source, tokenizer):
157157
role = roles[sentence["from"]]
158158
assert role == conv.roles[j % 2]
159159
conv.append_message(role, sentence["value"])
160+
# TODO, this concatenates everything, but doesn't seem to properly add the eos_token_id, as the eos_token gets split up
160161
conversation = conv.get_prompt()
161162

162163
# Tokenize conversations
163164
tokenized_result = tokenizer(
164165
conversation,
165166
truncation=True,
166-
max_length=2048, # FIXME
167+
max_length=sequence_len, # FIXME
167168
padding=False,
168169
return_tensors=None,
169170
)
@@ -173,7 +174,9 @@ def build_prompt(self, source, tokenizer):
173174
sep = conv.sep + conv.roles[1] + ": "
174175

175176
rounds = conversation.split(conv.sep2)
177+
rounds = [r + conv.sep2 for r in rounds]
176178
cur_len = 1
179+
target[0] = IGNORE_TOKEN_ID # mask out the bos
177180
for i, rou in enumerate(rounds):
178181
if rou == "":
179182
break
@@ -182,19 +185,27 @@ def build_prompt(self, source, tokenizer):
182185
if len(parts) != 2:
183186
break
184187
parts[0] += sep
185-
round_len = len(tokenizer(rou)["input_ids"])
186-
instruction_len = len(tokenizer(parts[0])["input_ids"]) - 2
188+
round_len = len(tokenizer(rou)["input_ids"]) - 1 # -1 ignores the bos_token generated for this
189+
# we have to strip the initial part, any dangling whitespace creates an additional ghost token
190+
instruction_len = len(tokenizer(parts[0].strip())["input_ids"]) - 1 # -1 ignores the bos_token generated for this
187191
target[cur_len : cur_len + instruction_len] = [
188192
IGNORE_TOKEN_ID
189193
] * instruction_len
190194

191195
cur_len += round_len
192-
target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
196+
if cur_len >= sequence_len:
197+
break
198+
199+
# Fix: Truncate the target to have the same length as input_ids
200+
target = target[:len(tokenized_result["input_ids"])]
201+
# target[cur_len:] = [IGNORE_TOKEN_ID] * (len(target) - cur_len)
202+
193203
attention_mask = [
194204
1 if x != tokenizer.pad_token_id else 0
195205
for x in tokenized_result["input_ids"]
196206
]
197207

208+
# TODO truncate len to sequence_len
198209
return dict(
199210
input_ids=tokenized_result["input_ids"],
200211
labels=target,

0 commit comments

Comments
 (0)