Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(AI-proxy-plugin): support the new /v1/messages API provided by Anthropic #12699

Merged
merged 5 commits into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions changelog/unreleased/kong/add-messages-api-to-anthropic.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"message": |
**AI-Proxy**: To support the new messages API of `Anthropic`, the upstream path of the `Anthropic` for `llm/v1/chat` route type is changed from `/v1/complete` to `/v1/messages`
"type": breaking_change
"scope": Plugin
"jiras":
- FTI-5770
69 changes: 53 additions & 16 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ end

local function kong_messages_to_claude_prompt(messages)
local buf = buffer.new()
buf:reset()

-- We need to flatten the messages into an assistant chat history for Claude
for _, v in ipairs(messages) do
Expand All @@ -44,6 +43,24 @@ local function kong_messages_to_claude_prompt(messages)
return buf:get()
end

-- reuse the messages structure of prompt
-- extract messages and system from kong request
local function kong_messages_to_claude_messages(messages)
local msgs, system, n = {}, nil, 1

for _, v in ipairs(messages) do
if v.role ~= "assistant" and v.role ~= "user" then
system = v.content

else
msgs[n] = v
n = n + 1
end
end

return msgs, system
end


local function to_claude_prompt(req)
if req.prompt then
Expand All @@ -57,22 +74,29 @@ local function to_claude_prompt(req)
return nil, "request is missing .prompt and .messages commands"
end

local function to_claude_messages(req)
if req.messages then
return kong_messages_to_claude_messages(req.messages)
end

return nil, nil, "request is missing .messages command"
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model)
local prompt = {}
local messages = {}
local err

prompt.prompt, err = to_claude_prompt(request_table)
if err then
messages.messages, messages.system, err = to_claude_messages(request_table)
if err then
return nil, nil, err
end

prompt.temperature = (model.options and model.options.temperature) or nil
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or nil
prompt.model = model.name

return prompt, "application/json", nil
messages.temperature = (model.options and model.options.temperature) or nil
messages.max_tokens = (model.options and model.options.max_tokens) or nil
messages.model = model.name

return messages, "application/json", nil
end,

["llm/v1/completions"] = function(request_table, model)
Expand All @@ -83,7 +107,7 @@ local transformers_to = {
if err then
return nil, nil, err
end

prompt.temperature = (model.options and model.options.temperature) or nil
prompt.max_tokens_to_sample = (model.options and model.options.max_tokens) or nil
prompt.model = model.name
Expand All @@ -96,36 +120,49 @@ local transformers_from = {
["llm/v1/chat"] = function(response_string)
local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode cohere response"
return nil, "failed to decode anthropic response"
end

if response_table.completion then
local function extract_text_from_content(content)
local buf = buffer.new()
for i, v in ipairs(content) do
if i ~= 1 then
buf:put("\n")
end

buf:put(v.text)
end

return buf:tostring()
end

if response_table.content then
local res = {
choices = {
{
index = 0,
message = {
role = "assistant",
content = response_table.completion,
content = extract_text_from_content(response_table.content),
},
finish_reason = response_table.stop_reason,
},
},
model = response_table.model,
object = "chat.completion",
object = "chat.content",
}

return cjson.encode(res)
else
-- it's probably an error block, return generic error
return nil, "'completion' not in anthropic://llm/v1/chat response"
return nil, "'content' not in anthropic://llm/v1/chat response"
end
end,

["llm/v1/completions"] = function(response_string)
local response_table, err = cjson.decode(response_string)
if err then
return nil, "failed to decode cohere response"
return nil, "failed to decode anthropic response"
end

if response_table.completion then
Expand Down
2 changes: 1 addition & 1 deletion kong/llm/drivers/shared.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ _M.operation_map = {
method = "POST",
},
["llm/v1/chat"] = {
path = "/v1/complete",
path = "/v1/messages",
method = "POST",
},
},
Expand Down
6 changes: 3 additions & 3 deletions spec/03-plugins/38-ai-proxy/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ local FORMATS = {
},
anthropic = {
["llm/v1/chat"] = {
name = "claude-2",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 512,
Expand All @@ -82,7 +82,7 @@ local FORMATS = {
},
},
["llm/v1/completions"] = {
name = "claude-2",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 512,
Expand Down Expand Up @@ -300,7 +300,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
assert.is_nil(err)

-- compare the tables
assert.same(actual_response_table.choices[1].message, expected_response_table.choices[1].message)
assert.same(expected_response_table.choices[1].message, actual_response_table.choices[1].message)
assert.same(actual_response_table.model, expected_response_table.model)
end)

Expand Down
26 changes: 13 additions & 13 deletions spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)

if err or (not body.prompt) then
if err or (not body.messages) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json"))
else
Expand All @@ -61,7 +61,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)

