-
Notifications
You must be signed in to change notification settings - Fork 40
/
training.py
128 lines (101 loc) · 3.66 KB
/
training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
from operator import itemgetter
from pathlib import Path
import hydra
import wandb
from datasets import load_dataset
from hydra.utils import to_absolute_path
from omegaconf import OmegaConf
from transformers import TrainingArguments
from trl import DataCollatorForCompletionOnlyLM, SFTTrainer
logger = logging.getLogger(__name__)
def setup_wandb(args: dict):
"""
WANDB integration for tracking training runs.
"""
env = {key: os.getenv(key) for key in os.environ}
run = wandb.init(
job_type="train",
project=args["project"],
group=args["experiment"],
entity=args["wandb_entity"],
config={**args, **env},
tags=["train"],
)
return run
@hydra.main(version_base=None, config_path="./configs", config_name="training")
def main(args):
logger.info(OmegaConf.to_yaml(args))
OmegaConf.set_struct(args, False)
logger.info(f"Experiment name: {args.experiment}")
logger.info(f"Output path: {args.train.output_dir}")
if args.use_wandb:
run = setup_wandb(OmegaConf.to_container(args))
logger.info(f"Loading dataset: {args.data_file}")
dataset = load_dataset(
"json", data_files=to_absolute_path(args.data_file), split="train"
)
logger.info(f"Loading instruction from file {args.instruction}...")
instruction = open(args.instruction).read()
logger.info(f"Loaded instruction: {instruction}")
if args.shuffle:
dataset = dataset.shuffle(seed=args.shuffle)
if args.limit:
dataset = dataset.select(range(min(args.limit, len(dataset))))
model_class = hydra.utils.instantiate(args.model, _convert_="object")
logger.info("Model was loaded.")
def format_answer(example):
query = example[args.input_key]
if args.model.instruction_in_prompt:
query = instruction + "\n" + query
output = (
out[0] if isinstance(out := example[args.output_key], list) else out
) or ""
if args.template:
return open(args.template).read().format(query=query, output=output)
else:
messages = [
{
"role": "system",
"content": instruction,
},
{"role": "user", "content": query},
{
"role": "assistant",
"content": output,
},
]
return dict(messages=messages)
dataset = dataset.map(format_answer)
# Split the dataset into train and dev
train, dev = itemgetter("train", "test")(dataset.train_test_split(args.dev_split))
collator = DataCollatorForCompletionOnlyLM(
model_class.tokenizer.encode(
args.model.completion_start, add_special_tokens=False
),
tokenizer=model_class.tokenizer,
)
logger.info("Initializing training arguments...")
training_args = TrainingArguments(**args.train)
logger.info("Starting to train...")
trainer = SFTTrainer(
model=model_class.model,
args=training_args,
data_collator=collator,
train_dataset=train,
eval_dataset=dev,
dataset_batch_size=1,
packing=False,
max_seq_length=args.model.max_sequence_len,
dataset_kwargs=dict(add_special_tokens=False),
)
trainer.train(resume_from_checkpoint=args.resume_checkpoint)
logger.info(
f"Finished training; saving model to {args.train.output_dir}/checkpoint..."
)
trainer.model.save_pretrained(Path(args.train.output_dir) / "checkpoint/")
if args.hfhub_tag:
trainer.model.push_to_hub(args.hfhub_tag, private=True)
if __name__ == "__main__":
main()