-
Notifications
You must be signed in to change notification settings - Fork 63
/
Copy pathtrain_lora.py
328 lines (281 loc) · 13 KB
/
train_lora.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
import argparse
import logging
import math
import os
import pathlib
import sys
import time
from typing import Tuple, Union
import deepspeed
import torch
import wandb
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
from transformers import (AutoModelForCausalLM, AutoTokenizer,
BitsAndBytesConfig, DataCollatorForSeq2Seq,
PreTrainedModel, PreTrainedTokenizer)
from transformers import Seq2SeqTrainingArguments as TrainingArguments
from transformers import Trainer
sys.path.append(os.getcwd())
from llamatuner.configs import (DataArguments, FinetuningArguments,
GeneratingArguments, ModelArguments)
from llamatuner.configs.parser import get_train_args
from llamatuner.data.data_loader import get_dataset
from llamatuner.model.callbacks import ComputeMetrics
from llamatuner.model.utils.misc import find_all_linear_modules
from llamatuner.utils.constants import IGNORE_INDEX
from llamatuner.utils.logger_utils import get_outdir, get_root_logger
from llamatuner.utils.model_utils import (get_logits_processor,
get_peft_state_maybe_zero_3,
print_model_dtypes,
print_trainable_parameters)
def load_model_tokenizer(
model_args: ModelArguments,
training_args: TrainingArguments,
finetuning_args: FinetuningArguments,
logger: logging.Logger,
) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
"""
Load a pre-trained model and tokenizer for natural language processing tasks.
Args:
model_args (ModelArguments): Arguments for the model configuration.
training_args (TrainingArguments): Arguments for the training configuration.
finetuning_args (FinetuningArguments): Arguments for the finetuning configuration.
logger (logging.Logger): Logger object for logging information.
Returns:
Tuple[PreTrainedModel, PreTrainedTokenizer]: Loaded model and tokenizer.
"""
# Determine torch dtype for model based on arguments
torch_dtype = (torch.float32 if training_args.fp16 else
(torch.bfloat16 if training_args.bf16 else torch.float32))
device_map: Union[str, None] = 'auto'
if finetuning_args.use_qlora:
world_size = int(os.environ.get('WORLD_SIZE', 1))
device_map = ({
'': int(os.environ.get('LOCAL_RANK') or 0)
} if world_size != 1 else None)
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled(
):
logger.info(
'FSDP and ZeRO3 are both currently incompatible with QLoRA.')
config_kwargs = {
'cache_dir': model_args.cache_dir,
'trust_remote_code': model_args.trust_remote_code,
}
logger.info(f'Loading Model from {model_args.model_name_or_path}...')
load_in_4bit = finetuning_args.quant_bit == 4
load_in_8bit = finetuning_args.quant_bit == 8
model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
device_map=device_map,
low_cpu_mem_usage=True,
# BitsAndBytesConfig设置存储格式和计算格式,以及优化方式
quantization_config=BitsAndBytesConfig(
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
llm_int8_threshold=finetuning_args.llm_int8_threshold, # int8的门限
llm_int8_has_fp16_weight=finetuning_args.
llm_int8_has_fp16_weight, # int8的LLM,是否包含fp16的权重
bnb_4bit_use_double_quant=finetuning_args.double_quant, # 是否进行双重量化
bnb_4bit_quant_type=finetuning_args.quant_type, # {'fp4', 'nf4'}
bnb_4bit_compute_dtype=torch_dtype, # 计算时使用的数据类型
) if finetuning_args.use_qlora else None,
torch_dtype=torch_dtype,
**config_kwargs,
)
# Enable model parallelism.
# 设置两个和并行操作相关的参数
if torch.cuda.device_count() > 1:
# Keeps Trainer from trying its own DataParallelism when more than 1 GPU is available
setattr(model, 'model_parallel', True)
setattr(model, 'is_parallelizable', True)
# Prepare the model for k-bit training if specified.
if finetuning_args.use_qlora:
logger.info('Preparemodel for kbit training!!!')
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=training_args.gradient_checkpointing)
# Print a message if the GPU supports bfloat16.
# 如果计算类型为 torch.float16 并且 args.bits==4,也就是4bit量化模型时,进行如下操作。
if torch_dtype == torch.float16 and finetuning_args.quant_bit.bits == 4:
# 得到显卡的计算能力的最大值和最小值,分别对应major和minor
# 只有major >= 8时的GPU才支持bfloat16格式,可以使用参数--bf16来加速训练
major, minor = torch.cuda.get_device_capability()
if major >= 8:
logger.info(
'Your GPU supports bfloat16, you can accelerate training with the argument --bf16'
)
# Add LoRA sparsity if specified
logger.info('Adding LoRA modules...')
if len(finetuning_args.lora_target
) == 1 and finetuning_args.lora_target[0] == 'all':
target_modules = find_all_linear_modules(model)
else:
target_modules = finetuning_args.lora_target
lora_config = LoraConfig(
r=finetuning_args.lora_rank, # lora层A矩阵的列大小和B矩阵的行大小
lora_alpha=finetuning_args.lora_alpha, # 缩放因子
target_modules=target_modules, # 需要进行lora网络操作的模块名称列表
lora_dropout=finetuning_args.lora_dropout, # 是否使用dropout, 正则化操作
bias=finetuning_args.lora_bias, # 是否对偏差参数进行处理
task_type='CAUSAL_LM', # 模型名称,一种标记
)
logger.info('Getting the PEFT model...')
model = get_peft_model(model, lora_config)
# Enable gradient checkpointing if specified
if training_args.gradient_checkpointing:
logger.info('Using gradient checkpointing...')
model.enable_input_require_grads()
model.config.use_cache = False
logger.info(f'Loading tokenizer from {model_args.model_name_or_path}...')
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
padding_side='right',
model_max_length=model_args.model_max_length,
use_fast=False,
**config_kwargs,
)
if tokenizer.pad_token != tokenizer.unk_token:
tokenizer.pad_token = tokenizer.unk_token
return model, tokenizer
def run_lora_sft(
model_args: ModelArguments,
data_args: DataArguments,
training_args: TrainingArguments,
finetuning_args: FinetuningArguments,
generating_args: GeneratingArguments,
) -> None:
"""
Trains a language model using Hugging Face's Transformers library.
Args:
model_args (ModelArguments): The arguments for the model configuration.
data_args (DataArguments): The arguments for the data configuration.
training_args (TrainingArguments): The arguments for the training configuration.
finetuning_args (FinetuningArguments): The arguments for the finetuning configuration.
generating_args (GeneratingArguments): The arguments for the generating configuration.
Returns:
None
"""
args = argparse.Namespace(
**vars(model_args),
**vars(data_args),
**vars(training_args),
**vars(finetuning_args),
**vars(generating_args),
)
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
output_dir = get_outdir(training_args.output_dir)
training_args.output_dir = get_outdir(output_dir, 'checkpoints')
log_name = os.path.join(output_dir, timestamp).replace(os.path.sep, '_')
log_file = os.path.join(output_dir, log_name + '.log')
logger = get_root_logger(log_file=log_file, log_level='INFO')
logger.info(
f'Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}'
)
logger.info(
f'distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}'
)
logger.info('Loading model and tokenizer...')
model, tokenizer = load_model_tokenizer(model_args,
training_args,
finetuning_args,
logger=logger)
logger.info('Successfully loaded model and tokenizer.')
logger.info('Printing trainable parameters...')
print_trainable_parameters(model, kbit=finetuning_args.quant_bit)
# Verify dtypes
logger.info('Print model dtypes...')
print_model_dtypes(model)
# Create a supervised dataset and Trainer, then train the model
logger.info('Creating a supervised dataset and DataCollator...')
dataset_module = get_dataset(
data_args,
model_args,
training_args,
stage='sft',
tokenizer=tokenizer,
processor=None,
)
logger.info('Successfully created the supervised dataset.')
logger.info('Creating DataCollator for Seq2Seq...')
data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer,
pad_to_multiple_of=8 if tokenizer.padding_side == 'right' else None,
label_pad_token_id=IGNORE_INDEX
if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id,
)
training_args.generation_max_length = (training_args.generation_max_length
or data_args.cutoff_len)
training_args.generation_num_beams = (data_args.eval_num_beams or
training_args.generation_num_beams)
training_args.remove_unused_columns = (False
if model_args.visual_inputs else
training_args.remove_unused_columns)
gen_kwargs = generating_args.to_dict()
gen_kwargs['eos_token_id'] = [tokenizer.eos_token_id
] + tokenizer.additional_special_tokens_ids
gen_kwargs['pad_token_id'] = tokenizer.pad_token_id
gen_kwargs['logits_processor'] = get_logits_processor()
if 'wandb' in training_args.report_to:
logger.info('Initializing wandb project...')
wandb_run_name = finetuning_args.wandb_run_name if finetuning_args else log_name
wandb.init(
dir=output_dir,
project=finetuning_args.wandb_project,
name=wandb_run_name,
tags=['lora-finetune', 'sft'],
group='lora-finetune',
config=args,
)
logger.info('Creating a Trainer...')
trainer = Trainer(
model=model,
tokenizer=tokenizer,
args=training_args,
data_collator=data_collator,
compute_metrics=ComputeMetrics(tokenizer)
if training_args.predict_with_generate else None,
**dataset_module,
)
logger.info('Starting training...')
if training_args.do_train:
if training_args.resume_from_checkpoint and list(
pathlib.Path(training_args.output_dir).glob('checkpoint-*')):
logger.info(
f'Resuming training from checkpoint {training_args.resume_from_checkpoint}'
)
train_result = trainer.train(resume_from_checkpoint=True)
else:
logger.info('Starting training from scratch...')
train_result = trainer.train()
if deepspeed.is_deepspeed_zero3_enabled():
state_dict_zero3 = (
trainer.model_wrapped._zero3_consolidated_16bit_state_dict())
if training_args.local_rank == 0:
state_dict = state_dict_zero3
else:
state_dict = get_peft_state_maybe_zero_3(model.named_parameters(),
finetuning_args.lora_bias)
if training_args.local_rank == 0:
model.save_pretrained(training_args.output_dir,
state_dict=state_dict)
metrics = train_result.metrics
metrics['train_samples'] = len(trainer.train_dataset)
trainer.log_metrics('train', metrics)
trainer.save_metrics('train', metrics)
trainer.save_state()
if training_args.do_eval:
metrics = trainer.evaluate(metric_key_prefix='eval')
try:
perplexity = math.exp(metrics['eval_loss'])
except OverflowError:
perplexity = float('inf')
metrics['perplexity'] = perplexity
metrics['eval_samples'] = len(trainer.eval_dataset)
trainer.log_metrics('eval', metrics)
trainer.save_metrics('eval', metrics)
logger.info('Done.')
if __name__ == '__main__':
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args())
run_lora_sft(model_args, data_args, training_args, finetuning_args,
generating_args)