Skip to content

Commit

Permalink
code review
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek committed Sep 23, 2024
1 parent e87c1e7 commit ceafa66
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 26 deletions.
33 changes: 23 additions & 10 deletions apisix/plugins/ai-rag.lua
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local http = require("resty.http")
local next = next
local require = require
local ngx_req = ngx.req
local core = require("apisix.core")

local http = require("resty.http")
local core = require("apisix.core")
local decorate = require("apisix.plugins.ai-prompt-decorator").__decorate
local next = next
local require = require

local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local BAD_REQUEST = ngx.HTTP_BAD_REQUEST

local azure_ai_search_schema = {
type = "object",
Expand Down Expand Up @@ -56,15 +60,17 @@ local schema = {
properties = {
azure_openai = azure_openai_embeddings
},
-- change to enum while implementing support for other search services
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_openai" },
},
vector_search_provider = {
type = "object",
properties = {
azure_ai_search = azure_ai_search_schema
},
-- change to enum while implementing support for other search services
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_ai_search" }
},
},
Expand Down Expand Up @@ -102,11 +108,11 @@ function _M.access(conf, ctx)
local httpc = http.new()
local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return 400, err
return BAD_REQUEST, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
return 400
return BAD_REQUEST
end

local embeddings_provider = next(conf.embeddings_provider)
Expand All @@ -127,7 +133,7 @@ function _M.access(conf, ctx)
local ok, err = core.schema.check(request_schema, body_tab)
if not ok then
core.log.error("request body fails schema check: ", err)
return 400
return BAD_REQUEST
end

local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
Expand All @@ -146,7 +152,10 @@ function _M.access(conf, ctx)
return status, err
end

-- remove ai_rag from request body because their purpose is served
-- also, these values will cause failure when proxying requests to LLM.
body_tab["ai_rag"] = nil

local prepend = {
{
role = "user",
Expand All @@ -160,7 +169,11 @@ function _M.access(conf, ctx)
body_tab.messages = {}
end
decorate(decorator_conf, body_tab)
local req_body_json = core.json.encode(body_tab)
local req_body_json, err = core.json.encode(body_tab)
if not req_body_json then
return INTERNAL_SERVER_ERROR, err
end

ngx_req.set_body_data(req_body_json)
end

Expand Down
32 changes: 25 additions & 7 deletions apisix/plugins/ai-rag/embeddings/azure_openai.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,41 @@
-- limitations under the License.
--
local core = require("apisix.core")
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local type = type

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = { "endpoint", "api_key" }
}

function _M.get_embeddings(conf, body, httpc)
local body_tab, err = core.json.encode(body)
if not body_tab then
return nil, INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = core.json.encode(body)
body = body_tab
})

if not res or not res.body then
return nil, err
return nil, res.status or INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Expand All @@ -40,16 +58,16 @@ function _M.get_embeddings(conf, body, httpc)

local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, internal_server_error, err
return nil, INTERNAL_SERVER_ERROR, err
end

if type(res_tab.data) ~= "table" or #res_tab.data < 1 then
return nil, internal_server_error, res.body
if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
return nil, INTERNAL_SERVER_ERROR, res.body
end

local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, internal_server_error, err
return nil, INTERNAL_SERVER_ERROR, err
end

return res_tab.data[1].embedding
Expand Down
22 changes: 19 additions & 3 deletions apisix/plugins/ai-rag/vector-search/azure_ai_search.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,22 @@
-- limitations under the License.
--
local core = require("apisix.core")
local internal_server_error = ngx.HTTP_INTERNAL_SERVER_ERROR
local INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
}
}


function _M.search(conf, search_body, httpc)
local body = {
Expand All @@ -30,7 +42,11 @@ function _M.search(conf, search_body, httpc)
}
}
}
local final_body = core.json.encode(body)
local final_body, err = core.json.encode(body)
if not final_body then
return nil, INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
Expand All @@ -41,7 +57,7 @@ function _M.search(conf, search_body, httpc)
})

if not res or not res.body then
return nil, internal_server_error, err
return nil, INTERNAL_SERVER_ERROR, err
end

if res.status ~= 200 then
Expand Down
9 changes: 3 additions & 6 deletions t/plugin/ai-rag.t
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ add_block_preprocessor(sub {
if ngx.req.get_method() ~= "POST" then
ngx.status = 400
ngx.say("Unsupported request method: ", ngx.req.get_method())
return
end
ngx.req.read_body()
local body, err = ngx.req.get_body_data()
Expand All @@ -62,12 +63,8 @@ add_block_preprocessor(sub {
return
end
if header_auth == "key" then
ngx.status = 200
ngx.say([[$embeddings]])
else
ngx.status = 401
end
ngx.status = 200
ngx.say([[$embeddings]])
}
}
Expand Down

0 comments on commit ceafa66

Please sign in to comment.