-
Notifications
You must be signed in to change notification settings - Fork 10
/
model_utils.py
235 lines (205 loc) · 10.7 KB
/
model_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""
https://github.com/allenai/open-instruct
"""
import torch
import tqdm
from transformers import StoppingCriteria, StoppingCriteriaList
class KeywordsStoppingCriteria(StoppingCriteria):
def __init__(self, keywords_str, tokenizer):
StoppingCriteria.__init__(self)
self.current_context = []
self.tokenizer = tokenizer
self.keywords_str = keywords_str
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
if len(self.current_context) == 0:
self.current_context = [[] for _ in range(input_ids.shape[0])]
# self.current_context.append(input_ids[0][-1].item())
sequences_should_be_stopped = []
for i in range(input_ids.shape[0]):
_id = input_ids[i][-1].item()
self.current_context[i].append(_id)
current_context = self.tokenizer.decode(self.current_context[i])
should_be_stopped = False
for word in self.keywords_str:
if word in current_context:
should_be_stopped = True
break
sequences_should_be_stopped.append(should_be_stopped)
return all(sequences_should_be_stopped)
class KeyWordsCriteriaTrunc(StoppingCriteria):
def __init__(self, stop_id_sequences, prompt_length):
assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
self.stop_sequences = stop_id_sequences
self.prompt_length = prompt_length
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
sequences_should_be_stopped = []
for i in range(input_ids.shape[0]):
ids = input_ids[i][self.prompt_length:].tolist()
should_be_stopped = False
for stop_sequence in self.stop_sequences:
if input_ids.shape[0] == 1:
_ids = ids[-len(stop_sequence):]
else:
_ids = ids
for j in range(len(_ids), 0, -len(stop_sequence)):
if _ids[max(j - len(stop_sequence), 0): j] == stop_sequence:
should_be_stopped = True
break
if should_be_stopped:
break
sequences_should_be_stopped.append(should_be_stopped)
return all(sequences_should_be_stopped)
class KeyWordsCriteria(StoppingCriteria):
def __init__(self, stop_id_sequences):
assert isinstance(stop_id_sequences[0], list), "stop_id_sequences should be a list of list of ids"
self.stop_sequences = stop_id_sequences
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
sequences_should_be_stopped = []
for i in range(input_ids.shape[0]):
sequence_should_be_stopped = False
for stop_sequence in self.stop_sequences:
if input_ids[i][-len(stop_sequence):].tolist() == stop_sequence:
sequence_should_be_stopped = True
break
sequences_should_be_stopped.append(sequence_should_be_stopped)
return all(sequences_should_be_stopped)
@torch.no_grad()
def generate_completions(model, tokenizer, prompts, batch_size=1, stop_id_sequences=None, add_special_tokens=True, disable_tqdm=False, **generation_kwargs):
generations = []
if not disable_tqdm:
progress = tqdm.tqdm(total=len(prompts), desc="Generating Completions")
num_return_sequences = generation_kwargs.get("num_return_sequences", 1)
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i:i+batch_size]
tokenized_prompts = tokenizer(batch_prompts, padding="longest", return_tensors="pt", add_special_tokens=add_special_tokens)
batch_input_ids = tokenized_prompts.input_ids
attention_mask = tokenized_prompts.attention_mask
if model.device.type == "cuda":
batch_input_ids = batch_input_ids.cuda()
attention_mask = attention_mask.cuda()
# try:
stop_criteria = KeywordsStoppingCriteria(stop_id_sequences, tokenizer)
batch_outputs = model.generate(
input_ids=batch_input_ids,
attention_mask=attention_mask,
stopping_criteria=StoppingCriteriaList([stop_criteria]),
# stopping_criteria=[KeyWordsCriteria(stop_id_sequences)] if stop_id_sequences else None,
# stopping_criteria=[KeyWordsCriteriaTrunc(stop_id_sequences, batch_input_ids.size(1))] if stop_id_sequences else None,
**generation_kwargs
)
# the stopping criteria is applied at batch level, so if other examples are not stopped, the entire batch will continue to generate.
# so some outputs still have the stop sequence, which we need to remove.
# if stop_id_sequences:
# for output_idx in range(batch_outputs.shape[0]):
# for token_idx in range(batch_input_ids.shape[1], batch_outputs.shape[1]):
# if any(batch_outputs[output_idx, token_idx: token_idx+len(stop_sequence)].tolist() == stop_sequence for stop_sequence in stop_id_sequences):
# batch_outputs[output_idx, token_idx:] = tokenizer.pad_token_id
# break
# remove the prompt from the output
# we need to re-encode the prompt because we need to make sure the special tokens are treated the same way as in the outputs.
# we changed our previous way of truncating the output token ids dicrectly because some tokenizer (e.g., llama) won't add space token before the first token.
# space is important for some tasks (e.g., code completion).
batch_outputs = tokenizer.batch_decode(batch_outputs, skip_special_tokens=True)
batch_prompts = tokenizer.batch_decode(batch_input_ids, skip_special_tokens=True)
# duplicate the prompts to match the number of return sequences
batch_prompts = [prompt for prompt in batch_prompts for _ in range(num_return_sequences)]
batch_generations = [
output[len(prompt):] for prompt, output in zip(batch_prompts, batch_outputs)
]
# remove the remain stop sequence from the output.
for idx, prediction in enumerate(batch_generations):
for stop_sequence in stop_id_sequences:
batch_generations[idx] = prediction.split(stop_sequence)[0]
generations += batch_generations
if not disable_tqdm:
progress.update(len(batch_prompts)//num_return_sequences)
assert len(generations) == len(prompts) * num_return_sequences, "number of generations should be equal to number of prompts * num_return_sequences"
return generations
def load_hf_lm_and_tokenizer(
model_name_or_path,
tokenizer_name_or_path=None,
device_map="auto",
load_in_8bit=False,
load_in_half=True,
gptq_model=False,
use_fast_tokenizer=False,
padding_side="left",
use_safetensors=False,
):
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
if not tokenizer_name_or_path:
tokenizer_name_or_path = model_name_or_path
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True)
# tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, legacy=False, use_fast=use_fast_tokenizer, padding_side=padding_side, trust_remote_code=True)
# set pad token to eos token if pad token is not set
if tokenizer.pad_token is None:
if tokenizer.unk_token:
tokenizer.pad_token = tokenizer.unk_token
tokenizer.pad_token_id = tokenizer.unk_token_id
elif tokenizer.eos_token:
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id
else:
raise ValueError("You are using a new tokenizer without a pad token."
"This is not supported by this script.")
# if tokenizer.pad_token is None:
# tokenizer.pad_token = tokenizer.unk_token
# tokenizer.pad_token_id = tokenizer.unk_token_id
if gptq_model:
from auto_gptq import AutoGPTQForCausalLM
model_wrapper = AutoGPTQForCausalLM.from_quantized(
model_name_or_path, device="cuda:0", use_triton=True
)
model = model_wrapper.model
elif load_in_8bit:
model = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
device_map=device_map,
load_in_8bit=True
)
else:
# return "", tokenizer
# defaul load in float16
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
torch_dtype=torch.float16,
device_map=device_map,
trust_remote_code=True,
use_safetensors=use_safetensors)
if torch.cuda.is_available():
model = model.cuda()
if load_in_half:
model = model.half()
model.eval()
return model, tokenizer
def _test_generate_completions():
model_name_or_path = "../models/codellama_7b/v1-16k"
llm, tokenizer = load_hf_lm_and_tokenizer(
model_name_or_path=model_name_or_path,
load_in_half=True,
use_fast_tokenizer=True,
use_safetensors=True,
)
# some math word problems
prompts = [
"---\n1+1=2\n---2+2=4\n---3+3=6\n---4+4=8\n---5+5=10\n---6+6=",
"---\n1+1=2\n---12+12=24\n---3+3=6\n---12345+12345=",
# "A train leaves Chicago at 7am and travels at 60mph. Another train leaves Chicago at 9am and travels at 80mph. When will the second train overtake the first?",
# "The sum of two numbers is 10. The difference of the same two numbers is 4. What are the two numbers?",
]
stop_sequences = ["\n\n\n", "---"]
# Because many tokenizers will treat the word after space differently from the original word alone,
# to be consistent, we add a space before tokenization and remove it after tokenization.
# stop_id_sequences = [tokenizer.encode(" " + x, add_special_tokens=False)[1:] for x in stop_sequences]
outputs = generate_completions(
model=llm,
tokenizer=tokenizer,
prompts=prompts,
max_new_tokens=128,
batch_size=16,
# stop_id_sequences=stop_id_sequences,
stop_id_sequences=stop_sequences,
)
print(outputs)
if __name__ == "__main__":
_test_generate_completions()