1
+ import importlib
1
2
import logging
2
3
import os
4
+ import pathlib
3
5
import random
4
6
import signal
5
7
import sys
11
13
from attrdict import AttrDefault
12
14
13
15
# add src to the pythonpath so we don't need to pip install this
16
+ from axolotl .utils .tokenization import check_dataset_labels
17
+
14
18
project_root = os .path .abspath (os .path .join (os .path .dirname (__file__ ), ".." ))
15
19
src_dir = os .path .join (project_root , "src" )
16
20
sys .path .insert (0 , src_dir )
@@ -42,48 +46,20 @@ def get_device():
42
46
cfg .device_map = {"" : cfg .device }
43
47
44
48
45
- def check_dataset_labels (dataset , tokenizer ):
46
- from termcolor import colored
47
-
48
- # the dataset is already shuffled, so let's just check the first 5 elements
49
- for idx in range (5 ):
50
- # Get the input_ids, labels, and attention_mask from the dataset
51
- input_ids = dataset [idx ]["input_ids" ]
52
- labels = dataset [idx ]["labels" ]
53
- attention_mask = dataset [idx ]["attention_mask" ]
54
-
55
- # You can compare the input_ids and labels element-wise
56
- # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
57
- colored_tokens = []
58
- for i , (input_id , label_id , mask ) in enumerate (
59
- zip (input_ids , labels , attention_mask )
60
- ):
61
- decoded_input_token = tokenizer .decode (input_id )
62
- # Choose the color based on whether the label has the ignore value or not
63
- color = (
64
- "red" if label_id == - 100 else ("yellow" if label_id == 0 else "green" )
65
- )
66
- colored_token = colored (decoded_input_token , color ) + colored (
67
- f"({ label_id } , { mask } )" , "white"
68
- )
69
- colored_tokens .append (colored_token )
70
-
71
- logging .info (" " .join (colored_tokens ))
72
- logging .info ("\n \n \n " )
73
-
74
-
75
- def do_inference (cfg , model , tokenizer ):
49
+ def do_inference (cfg , model , tokenizer , prompter = "AlpacaPrompter" ):
76
50
tokenizer .add_special_tokens ({"unk_token" : "<unk>" })
77
51
tokenizer .add_special_tokens ({"bos_token" : "<s>" })
78
52
tokenizer .add_special_tokens ({"eos_token" : "</s>" })
79
53
80
- from axolotl .prompters import ReflectAlpacaPrompter
54
+ prompter_module = getattr ( importlib . import_module ( " axolotl.prompters" ), prompter )
81
55
82
56
while True :
83
- instruction = str (input ("Give me an instruction: " ))
57
+ # support for multiline inputs
58
+ print ("Give me an instruction (Ctrl + D to finish): " )
59
+ instruction = pathlib .Path ("/proc/self/fd/0" ).read_text ()
84
60
if not instruction :
85
61
return
86
- prompt = ReflectAlpacaPrompter ().build_prompt (instruction = instruction )
62
+ prompt = prompter_module ().build_prompt (instruction = instruction )
87
63
batch = tokenizer (prompt , return_tensors = "pt" , add_special_tokens = True )
88
64
89
65
model .eval ()
@@ -174,8 +150,8 @@ def train(
174
150
cfg .bf16 = False
175
151
176
152
# Load the model and tokenizer
177
- logging .info ("loading model, tokenizer, and lora_config ..." )
178
- model , tokenizer , lora_config = load_model (
153
+ logging .info ("loading model, tokenizer, and peft_config ..." )
154
+ model , tokenizer , peft_config = load_model (
179
155
cfg .base_model ,
180
156
cfg .base_model_config ,
181
157
cfg .model_type ,
@@ -190,6 +166,10 @@ def train(
190
166
do_inference (cfg , model , tokenizer )
191
167
return
192
168
169
+ if "shard" in kwargs :
170
+ model .save_pretrained (cfg .output_dir )
171
+ return
172
+
193
173
train_dataset , eval_dataset = load_prepare_datasets (
194
174
tokenizer , cfg , DEFAULT_DATASET_PREPARED_PATH
195
175
)
@@ -199,8 +179,9 @@ def train(
199
179
return
200
180
201
181
if cfg .debug :
182
+ logging .info ("check_dataset_labels..." )
202
183
check_dataset_labels (
203
- train_dataset .select ([random .randrange (0 , len (train_dataset ) - 1 )]),
184
+ train_dataset .select ([random .randrange (0 , len (train_dataset ) - 1 ) for i in range ( 5 ) ]),
204
185
tokenizer ,
205
186
)
206
187
@@ -213,9 +194,9 @@ def train(
213
194
model = torch .compile (model )
214
195
215
196
# go ahead and presave, so we have the adapter config available to inspect
216
- if lora_config :
197
+ if peft_config :
217
198
logging .info (f"Pre-saving adapter config to { cfg .output_dir } " )
218
- lora_config .save_pretrained (cfg .output_dir )
199
+ peft_config .save_pretrained (cfg .output_dir )
219
200
220
201
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
221
202
if cfg .local_rank == 0 :
@@ -234,12 +215,11 @@ def train(
234
215
logging .info (f"Using Auto-resume functionality to start with checkpoint at { resume_from_checkpoint } " )
235
216
trainer .train (resume_from_checkpoint = resume_from_checkpoint )
236
217
237
- if cfg .local_rank == 0 :
238
- # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
239
- logging .info (
240
- f"Training Completed!!! Saving pre-trained model to { cfg .output_dir } "
241
- )
242
- model .save_pretrained (cfg .output_dir )
218
+ logging .info (
219
+ f"Training Completed!!! Saving pre-trained model to { cfg .output_dir } "
220
+ )
221
+ # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
222
+ trainer .save_model (cfg .output_dir )
243
223
244
224
245
225
if __name__ == "__main__" :
0 commit comments