-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathschema_helper.py
409 lines (361 loc) · 15.9 KB
/
schema_helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
# SPDX-FileCopyrightText: 2024-present Oori Data <info@oori.dev>
#
# SPDX-License-Identifier: Apache-2.0
# toolio.schema_helper
'''
JSON schema decoding with MLX
Basically just a combo of:
* https://github.com/otriscon/llm-structured-output/blob/main/src/examples/llm_schema.py
* https://github.com/otriscon/llm-structured-output/blob/main/src/examples/reusable_kv_cache.py
'''
import time
from math import inf
from operator import itemgetter
from typing import Iterable, Optional, Union
import mlx.core as mx
from mlx_lm.models.cache import KVCache
from mlx_lm.utils import load
from toolio.vendor.llm_structured_output import JsonSchemaAcceptorDriver
from toolio.vendor.llm_structured_output.util.bitmap import (
bias_logits,
count_set_bits,
enumerate_set_bits,
)
from toolio.vendor.llm_structured_output.util.output import debug
from toolio.vendor.llm_structured_output.util.tokenization import HuggingfaceTokenizerHelper
class RejectedCompletion(Exception):
'''
Reached a state from where it's not possible to advance the acceptor (a rare condition).
For example, when closing a JSON string we get a higher probability for curly quotes than ASCII
ones and thus select the wrong token. The LLM then continues generating as if the string
has been closed, but the acceptor remains awaiting a close quote. Could be a bug in the
tokenizer vocabulary passed to the acceptor, or in the code decoding tokens from the LLM.
Could also be an inability of the LLM to generate JSON, although most can.
'''
class Model:
def __init__(self):
mx.random.seed(0)
self.model = None
self.tokenizer = None
self.vocabulary = None
self.eos_id = None
self.json_schema_acceptor_driver_factory = None
self._cached_prompt = None
self._cached_cache = None
def load(self, model_path: str):
'''
Load locally or download from Huggingface hub.
'''
self.model, tokenizer = load(model_path)
self.tokenizer = HuggingfaceTokenizerHelper(tokenizer)
self.simple_tokenizer = tokenizer
self.vocabulary, self.eos_id = self.tokenizer.extract_vocabulary()
self.json_schema_acceptor_driver_factory = (
JsonSchemaAcceptorDriver.driver_factory_for_model(
self.vocabulary, self.eos_id
)
)
def get_driver_for_json_schema(self, schema, encapsulated: bool = False):
return self.json_schema_acceptor_driver_factory(
schema, is_encapsulated_json=encapsulated
)
def _evaluate_prompt(
self, prompt: list[int], prior_prompt: list[int] = None, prior_cache=None
):
if prior_prompt:
i = 0
for i, t in enumerate(prior_prompt):
# Need to leave at least one token to evaluate because we don't
# save the past logits.
if i >= len(prompt) - 1 or prompt[i] != t:
break
cache = prior_cache
for layer_cache in cache:
layer_cache.reuse(len(prompt), i)
tokens = prompt[i:]
# print('CACHED', tokens, prompt, i)
else:
cache = ReusableKVCache.for_model(self.model)
tokens = prompt
# print('UNCACHED', tokens)
logits = self.model(mx.array(tokens)[None], cache=cache)
return logits, cache
def _decode(self, tokens):
return self.tokenizer.no_strip_decode(tokens)
def _debug_top_tokens(self, logits, count=10):
token_logits = sorted(
enumerate(logits.tolist()), key=itemgetter(1), reverse=True
)
top_tokens = [
(self._decode([t]), p) for t, p in token_logits[:count] if p != -inf
]
debug('TOP TOKENS:', top_tokens)
def _sample(self, logits, temp: float = 0):
if temp == 0:
result = mx.argmax(logits, axis=-1)
else:
result = mx.random.categorical(logits * (1 / temp))
return result.item()
def _sample_with_bias(
self, logits, temp: float = 0, token_acceptor=None, lazy_bias: bool = True
):
if token_acceptor is None:
return self._sample(logits, temp)
if lazy_bias:
token = self._sample(logits, temp)
try:
token_acceptor.advance_token(token)
return token
except JsonSchemaAcceptorDriver.TokenRejected:
pass
accepted_token_bitmap = token_acceptor.select_valid_tokens()
if not accepted_token_bitmap:
raise RejectedCompletion()
token = self._sample(bias_logits(mx, logits, accepted_token_bitmap), temp)
token_acceptor.advance_token(token)
return token
def generate_without_schema(self, logits, cache, temp: Optional[float] = 0.0):
'''
For testing / comparison purposes.
'''
while True:
tokens = [self._sample(logits[0, -1, :], temp)]
yield tokens
if tokens[-1] == self.eos_id:
break
logits = self.model(mx.array(tokens)[None], cache=cache)
def generate_with_schema(
self, logits, cache, token_acceptor, temp: Optional[float] = 0.0
):
while True:
tokens = [self._sample_with_bias(logits[0, -1, :], temp, token_acceptor)]
yield tokens
if tokens[-1] == self.eos_id:
break
logits = self.model(mx.array(tokens)[None], cache=cache)
def generate_with_preemptive_decoding(
self,
logits,
cache,
token_acceptor,
temp: Optional[float] = 0.0,
max_batch_size=5,
):
'''
Try to generate faster by precomputing two tokens at a time when possible.
If we know that the acceptor will only accept a small set of tokens after
the current one, we can evaluate a batch with one entry per possible
future token. Each entry in the batch contains the current token sampled,
which we have to evaluate anyway, and a second token corresponding to one
of the possible tokens that could be sampled from the output to the first
token. We get back logits for both tokens for each item in the batch: the
logits for the first token will be the same (as long as the model applies
a causal mask), and we can sample those logits to select from which of the
items in the batch we can select the second token.
In practice, this only seems to accelerate things for unquantized models.
'''
# Sample token from prompt evaluation
accepted_token_bitmap = token_acceptor.select_valid_tokens()
first_token_logits = bias_logits(mx, logits[0, -1, :], accepted_token_bitmap)
first_token = self._sample(first_token_logits, temp)
tokens = [first_token]
yield tokens
token_acceptor.advance_token(first_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
while True:
last_token = tokens[-1]
if count_set_bits(accepted_token_bitmap) in range(1, max_batch_size + 1):
# If the number of possible follow-up tokens is small, submit for
# evaluation a batch of 2-token continuations.
batch = []
for followup_token in enumerate_set_bits(accepted_token_bitmap):
batch.append([last_token, followup_token])
# Re-shape the cache to match the input.
for layer_cache in cache:
layer_cache.keys = mx.concatenate([layer_cache.keys] * len(batch))
layer_cache.values = mx.concatenate(
[layer_cache.values] * len(batch)
)
else: # Otherwise, submit the normal one-token continuation.
batch = [[last_token]]
logits = self.model(mx.array(batch), cache=cache)
mx.eval(logits)
first_token_logits = bias_logits(mx, logits[0, 0, :], accepted_token_bitmap)
first_token = self._sample(first_token_logits, temp)
tokens = [first_token]
if first_token == self.eos_id:
yield tokens
break
token_acceptor.advance_token(first_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
if not accepted_token_bitmap:
raise RejectedCompletion()
# If we had submitted 2-token continuations, we can decode a second token
if len(batch[0]) > 1:
index = next( # Find which of the second tokens was selected
i
for i, batch_item in enumerate(batch)
if batch_item[1] == first_token
)
second_token_logits = bias_logits(
mx, logits[index, 1, :], accepted_token_bitmap
)
second_token = self._sample(second_token_logits, temp)
tokens.append(second_token)
token_acceptor.advance_token(second_token)
accepted_token_bitmap = token_acceptor.select_valid_tokens()
# Select the accepted generation in the cache, restoring it to batch dimension 1.
for layer_cache in cache:
layer_cache.keys = layer_cache.keys.split([index, index + 1])[1]
layer_cache.values = layer_cache.values.split([index, index + 1])[1]
yield tokens
def _generate_tokens(
self,
generator: Iterable,
max_tokens: int = 1000,
) -> Iterable:
start_time = time.time_ns()
token_count = 0
for tokens in generator:
token_count += len(tokens)
try:
eos_index = tokens.index(self.eos_id)
tokens = tokens[0:eos_index]
except ValueError:
eos_index = -1
if tokens:
text = self._decode(tokens)
yield {
'op': 'generatedTokens',
'text': text,
'token_count': len(tokens),
'time_ms': (time.time_ns() - start_time) / 1e6,
}
if eos_index >= 0:
yield {'op': 'stop', 'reason': 'end'}
return
if token_count >= max_tokens:
yield {'op': 'stop', 'reason': 'max_tokens'}
return
start_time = time.time_ns()
assert False
def completion(
self,
prompt: Union[str, Iterable[dict[str, str]]],
schema: dict,
encapsulated: bool = False,
max_tokens: int = 1000,
temp: float = 0.0,
seed: int = None,
preemptive_batch_size: int = 0,
cache_prompt: bool = False,
):
if seed is not None:
mx.random.seed(seed)
start_time = time.time_ns()
prompt_tokens = self.tokenizer.encode_prompt(prompt)
logits, cache = self._evaluate_prompt(
prompt_tokens, self._cached_prompt, self._cached_cache
)
if cache_prompt:
self._cached_prompt = prompt_tokens
self._cached_cache = cache
# Eager eval to more accurately reflect the prompt evaluation time.
mx.eval(logits)
prompt_time = time.time_ns() - start_time
yield {
'op': 'evaluatedPrompt',
'prompt': prompt,
'token_count': len(prompt_tokens),
'time_ms': prompt_time / 1e6,
'prompt_tps': len(prompt_tokens) / (prompt_time / 1e9),
}
if schema:
token_acceptor = self.get_driver_for_json_schema(schema, encapsulated)
if preemptive_batch_size > 0:
generator = self.generate_with_preemptive_decoding(
logits,
cache,
token_acceptor,
temp,
max_batch_size=preemptive_batch_size,
)
else:
generator = self.generate_with_schema(
logits, cache, token_acceptor, temp
)
else:
generator = self.generate_without_schema(logits, cache, temp)
token_count = 0
generation_time = 0
for generation_result in self._generate_tokens(generator, max_tokens):
if generation_result['op'] == 'generatedTokens':
token_count += generation_result['token_count']
generation_time += generation_result['time_ms']
elif generation_result['op'] == 'stop':
generation_result['token_count'] = token_count
generation_result['time_ms'] = generation_time
if generation_time == 0.0:
# Happens, believe it or not
generation_result['generation_tps'] = float('inf')
else:
# Slightly incorrect, because the first token is generated from the prompt evaluation
generation_result['generation_tps'] = token_count / (
generation_time / 1e3
)
yield generation_result
class ReusableKVCache(KVCache):
'''
Usability improvements over MLX's KVCache.
'''
@classmethod
def for_model(cls, model):
return [cls() for _ in model.layers]
def reuse(self, new_prompt_length, common_prefix_length):
'''
Reuse (part of) this cache for a new prompt that shares a prefix with it.
'''
if self.keys is None:
return
# Clip the cache to the common length.
self.offset = common_prefix_length
# Ensure cache can fit the whole prompt. Because the offset is (very likely) not a multiple of the step size,
# update_and_fetch() won't resize the cache when evaluating the rest of the prompt as it
# would if it were an empty cache.
current_size = self.keys.shape[2]
if current_size < new_prompt_length:
_, n_kv_heads, _, k_head_dim = self.keys.shape
v_head_dim = self.values.shape[3]
n_steps = (self.step + new_prompt_length - 1) // self.step
k_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, k_head_dim)
v_add_shape = (1, n_kv_heads, n_steps * self.step - current_size, v_head_dim)
k_zeros = mx.zeros(k_add_shape, self.keys.dtype)
v_zeros = mx.zeros(v_add_shape, self.values.dtype)
self.keys = mx.concatenate([self.keys, k_zeros], axis=2)
self.values = mx.concatenate([self.values, v_zeros], axis=2)
def update_and_fetch(self, keys, values):
'''
Override base class method to allow the cache to be used with batches size >1
(Just a tiny change in the line that determines the shape)
'''
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B, n_kv_heads, _, k_head_dim = keys.shape
v_head_dim = values.shape[3]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim)
v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]