if err or (not body.prompt) then
if err or (not body.messages) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json"))
else
Expand Down Expand Up @@ -156,7 +156,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -186,7 +186,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -216,7 +216,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo-instruct",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -246,7 +246,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "wrong-key",
},
model = {
name = "gpt-3.5-turbo",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -276,7 +276,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -306,7 +306,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo-instruct",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -336,7 +336,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
header_value = "anthropic-key",
},
model = {
name = "gpt-3.5-turbo",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
Expand Down Expand Up @@ -440,8 +440,8 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then

-- check this is in the 'kong' response format
-- assert.equals(json.id, "chatcmpl-8T6YwgvjQVVnGbJ2w8hpOA17SeNy2")
assert.equals(json.model, "claude-2")
assert.equals(json.object, "chat.completion")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "chat.content")

assert.is_table(json.choices)
assert.is_table(json.choices[1].message)
Expand All @@ -463,7 +463,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check we got internal server error
local body = assert.res_status(500 , r)
local json = cjson.decode(body)
assert.equals(json.error.message, "transformation failed from type anthropic://llm/v1/chat: 'completion' not in anthropic://llm/v1/chat response")
assert.equals(json.error.message, "transformation failed from type anthropic://llm/v1/chat: 'content' not in anthropic://llm/v1/chat response")
end)

it("bad request", function()
Expand Down Expand Up @@ -496,7 +496,7 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
local json = cjson.decode(body)

-- check this is in the 'kong' response format
assert.equals(json.model, "claude-2")
assert.equals(json.model, "claude-2.1")
assert.equals(json.object, "text_completion")

assert.is_table(json.choices)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ local FORMATS = {
header_value = "Bearer cohere-key",
},
},
authropic = {
anthropic = {
route_type = "llm/v1/chat",
model = {
name = "claude-2",
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 512,
Expand Down Expand Up @@ -185,7 +185,6 @@ describe(PLUGIN_NAME .. ": (unit)", function()
if err or (body.messages == ngx.null) then
ngx.status = 400
ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/llm-v1-chat/responses/bad_request.json"))

else
ngx.status = 200
ngx.say(pl_file.read(base_dir .. ngx.var.provider .. "/request-transformer/response-in-json.json"))
Expand Down Expand Up @@ -214,6 +213,7 @@ describe(PLUGIN_NAME .. ": (unit)", function()
assert(mock:stop())
end)


for name, format_options in pairs(FORMATS) do
describe(name .. " transformer tests, exact json response", function()
it("transforms request based on LLM instructions", function()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
{
"completion": "The sum of 1 + 1 is 2.",
"content": [
{
"text": "The sum of 1 + 1 is 2.",
"type": "text"
}
],
"stop_reason": "stop_sequence",
"model": "claude-2"
"model": "claude-2.1"
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"completion": " Hello! My name is Claude.",
"stop_reason": "stop_sequence",
"model": "claude-2"
"model": "claude-2.1"
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,18 @@
{
"completion": "{\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n",
"stop_reason": "stop_sequence",
"model": "claude-2"
}
"content": [
{
"text": "{\n \"persons\": [\n {\n \"name\": \"Kong A\",\n \"age\": 62\n },\n {\n \"name\": \"Kong B\",\n \"age\": 84\n }\n ]\n }\n",
"type": "text"
}
],
"id": "msg_013Zva2CMHLNnXjNJJKqJ2EF",
"model": "claude-2.1",
"role": "assistant",
"stop_reason": "end_turn",
"stop_sequence": null,
"type": "message",
"usage": {
"input_tokens": 10,
"output_tokens": 25
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,28 @@
{
"model": "claude-2",
"prompt": "You are a mathematician.\n\nHuman: What is 1 + 2?\n\nAssistant: The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!\n\nHuman: Multiply that by 2\n\nAssistant: Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!\n\nHuman: Why can't you divide by zero?\n\nAssistant:",
"max_tokens_to_sample": 512,
"model": "claude-2.1",
"messages": [
{
"role": "user",
"content": "What is 1 + 2?"
},
{
"role": "assistant",
"content": "The sum of 1 + 2 is 3. If you have any more math questions or if there's anything else I can help you with, feel free to ask!"
},
{
"role": "user",
"content": "Multiply that by 2"
},
{
"role": "assistant",
"content": "Certainly! If you multiply 3 by 2, the result is 6. If you have any more questions or if there's anything else I can help you with, feel free to ask!"
},
{
"role": "user",
"content": "Why can't you divide by zero?"
}
],
"system": "You are a mathematician.",
"max_tokens": 512,
"temperature": 0.5
}
Loading
Loading