Skip to content

Commit

Permalink
feat(ai-proxy): add streaming support and transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
tysoekong committed Mar 27, 2024
1 parent 6d31868 commit c01b0ab
Show file tree
Hide file tree
Showing 17 changed files with 639 additions and 118 deletions.
4 changes: 4 additions & 0 deletions changelog/unreleased/kong/feat-ai-proxy-add-streaming.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
**AI-Proxy**: add support for streaming event-by-event responses back to client on supported providers
scope: Plugin
type: feature
4 changes: 2 additions & 2 deletions kong/llm/drivers/anthropic.lua
Original file line number Diff line number Diff line change
Expand Up @@ -229,13 +229,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
5 changes: 2 additions & 3 deletions kong/llm/drivers/azure.lua
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand All @@ -82,7 +82,6 @@ end

-- returns err or nil
function _M.configure_request(conf)

local parsed_url

if conf.model.options.upstream_url then
Expand Down
109 changes: 105 additions & 4 deletions kong/llm/drivers/cohere.lua
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,104 @@ local table_new = require("table.new")
local DRIVER_NAME = "cohere"
--

local function handle_stream_event(event_string, model_info, route_type)
local metadata

local event, err = cjson.decode(event_string)
if err then
return nil, "failed to decode event frame from cohere: " .. err, nil
end

local new_event

if event.event_type == "stream-start" then
kong.ctx.plugin.ai_proxy_cohere_stream_id = event.generation_id

-- ignore the rest of this one
new_event = {
choices = {
[1] = {
delta = {
content = "",
role = "assistant",
},
index = 0,
},
},
id = event.generation_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif event.event_type == "text-generation" then
-- this is a token
if route_type == "stream/llm/v1/chat" then
new_event = {
choices = {
[1] = {
delta = {
content = event.text or "",
},
index = 0,
},
},
id = kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "chat.completion.chunk",
}

elseif route_type == "stream/llm/v1/completions" then
new_event = {
choices = {
[1] = {
text = event.text or "",
index = 0,
},
},
id = kong.ctx.plugin.ai_proxy_cohere_stream_id,
model = model_info.name,
object = "text_completion",
}

end

elseif event.event_type == "stream-end" then
-- return a metadata object, with a null event
metadata = {
-- prompt_tokens = event.response.token_count.prompt_tokens,
-- completion_tokens = event.response.token_count.response_tokens,

completion_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.output_tokens
or
event.response
and event.response.token_count
and event.response.token_count.response_tokens
or 0,

prompt_tokens = event.response
and event.response.meta
and event.response.meta.billed_units
and event.response.meta.billed_units.input_tokens
or
event.response
and event.response.token_count
and event.token_count.prompt_tokens
or 0,
}

end

if new_event then
new_event = cjson.encode(new_event)
return new_event, nil, metadata
else
return nil, nil, metadata -- caller code will handle "unrecognised" event types
end
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model)
request_table.model = model.name
Expand Down Expand Up @@ -243,6 +341,9 @@ local transformers_from = {

return cjson.encode(prompt)
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand All @@ -253,7 +354,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("no transformer available from format %s://%s", model_info.provider, route_type)
end

local ok, response_string, err = pcall(transformers_from[route_type], response_string, model_info)
local ok, response_string, err, metadata = pcall(transformers_from[route_type], response_string, model_info, route_type)
if not ok or err then
return nil, fmt("transformation failed from type %s://%s: %s",
model_info.provider,
Expand All @@ -262,7 +363,7 @@ function _M.from_format(response_string, model_info, route_type)
)
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -344,13 +445,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
14 changes: 8 additions & 6 deletions kong/llm/drivers/llama2.lua
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ local transformers_from = {
["llm/v1/completions/raw"] = from_raw,
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand All @@ -155,8 +157,8 @@ function _M.from_format(response_string, model_info, route_type)
if not transformers_from[transformer_type] then
return nil, fmt("no transformer available from format %s://%s", model_info.provider, transformer_type)
end
local ok, response_string, err = pcall(

local ok, response_string, err, metadata = pcall(
transformers_from[transformer_type],
response_string,
model_info,
Expand All @@ -166,7 +168,7 @@ function _M.from_format(response_string, model_info, route_type)
return nil, fmt("transformation failed from type %s://%s: %s", model_info.provider, route_type, err or "unexpected_error")
end

return response_string, nil
return response_string, nil, metadata
end

function _M.to_format(request_table, model_info, route_type)
Expand Down Expand Up @@ -217,13 +219,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down Expand Up @@ -265,7 +267,7 @@ function _M.configure_request(conf)

kong.service.request.set_path(parsed_url.path)
kong.service.request.set_scheme(parsed_url.scheme)
kong.service.set_target(parsed_url.host, tonumber(parsed_url.port))
kong.service.set_target(parsed_url.host, (tonumber(parsed_url.port) or 443))

local auth_header_name = conf.auth and conf.auth.header_name
local auth_header_value = conf.auth and conf.auth.header_value
Expand Down
6 changes: 4 additions & 2 deletions kong/llm/drivers/mistral.lua
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ local DRIVER_NAME = "mistral"
local transformers_from = {
["llm/v1/chat/ollama"] = ai_shared.from_ollama,
["llm/v1/completions/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/chat/ollama"] = ai_shared.from_ollama,
["stream/llm/v1/completions/ollama"] = ai_shared.from_ollama,
}

local transformers_to = {
Expand Down Expand Up @@ -104,13 +106,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
23 changes: 20 additions & 3 deletions kong/llm/drivers/openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,18 @@ local socket_url = require "socket.url"
local DRIVER_NAME = "openai"
--

local function handle_stream_event(event_string)
if #event_string > 0 then
local lbl, val = event_string:match("(%w*): (.*)")

if lbl == "data" then
return val
end
end

return nil
end

local transformers_to = {
["llm/v1/chat"] = function(request_table, model, max_tokens, temperature, top_p)
-- if user passed a prompt as a chat, transform it to a chat message
Expand All @@ -29,8 +41,9 @@ local transformers_to = {
max_tokens = max_tokens,
temperature = temperature,
top_p = top_p,
stream = request_table.stream or false,
}

return this, "application/json", nil
end,

Expand All @@ -40,6 +53,7 @@ local transformers_to = {
model = model,
max_tokens = max_tokens,
temperature = temperature,
stream = request_table.stream or false,
}

return this, "application/json", nil
Expand Down Expand Up @@ -72,6 +86,9 @@ local transformers_from = {
return nil, "'choices' not in llm/v1/completions response"
end
end,

["stream/llm/v1/chat"] = handle_stream_event,
["stream/llm/v1/completions"] = handle_stream_event,
}

function _M.from_format(response_string, model_info, route_type)
Expand Down Expand Up @@ -155,13 +172,13 @@ function _M.subrequest(body, conf, http_opts, return_res_table)
headers[conf.auth.header_name] = conf.auth.header_value
end

local res, err = ai_shared.http_request(url, body_string, method, headers, http_opts)
local res, err, httpc = ai_shared.http_request(url, body_string, method, headers, http_opts, return_res_table)
if err then
return nil, nil, "request to ai service failed: " .. err
end

if return_res_table then
return res, res.status, nil
return res, res.status, nil, httpc
else
-- At this point, the entire request / response is complete and the connection
-- will be closed or back on the connection pool.
Expand Down
Loading

0 comments on commit c01b0ab

Please sign in to comment.