Skip to content

Commit e733a9e

Browse files
DebuggingLife46LostRuins
DebuggingLife46
andauthored
Add logit_bias to the OpenAI api (ggml-org#577)
* Add logit_bias to the OpenAI api * Cleanup and refactor, test in swagger. --------- Co-authored-by: Concedo <39025047+LostRuins@users.noreply.github.com>
1 parent 5006b23 commit e733a9e

File tree

4 files changed

+70
-5
lines changed

4 files changed

+70
-5
lines changed

expose.h

+6
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
const int stop_token_max = 16;
44
const int ban_token_max = 16;
55
const int tensor_split_max = 16;
6+
const int logit_bias_max = 16;
67
// match kobold's sampler list and order
78
enum samplers
89
{
@@ -22,6 +23,10 @@ enum stop_reason
2223
EOS_TOKEN=1,
2324
CUSTOM_STOPPER=2,
2425
};
26+
struct logit_bias {
27+
int32_t token_id;
28+
float bias;
29+
};
2530
struct load_model_inputs
2631
{
2732
const int threads;
@@ -76,6 +81,7 @@ struct generation_inputs
7681
const char * grammar;
7782
const bool grammar_retain_state;
7883
const bool quiet = false;
84+
const logit_bias logit_biases[logit_bias_max];
7985
};
8086
struct generation_outputs
8187
{

gpttype_adapter.cpp

+18
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ static int stopper_unused_tokens = 0;
101101
static std::mutex concat_output_mtx;
102102
static std::string concat_output = "";
103103
static std::string concat_output_reader_copy = "";
104+
static std::vector<logit_bias> logit_biases;
104105

105106
const int extra_context_handle_fragmentation = 80;
106107

@@ -489,6 +490,12 @@ int mirostat, float mirostat_tau, float mirostat_eta, const std::vector<samplers
489490
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
490491
}
491492

493+
for(int i=0;i<logit_biases.size();++i)
494+
{
495+
auto & itm = logit_biases[i];
496+
candidates[itm.token_id].logit += itm.bias;
497+
}
498+
492499
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
493500

494501
if (grammar != nullptr) {
@@ -1437,6 +1444,17 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
14371444
}
14381445
}
14391446

1447+
logit_biases.clear();
1448+
for(int x=0;x<logit_bias_max;++x)
1449+
{
1450+
int32_t t_id = inputs.logit_biases[x].token_id;
1451+
float bias = inputs.logit_biases[x].bias;
1452+
if(t_id >= 0 && t_id < n_vocab && bias!=0)
1453+
{
1454+
logit_biases.push_back(inputs.logit_biases[x]);
1455+
}
1456+
}
1457+
14401458
std::string addedmemory = inputs.memory;
14411459
params.prompt = inputs.prompt;
14421460
params.seed = inputs.seed;

kcpp_docs.embd

