-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
95 lines (89 loc) · 3.17 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Any
from transformers import PreTrainedTokenizerBase
class CustomTokenizer:
def __init__(
self,
tokenizer: PreTrainedTokenizerBase,
max_length: int,
prompt_template: str,
a_template: str,
b_template: str,
instruction: str,
show_length: bool = False,
use_chat_template: bool = False,
) -> None:
self.tokenizer = tokenizer
self.max_length = max_length
self.prompt_template = prompt_template
self.b_template = b_template
self.a_template = a_template
self.instruction = instruction
self.show_length = show_length
self.use_chat_template = use_chat_template
def __call__(self, batch: dict) -> dict:
if self.show_length == False:
response_a = [
self.a_template.replace("<\A>", self.process_text(t))
for t in batch["response_a"]
]
response_b = [
self.b_template.replace("<\B>", self.process_text(t))
for t in batch["response_b"]
]
else:
response_a = [
self.a_template.replace("<\A>", self.process_text(t)).replace(
"<response_a>:",
f"<response_a> ({self.process_text(t).count(' ')} words):",
)
for t in batch["response_a"]
]
response_b = [
self.b_template.replace("<\B>", self.process_text(t)).replace(
"<response_b>:",
f"<response_b> ({self.process_text(t).count(' ')} words):",
)
for t in batch["response_b"]
]
prompt = [
self.prompt_template.replace("<\P>", self.process_text(t))
for t in batch["prompt"]
]
texts = [
self.instruction + "\n".join(sample)
for sample in zip(prompt, response_a, response_b)
]
add_special_tokens = not self.use_chat_template
if self.use_chat_template:
texts = [
"<bos><start_of_turn>user\n<<content>><end_of_turn>\n<start_of_turn>model\n".replace(
"<<content>>", t
)
for t in texts
]
tokenized = self.tokenizer(
texts,
max_length=self.max_length,
truncation=False,
add_special_tokens=add_special_tokens,
)
token_length = [len(t) for t in tokenized["input_ids"]]
tokenized_truncation = self.tokenizer(
texts,
max_length=self.max_length,
truncation=True,
add_special_tokens=add_special_tokens,
)
labels = []
for a_win, b_win in zip(batch["winner_model_a"], batch["winner_model_b"]):
if a_win:
label = 0
elif b_win:
label = 1
else:
label = 2
labels.append(label)
return {**tokenized_truncation, "labels": labels, "token_length": token_length}
@staticmethod
def process_text(text: str) -> str:
return " ".join(eval(text, {"null": ""}))