-
Notifications
You must be signed in to change notification settings - Fork 0
/
run_clm.py
115 lines (93 loc) · 4.74 KB
/
run_clm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import argparse
from pathlib import Path
import pandas as pd
import torch
from tqdm import tqdm
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer
from dataset import CLMGenerationDataset
def gen_completions(model, tokenizer, dataset, batch_size, filename='completions'):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model.eval()
device_map = model.hf_device_map
first_device = next(iter(device_map.values()))
texts = dataset['prompts'].reset_index(drop=True) # Prompts
labels = dataset['texts'].reset_index(drop=True) # True texts
sensitives = dataset['domain'].reset_index(drop=True) # Additional info (e.g. domain)
category = dataset['category'].reset_index(drop=True) # Additional info (e.g. category)
prompt_dataset = CLMGenerationDataset(texts, labels, tokenizer)
prompt_loader = DataLoader(prompt_dataset, batch_size)
completions, completions_split, prompts_only, completions_only, completions_split_only, references = [], [], [], [], [], []
with torch.no_grad():
for batch in tqdm(prompt_loader, desc='Generating', unit='batch'):
input_ids = batch['input_ids'].to(first_device)
labels = batch['labels'].to(first_device)
attention_mask = batch['attention_mask'].to(first_device)
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=128,
num_beams=5,
no_repeat_ngram_size=2,
early_stopping=True
)
outputs = outputs.cpu()
decoded_outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
for i, decoded_output in enumerate(decoded_outputs):
prompt = texts[i]
# Remove the prompt part from the completion
if decoded_output.startswith(prompt):
completion_only = decoded_output[len(prompt):].strip()
else:
completion_only = decoded_output # Fallback if no prompt match
completions.append(decoded_output) # Full completion (prompt + generated text)
completions_only.append(completion_only) # Generated completion without prompt
completions_split.append(decoded_output.split()) # Full completion split
# Split the completion-only part (after prompt) and save
completions_split_only.append(completion_only.split())
prompts_only.append(prompt) # Store the original prompt
for label in labels:
reference = tokenizer.decode(label, skip_special_tokens=True)
references.append([reference.split()])
# Save prompts and completions separately
data_to_save = {
"prompts": prompts_only,
"completions": completions_only, # Generated completions without the prompts
"completions_full": completions, # Full completion (prompt + generated text)
"completions_split": completions_split_only, # Completion split without the prompts
"sensitives": sensitives.to_list(),
"category": category.to_list(),
"references": references
}
output_file_path = Path(filename)
output_file_path.parent.mkdir(parents=True, exist_ok=True)
df = pd.DataFrame(data_to_save)
print(f"output file save to {output_file_path.absolute()}")
df.to_json(output_file_path, orient='records', lines=True)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Script to process model and dataset')
parser.add_argument('--model_name_or_path',
type=str,
required=True,
help='The name or path of the model to train')
parser.add_argument('--tokenizer_name',
type=str,
required=True,
help='The name of tokenizer to load')
parser.add_argument('--prompt_file',
type=str,
required=True,
help='Input prompt data file (a jsonl file) '
'should include one column name "prompt" as training corpus')
parser.add_argument('--filename',
type=str,
required=False,
help='A jsonl file to store output completions')
args = parser.parse_args()
model = AutoModelForCausalLM.from_pretrained(args.model_name_or_path).half()
tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
df = pd.read_json(args.prompt_file, lines=True)
gen_completions(model, tokenizer, df, args.filename)