Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(AI-Proxy): improve the robustness of anthropic's statistics #12854

Merged
merged 4 commits into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,20 @@ local transformers_from = {
end

if response_table.content then
local usage = response_table.usage

if usage then
usage = {
prompt_tokens = usage.input_tokens or nil,
hanshuebner marked this conversation as resolved.
Show resolved Hide resolved
completion_tokens = usage.output_tokens or nil,
total_tokens = usage.input_tokens and usage.output_tokens and
usage.input_tokens + usage.output_tokens or nil,
}

else
usage = "no usage data returned from upstream"
end

local res = {
choices = {
{
Expand All @@ -148,16 +162,11 @@ local transformers_from = {
finish_reason = response_table.stop_reason,
},
},
usage = {
prompt_tokens = response_table.usage.input_tokens or 0,
completion_tokens = response_table.usage.output_tokens or 0,
total_tokens = response_table.usage.input_tokens and response_table.usage.output_tokens and
response_table.usage.input_tokens + response_table.usage.output_tokens or 0,
},
usage = usage,
model = response_table.model,
object = "chat.content",
}

return cjson.encode(res)
else
-- it's probably an error block, return generic error
Expand Down
142 changes: 140 additions & 2 deletions spec/03-plugins/38-ai-proxy/03-anthropic_integration_spec.lua
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
local helpers = require "spec.helpers"
local cjson = require "cjson"
local pl_file = require "pl.file"
local deepcompare = require("pl.tablex").deepcompare

local PLUGIN_NAME = "ai-proxy"
local MOCK_PORT = helpers.get_available_port()
Expand Down Expand Up @@ -75,6 +76,56 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
}

location = "/llm/v1/chat/no_usage_upstream_response" {
content_by_lua_block {
local pl_file = require "pl.file"
local json = require("cjson.safe")

local token = ngx.req.get_headers()["x-api-key"]
if token == "anthropic-key" then
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)

if err or (not body.messages) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json"))
else
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/no_usage_response.json"))
end
else
ngx.status = 401
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json"))
end
}
}

location = "/llm/v1/chat/malformed_usage_upstream_response" {
content_by_lua_block {
local pl_file = require "pl.file"
local json = require("cjson.safe")

local token = ngx.req.get_headers()["x-api-key"]
if token == "anthropic-key" then
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
body, err = json.decode(body)

if err or (not body.messages) then
ngx.status = 400
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/bad_request.json"))
else
ngx.status = 200
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/malformed_usage_response.json"))
end
else
ngx.status = 401
ngx.print(pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/responses/unauthorized.json"))
end
}
}

location = "/llm/v1/chat/bad_request" {
content_by_lua_block {
local pl_file = require "pl.file"
Expand Down Expand Up @@ -170,15 +221,15 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
--

-- 200 chat bad upstream response with one option
local chat_good = assert(bp.routes:insert {
local chat_bad = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/anthropic/llm/v1/chat/bad_upstream_response" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_good.id },
route = { id = chat_bad.id },
config = {
route_type = "llm/v1/chat",
auth = {
Expand All @@ -199,6 +250,65 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
}
--

-- 200 chat no-usage response
local chat_no_usage = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/anthropic/llm/v1/chat/no_usage_upstream_response" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_no_usage.id },
config = {
route_type = "llm/v1/chat",
auth = {
header_name = "x-api-key",
header_value = "anthropic-key",
},
model = {
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/no_usage_upstream_response",
anthropic_version = "2023-06-01",
},
},
},
}
--

-- 200 chat malformed-usage response
local chat_malformed_usage = assert(bp.routes:insert {
service = empty_service,
protocols = { "http" },
strip_path = true,
paths = { "/anthropic/llm/v1/chat/malformed_usage_upstream_response" }
})
bp.plugins:insert {
name = PLUGIN_NAME,
route = { id = chat_malformed_usage.id },
config = {
route_type = "llm/v1/chat",
auth = {
header_name = "x-api-key",
header_value = "anthropic-key",
},
model = {
name = "claude-2.1",
provider = "anthropic",
options = {
max_tokens = 256,
temperature = 1.0,
upstream_url = "http://"..helpers.mock_upstream_host..":"..MOCK_PORT.."/llm/v1/chat/malformed_usage_upstream_response",
anthropic_version = "2023-06-01",
},
},
},
}

-- 200 completions good with one option
local completions_good = assert(bp.routes:insert {
service = empty_service,
Expand Down Expand Up @@ -487,6 +597,34 @@ for _, strategy in helpers.all_strategies() do if strategy ~= "cassandra" then
-- check this is in the 'kong' response format
assert.equals(json.error.message, "request format not recognised")
end)

it("no usage response", function()
local r = client:get("/anthropic/llm/v1/chat/no_usage_upstream_response", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"),
})

local body = assert.res_status(200 , r)
local json = cjson.decode(body)
assert.equals(json.usage, "no usage data returned from upstream")
end)

it("malformed usage response", function()
local r = client:get("/anthropic/llm/v1/chat/malformed_usage_upstream_response", {
headers = {
["content-type"] = "application/json",
["accept"] = "application/json",
},
body = pl_file.read("spec/fixtures/ai-proxy/anthropic/llm-v1-chat/requests/good.json"),
})

local body = assert.res_status(200 , r)
local json = cjson.decode(body)
assert.is_truthy(deepcompare(json.usage, {}))
end)
end)

describe("anthropic llm/v1/completions", function()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"content": [
{
"text": "The sum of 1 + 1 is 2.",
"type": "text"
}
],
"model": "claude-2.1",
"stop_reason": "end_turn",
"stop_sequence": "string",
"usage": {
"foo": 0,
"bar": 0
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
{
"content": [
{
"text": "The sum of 1 + 1 is 2.",
"type": "text"
}
],
"model": "claude-2.1",
"stop_reason": "end_turn",
"stop_sequence": "string"
}
Loading