forked from harubaru/convogpt
-
Notifications
You must be signed in to change notification settings - Fork 16
/
hf_trainer.py
267 lines (219 loc) · 10.2 KB
/
hf_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
import os
import pathlib
import typing as t
from dataclasses import dataclass, field
import torch
import transformers
from dataset import DataCollatorForMmapedDataset, MmappedArrowDataset
from profiling import ProfilerCallback, build_profiler_configuration
@dataclass
class ModelArguments:
low_cpu_mem_usage: bool = field(
metadata={"help": "Try to reduce CPU memory usage while loading the model."},
default=True)
model_name_or_path: t.Optional[str] = field(
default="EleutherAI/pythia-70m-deduped")
use_xformers: bool = field(default=False, metadata={"help": "Use xFormers' memory_efficient_attention"})
@dataclass
class DataArguments:
train_file: str = field(metadata={"help": "Path to the training set."})
eval_file: str = field(metadata={"help": "Path to the evaluation set."})
@dataclass
class OtherArguments:
model_load_delay_per_rank: t.Optional[int] = field(metadata={
"help": "Delay loading the model by (this many seconds) * (local_rank)."},
default=None)
enable_profiler: bool = field(
metadata={"help": "Whether to profile the training loop."},
default=False)
add_special_tokens: t.Optional[str] = field(
metadata={"help": "Extra special tokens to add to the tokenizer before training. Comma-separated."},
default=None)
uft: bool = field(
metadata={"help": "Use unsupervised fine-tuning instead of supervised fine-tuning."},
default=False)
@dataclass
class LoraArguments:
use_lora: t.Optional[bool] = field(metadata={"help": "Whether to train a LoRA instead of the full model."},
default=False)
lora_rank: t.Optional[int] = field(metadata={"help": "LoRA rank."},
default=4)
lora_alpha: t.Optional[int] = field(metadata={"help": "LoRA alpha."},
default=32)
lora_dropout: t.Optional[float] = field(metadata={"help": "LoRA dropout."},
default=0.05)
lora_target_modules: t.Optional[str] = field(metadata={"help": "Target modules, comma-separated."},
default=None)
def main() -> None:
parser = transformers.HfArgumentParser((
ModelArguments,
DataArguments,
LoraArguments,
OtherArguments,
transformers.TrainingArguments,
))
model_args, data_args, lora_args, \
other_args, training_args = parser.parse_args_into_dataclasses()
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
padding_side="right",
use_fast=True,
)
# xFormers optimizations.
if model_args.use_xformers:
from monkeypatches import apply_xformers_monkeypatches
apply_xformers_monkeypatches()
if other_args.model_load_delay_per_rank is not None:
# When working with constrained system memory, loading the model at the
# exact same time on all training processes will likely fail due to all
# the model copies going around. We can delay loading based on
# local_rank so not all processes are doing this at once, which
# alleviates the situation. Kinda silly, but it works.
import time
time.sleep(other_args.model_load_delay_per_rank *
training_args.local_rank)
# Model loading.
model_load_dtype = None
if training_args.bf16:
model_load_dtype = torch.bfloat16
elif training_args.fp16:
model_load_dtype = torch.float16
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
low_cpu_mem_usage=model_args.low_cpu_mem_usage,
torch_dtype=model_load_dtype,
).cuda()
if other_args.add_special_tokens is not None:
# MAINTENANCE(11b): Big fat warning: the snippet below is copy-pasted
# into ``./preparation/tokenize_data_{sft,uft}.py``. Make sure to always keep both
# implementations in sync.
special_token_contents = other_args.add_special_tokens.split(",")
special_tokens = [
transformers.AddedToken(
# Heads up: this is very poorly documented in HuggingFace and
# some old forum discussions mention that it's apparently
# exclusive to the Rust-based tokenizers? If anything seems
# funky about the special token behavior, this is a good place
# to look.
content, lstrip=True, rstrip=True)
for content in special_token_contents
]
_add_special_tokens_to_tokenizer_and_resize_model_embeddings(
{"additional_special_tokens": special_tokens},
tokenizer,
model,
)
# LoRA setup.
if lora_args.use_lora:
from peft import LoraConfig, TaskType, get_peft_model
target_modules = None
if lora_args.lora_target_modules is not None:
target_modules = lora_args.lora_target_modules.split(",")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
inference_mode=False,
r=lora_args.lora_rank,
lora_alpha=lora_args.lora_alpha,
lora_dropout=lora_args.lora_dropout,
target_modules=target_modules,
)
if training_args.gradient_checkpointing:
model.enable_input_require_grads()
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
# Silence this annoying warning.
if training_args.gradient_checkpointing:
model.config.use_cache = False
# Dataset setup.
train_dataset = MmappedArrowDataset(data_args.train_file, sft=not other_args.uft)
eval_dataset = MmappedArrowDataset(data_args.eval_file, sft=not other_args.uft)
data_collator = DataCollatorForMmapedDataset(tokenizer=tokenizer, sft=not other_args.uft)
trainer = transformers.Trainer(
model=model,
tokenizer=tokenizer,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
data_collator=data_collator,
args=training_args,
callbacks=[SavePeftModelCallback] if lora_args.use_lora else None,
)
try:
# Resume from checkpoint if we have any checkpoints automatically saved
# by the HF Trainer within the output directory.
resume_from_checkpoint = len(
list(pathlib.Path(
training_args.output_dir).glob("checkpoint-*"))) > 0
if other_args.enable_profiler:
profiler_args = build_profiler_configuration()
with torch.profiler.profile(**profiler_args) as profiler:
trainer.add_callback(ProfilerCallback(profiler=profiler))
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
except KeyboardInterrupt as ex:
# TODO(11b): Test whether this does what I expect. Idea is to have the
# trainer save the current state when I interrupt the run so I don't
# need to keep waiting for a checkpoint step.
# trainer.save_model()
# trainer.save_state()
raise ex
trainer.save_state()
trainer.save_model()
class SavePeftModelCallback(transformers.TrainerCallback):
'''
At some point, PEFT stopped saving just the adapter and instead started
storing full model weights. Extracting the adapter from the weights is
doable, but seems to result in subpar results for some unknown reason, so
this Trainer callback saves the adapter itself during training to avoid
this.
https://github.com/huggingface/peft/issues/286#issuecomment-1512611968
https://github.com/huggingface/peft/blob/main/examples/int8_training/peft_bnb_whisper_large_v2_training.ipynb
'''
def on_save(
self,
args: transformers.TrainingArguments,
state: transformers.TrainerState,
control: transformers.TrainerControl,
**kwargs,
):
checkpoint_folder_name = f"{transformers.trainer_utils.PREFIX_CHECKPOINT_DIR}-{state.global_step}"
checkpoint_folder = os.path.join(args.output_dir, checkpoint_folder_name)
peft_model_path = os.path.join(checkpoint_folder, "adapter_model")
kwargs["model"].save_pretrained(peft_model_path)
# pytorch_model_path = os.path.join(checkpoint_folder, "pytorch_model.bin")
# if os.path.exists(pytorch_model_path):
# os.remove(pytorch_model_path)
return control
def _add_special_tokens_to_tokenizer_and_resize_model_embeddings(
special_tokens: t.Dict[str, t.Union[str, transformers.AddedToken]],
tokenizer: transformers.PreTrainedTokenizerBase,
model: transformers.PreTrainedModel,
):
tokenizer.add_special_tokens(special_tokens)
# Size is rounded up to the nearest number divisible by 64 for performance
# reasons.
new_size = _nearest_divisible(num=len(tokenizer), divisor=64)
old_size = model.config.vocab_size
if new_size == old_size:
# No resizing needs to be done, let's bail!
return
# Need to resize the token embeddings. We initialize the new positions with
# the mean of the existing ones to cut down on required training time.
model.resize_token_embeddings(new_size)
new_positions_count = new_size - old_size
input_embeddings = model.get_input_embeddings().weight.data
output_embeddings = model.get_output_embeddings().weight.data
# This is just to keep the LSP happy.
assert isinstance(input_embeddings, torch.Tensor)
assert isinstance(output_embeddings, torch.Tensor)
input_embeddings_avg = input_embeddings[:-new_positions_count].mean(dim=0,
keepdim=True)
output_embeddings_avg = output_embeddings[:-new_positions_count].mean(dim=0,
keepdim=True)
input_embeddings[-new_positions_count:] = input_embeddings_avg
output_embeddings[-new_positions_count:] = output_embeddings_avg
def _nearest_divisible(num: int, divisor: int) -> int:
'''Returns the nearest number to `num` that is divisible by `divisor`.'''
return (num + divisor - 1) // divisor * divisor
if __name__ == "__main__":
main()