Skip to content

Commit ab97a8e

Browse files
ngxsonarthw
authored andcommitted
server : (embeddings) using same format for "input" and "content" (ggml-org#10872)
* server : (embeddings) using same format for "input" and "content" * fix test case * handle empty input case * fix test
1 parent 49c08d5 commit ab97a8e

File tree

3 files changed

+47
-9
lines changed

3 files changed

+47
-9
lines changed

examples/server/server.cpp

+14-6
Original file line numberDiff line numberDiff line change
@@ -3651,25 +3651,33 @@ int main(int argc, char ** argv) {
36513651
const json body = json::parse(req.body);
36523652
bool oaicompat = false;
36533653

3654-
// an input prompt can be a string or a list of tokens (integer)
3654+
// for the shape of input/content, see tokenize_input_prompts()
36553655
json prompt;
3656-
if (body.count("input") != 0) {
3656+
if (body.contains("input")) {
36573657
oaicompat = true;
36583658
prompt = body.at("input");
3659-
} else if (body.count("content") != 0) {
3660-
// with "content", we only support single prompt
3661-
prompt = std::vector<std::string>{body.at("content")};
3659+
} else if (body.contains("content")) {
3660+
oaicompat = false;
3661+
prompt = body.at("content");
36623662
} else {
36633663
res_error(res, format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
36643664
return;
36653665
}
36663666

3667+
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, true, true);
3668+
for (const auto & tokens : tokenized_prompts) {
3669+
// this check is necessary for models that do not add BOS token to the input
3670+
if (tokens.empty()) {
3671+
res_error(res, format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST));
3672+
return;
3673+
}
3674+
}
3675+
36673676
// create and queue the task
36683677
json responses = json::array();
36693678
bool error = false;
36703679
{
36713680
std::vector<server_task> tasks;
3672-
std::vector<llama_tokens> tokenized_prompts = tokenize_input_prompts(ctx_server.ctx, prompt, /* add_special */ false, true);
36733681
for (size_t i = 0; i < tokenized_prompts.size(); i++) {
36743682
server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING);
36753683
task.id = ctx_server.queue_tasks.get_new_id();

examples/server/tests/unit/test_embedding.py

+32-3
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,35 @@ def test_embedding_multiple():
4545
assert len(d['embedding']) > 1
4646

4747

48+
@pytest.mark.parametrize(
49+
"content,is_multi_prompt",
50+
[
51+
# single prompt
52+
("string", False),
53+
([12, 34, 56], False),
54+
([12, 34, "string", 56, 78], False),
55+
# multiple prompts
56+
(["string1", "string2"], True),
57+
(["string1", [12, 34, 56]], True),
58+
([[12, 34, 56], [12, 34, 56]], True),
59+
([[12, 34, 56], [12, "string", 34, 56]], True),
60+
]
61+
)
62+
def test_embedding_mixed_input(content, is_multi_prompt: bool):
63+
global server
64+
server.start()
65+
res = server.make_request("POST", "/embeddings", data={"content": content})
66+
assert res.status_code == 200
67+
if is_multi_prompt:
68+
assert len(res.body) == len(content)
69+
for d in res.body:
70+
assert 'embedding' in d
71+
assert len(d['embedding']) > 1
72+
else:
73+
assert 'embedding' in res.body
74+
assert len(res.body['embedding']) > 1
75+
76+
4877
def test_embedding_openai_library_single():
4978
global server
5079
server.start()
@@ -102,8 +131,8 @@ def test_same_prompt_give_same_result():
102131
@pytest.mark.parametrize(
103132
"content,n_tokens",
104133
[
105-
("I believe the meaning of life is", 7),
106-
("This is a test", 4),
134+
("I believe the meaning of life is", 9),
135+
("This is a test", 6),
107136
]
108137
)
109138
def test_embedding_usage_single(content, n_tokens):
@@ -126,4 +155,4 @@ def test_embedding_usage_multiple():
126155
})
127156
assert res.status_code == 200
128157
assert res.body['usage']['prompt_tokens'] == res.body['usage']['total_tokens']
129-
assert res.body['usage']['prompt_tokens'] == 2 * 7
158+
assert res.body['usage']['prompt_tokens'] == 2 * 9

examples/server/utils.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ static llama_tokens tokenize_mixed(const llama_context * ctx, const json & json_
138138
* and multiple prompts (multi-tasks):
139139
* - "prompt": ["string1", "string2"]
140140
* - "prompt": ["string1", [12, 34, 56]]
141+
* - "prompt": [[12, 34, 56], [78, 90, 12]]
141142
* - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56]]
142143
*/
143144
static std::vector<llama_tokens> tokenize_input_prompts(llama_context * ctx, const json & json_prompt, bool add_special, bool parse_special) {

0 commit comments

Comments
 (0)