Skip to content

Commit

Permalink
feat: ai-rag plugin (#11568)
Browse files Browse the repository at this point in the history
  • Loading branch information
shreemaan-abhishek authored Oct 16, 2024
1 parent 5eb9f6a commit 11c9d29
Show file tree
Hide file tree
Showing 11 changed files with 954 additions and 1 deletion.
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,11 @@ install: runtime
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers
$(ENV_INSTALL) apisix/plugins/ai-proxy/drivers/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-proxy/drivers

$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
$(ENV_INSTALL) apisix/plugins/ai-rag/embeddings/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/embeddings
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search
$(ENV_INSTALL) apisix/plugins/ai-rag/vector-search/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai-rag/vector-search

# ai-content-moderation plugin
$(ENV_INSTALL) -d $(ENV_INST_LUADIR)/apisix/plugins/ai
$(ENV_INSTALL) apisix/plugins/ai/*.lua $(ENV_INST_LUADIR)/apisix/plugins/ai
Expand Down
1 change: 1 addition & 0 deletions apisix/cli/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ local _M = {
"body-transformer",
"ai-prompt-template",
"ai-prompt-decorator",
"ai-rag",
"ai-content-moderation",
"proxy-mirror",
"proxy-rewrite",
Expand Down
156 changes: 156 additions & 0 deletions apisix/plugins/ai-rag.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local next = next
local require = require
local ngx_req = ngx.req

local http = require("resty.http")
local core = require("apisix.core")

local azure_openai_embeddings = require("apisix.plugins.ai-rag.embeddings.azure_openai").schema
local azure_ai_search_schema = require("apisix.plugins.ai-rag.vector-search.azure_ai_search").schema

local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_BAD_REQUEST = ngx.HTTP_BAD_REQUEST

local schema = {
type = "object",
properties = {
type = "object",
embeddings_provider = {
type = "object",
properties = {
azure_openai = azure_openai_embeddings
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_openai" },
maxProperties = 1,
},
vector_search_provider = {
type = "object",
properties = {
azure_ai_search = azure_ai_search_schema
},
-- ensure only one provider can be configured while implementing support for
-- other providers
required = { "azure_ai_search" },
maxProperties = 1
},
},
required = { "embeddings_provider", "vector_search_provider" }
}

local request_schema = {
type = "object",
properties = {
ai_rag = {
type = "object",
properties = {
vector_search = {},
embeddings = {},
},
required = { "vector_search", "embeddings" }
}
}
}

local _M = {
version = 0.1,
priority = 1060,
name = "ai-rag",
schema = schema,
}


function _M.check_schema(conf)
return core.schema.check(schema, conf)
end


function _M.access(conf, ctx)
local httpc = http.new()
local body_tab, err = core.request.get_json_request_body_table()
if not body_tab then
return HTTP_BAD_REQUEST, err
end
if not body_tab["ai_rag"] then
core.log.error("request body must have \"ai-rag\" field")
return HTTP_BAD_REQUEST
end

local embeddings_provider = next(conf.embeddings_provider)
local embeddings_provider_conf = conf.embeddings_provider[embeddings_provider]
local embeddings_driver = require("apisix.plugins.ai-rag.embeddings." .. embeddings_provider)

local vector_search_provider = next(conf.vector_search_provider)
local vector_search_provider_conf = conf.vector_search_provider[vector_search_provider]
local vector_search_driver = require("apisix.plugins.ai-rag.vector-search." ..
vector_search_provider)

local vs_req_schema = vector_search_driver.request_schema
local emb_req_schema = embeddings_driver.request_schema

request_schema.properties.ai_rag.properties.vector_search = vs_req_schema
request_schema.properties.ai_rag.properties.embeddings = emb_req_schema

local ok, err = core.schema.check(request_schema, body_tab)
if not ok then
core.log.error("request body fails schema check: ", err)
return HTTP_BAD_REQUEST
end

local embeddings, status, err = embeddings_driver.get_embeddings(embeddings_provider_conf,
body_tab["ai_rag"].embeddings, httpc)
if not embeddings then
core.log.error("could not get embeddings: ", err)
return status, err
end

local search_body = body_tab["ai_rag"].vector_search
search_body.embeddings = embeddings
local res, status, err = vector_search_driver.search(vector_search_provider_conf,
search_body, httpc)
if not res then
core.log.error("could not get vector_search result: ", err)
return status, err
end

-- remove ai_rag from request body because their purpose is served
-- also, these values will cause failure when proxying requests to LLM.
body_tab["ai_rag"] = nil

if not body_tab.messages then
body_tab.messages = {}
end

local augment = {
role = "user",
content = res
}
core.table.insert_tail(body_tab.messages, augment)

local req_body_json, err = core.json.encode(body_tab)
if not req_body_json then
return HTTP_INTERNAL_SERVER_ERROR, err
end

ngx_req.set_body_data(req_body_json)
end


return _M
88 changes: 88 additions & 0 deletions apisix/plugins/ai-rag/embeddings/azure_openai.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_OK = ngx.HTTP_OK
local type = type

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = { "endpoint", "api_key" }
}

function _M.get_embeddings(conf, body, httpc)
local body_tab, err = core.json.encode(body)
if not body_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = body_tab
})

if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= HTTP_OK then
return nil, res.status, res.body
end

local res_tab, err = core.json.decode(res.body)
if not res_tab then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if type(res_tab.data) ~= "table" or core.table.isempty(res_tab.data) then
return nil, HTTP_INTERNAL_SERVER_ERROR, res.body
end

local embeddings, err = core.json.encode(res_tab.data[1].embedding)
if not embeddings then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

return res_tab.data[1].embedding
end


_M.request_schema = {
type = "object",
properties = {
input = {
type = "string"
}
},
required = { "input" }
}

return _M
83 changes: 83 additions & 0 deletions apisix/plugins/ai-rag/vector-search/azure_ai_search.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
--
-- Licensed to the Apache Software Foundation (ASF) under one or more
-- contributor license agreements. See the NOTICE file distributed with
-- this work for additional information regarding copyright ownership.
-- The ASF licenses this file to You under the Apache License, Version 2.0
-- (the "License"); you may not use this file except in compliance with
-- the License. You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
--
local core = require("apisix.core")
local HTTP_INTERNAL_SERVER_ERROR = ngx.HTTP_INTERNAL_SERVER_ERROR
local HTTP_OK = ngx.HTTP_OK

local _M = {}

_M.schema = {
type = "object",
properties = {
endpoint = {
type = "string",
},
api_key = {
type = "string",
},
},
required = {"endpoint", "api_key"}
}


function _M.search(conf, search_body, httpc)
local body = {
vectorQueries = {
{
kind = "vector",
vector = search_body.embeddings,
fields = search_body.fields
}
}
}
local final_body, err = core.json.encode(body)
if not final_body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

local res, err = httpc:request_uri(conf.endpoint, {
method = "POST",
headers = {
["Content-Type"] = "application/json",
["api-key"] = conf.api_key,
},
body = final_body
})

if not res or not res.body then
return nil, HTTP_INTERNAL_SERVER_ERROR, err
end

if res.status ~= HTTP_OK then
return nil, res.status, res.body
end

return res.body
end


_M.request_schema = {
type = "object",
properties = {
fields = {
type = "string"
}
},
required = { "fields" }
}

return _M
1 change: 1 addition & 0 deletions conf/config.yaml.example
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,7 @@ plugins: # plugin list (sorted by priority)
- body-transformer # priority: 1080
- ai-prompt-template # priority: 1071
- ai-prompt-decorator # priority: 1070
- ai-rag # priority: 1060
- ai-content-moderation # priority: 1040 TODO: compare priority with other ai plugins
- proxy-mirror # priority: 1010
- proxy-rewrite # priority: 1008
Expand Down
3 changes: 2 additions & 1 deletion docs/en/latest/config.json
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@
"plugins/degraphql",
"plugins/body-transformer",
"plugins/ai-proxy",
"plugins/attach-consumer-label"
"plugins/attach-consumer-label",
"plugins/ai-rag"
]
},
{
Expand Down
Loading

0 comments on commit 11c9d29

Please sign in to comment.