+12-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
<!-- schema -->
1010
<script>
11-
let spec = {
11+
let spec = {
1212
"components": {
1313
"schemas": {
1414
"BasicError": {
@@ -176,7 +176,17 @@
176176
"default": false,
177177
"description": "KoboldCpp ONLY. If true, also removes detected stop_sequences from the output and truncates all text after them. Does not work with SSE streaming.",
178178
"type": "boolean"
179-
}
179+
},
180+
"logit_bias": {
181+
"default": {},
182+
"description": "KoboldCpp ONLY. An dictionary of key-value pairs, which indicate the token IDs (int) and logit bias (float) to apply for that token. Up to 16 value can be provided.",
183+
"type": "object",
184+
"example": {
185+
"2": -20,
186+
"145": -1.4,
187+
"3105": 3.2
188+
},
189+
},
180190
},
181191
"required": [
182192
"prompt"

koboldcpp.py

+34-3
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818
stop_token_max = 16
1919
ban_token_max = 16
2020
tensor_split_max = 16
21+
logit_bias_max = 16
22+
bias_min_value = -100.0
23+
bias_max_value = 100.0
2124

2225
class load_model_inputs(ctypes.Structure):
2326
_fields_ = [("threads", ctypes.c_int),
@@ -44,6 +47,10 @@ class load_model_inputs(ctypes.Structure):
4447
("banned_tokens", ctypes.c_char_p * ban_token_max),
4548
("tensor_split", ctypes.c_float * tensor_split_max)]
4649

50+
class logit_bias(ctypes.Structure):
51+
_fields_ = [("token_id", ctypes.c_int32),
52+
("bias", ctypes.c_float)]
53+
4754
class generation_inputs(ctypes.Structure):
4855
_fields_ = [("seed", ctypes.c_int),
4956
("prompt", ctypes.c_char_p),
@@ -70,7 +77,8 @@ class generation_inputs(ctypes.Structure):
7077
("stream_sse", ctypes.c_bool),
7178
("grammar", ctypes.c_char_p),
7279
("grammar_retain_state", ctypes.c_bool),
73-
("quiet", ctypes.c_bool)]
80+
("quiet", ctypes.c_bool),
81+
("logit_biases", logit_bias * logit_bias_max)]
7482

7583
class generation_outputs(ctypes.Structure):
7684
_fields_ = [("status", ctypes.c_int),
@@ -301,7 +309,7 @@ def load_model(model_filename):
301309
ret = handle.load_model(inputs)
302310
return ret
303311

304-
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False):
312+
def generate(prompt, memory="", max_length=32, max_context_length=512, temperature=0.7, top_k=100, top_a=0.0, top_p=0.92, min_p=0.0, typical_p=1.0, tfs=1.0, rep_pen=1.1, rep_pen_range=128, presence_penalty=0.0, mirostat=0, mirostat_tau=5.0, mirostat_eta=0.1, sampler_order=[6,0,1,3,4,2,5], seed=-1, stop_sequence=[], use_default_badwordsids=False, stream_sse=False, grammar='', grammar_retain_state=False, genkey='', trimstop=False, quiet=False, logit_biases={}):
305313
global maxctx, args, currentusergenkey, totalgens
306314
inputs = generation_inputs()
307315
outputs = ctypes.create_unicode_buffer(ctypes.sizeof(generation_outputs))
@@ -355,6 +363,28 @@ def generate(prompt, memory="", max_length=32, max_context_length=512, temperatu
355363
inputs.stop_sequence[n] = "".encode("UTF-8")
356364
else:
357365
inputs.stop_sequence[n] = stop_sequence[n].encode("UTF-8")
366+
367+
bias_list = []
368+
try:
369+
if logit_biases and len(logit_biases) > 0:
370+
bias_list = [{"key": key, "value": value} for key, value in logit_biases.items()]
371+
except Exception as ex:
372+
print(f"Logit bias dictionary is invalid: {ex}")
373+
374+
for n in range(logit_bias_max):
375+
if n >= len(bias_list):
376+
inputs.logit_biases[n] = logit_bias(-1, 0.0)
377+
else:
378+
try:
379+
t_id = int(bias_list[n]['key'])
380+
bias = float(bias_list[n]['value'])
381+
t_id = -1 if t_id < 0 else t_id
382+
bias = (bias_max_value if bias > bias_max_value else (bias_min_value if bias < bias_min_value else bias))
383+
inputs.logit_biases[n] = logit_bias(t_id, bias)
384+
except Exception as ex:
385+
inputs.logit_biases[n] = logit_bias(-1, 0.0)
386+
print(f"Skipped unparsable logit bias:{ex}")
387+
358388
currentusergenkey = genkey
359389
totalgens += 1
360390
ret = handle.generate(inputs,outputs)
@@ -515,7 +545,8 @@ def run_blocking(): #api format 1=basic,2=kai,3=oai,4=oai-chat
515545
grammar_retain_state = genparams.get('grammar_retain_state', False),
516546
genkey=genparams.get('genkey', ''),
517547
trimstop=genparams.get('trim_stop', False),
518-
quiet=is_quiet)
548+
quiet=is_quiet,
549+
logit_biases=genparams.get('logit_bias', {}))
519550

520551
recvtxt = ""
521552
if stream_flag:

0 commit comments

Comments
 (0)