Skip to content

Fast API 接口能力 #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/flask/llms_cache/data_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run():
res = requests.post(url, headers=headers, json=json.dumps(data))
res_text = res.text

print("data_insert:", res.status_code, res_text)

if __name__ == '__main__':
run()
1 change: 1 addition & 0 deletions examples/flask/llms_cache/data_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ def run():
res = requests.post(url, headers=headers, json=json.dumps(data))
res_text = res.text

print("data_query:", res.status_code, res_text)

if __name__ == '__main__':
run()
1 change: 1 addition & 0 deletions examples/flask/llms_cache/data_query_long.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def run():
res = requests.post(url, headers=headers, json=json.dumps(data))
res_text = res.text

print("data_query_long:", res.status_code, res_text)

if __name__ == '__main__':
run()
193 changes: 193 additions & 0 deletions fastapi4modelcache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
# -*- coding: utf-8 -*-
import time
import uvicorn
import asyncio
import logging
import configparser
import json
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from concurrent.futures import ThreadPoolExecutor
from starlette.responses import PlainTextResponse
import functools

from modelcache import cache
from modelcache.adapter import adapter
from modelcache.manager import CacheBase, VectorBase, get_data_manager
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
from modelcache.processor.pre import query_multi_splicing
from modelcache.processor.pre import insert_multi_splicing
from modelcache.utils.model_filter import model_blacklist_filter
from modelcache.embedding import Data2VecAudio

#创建一个FastAPI实例
app = FastAPI()

class RequestData(BaseModel):
type: str
scope: dict = None
query: str = None
chat_info: dict = None
remove_type: str = None
id_list: list = []

data2vec = Data2VecAudio()
mysql_config = configparser.ConfigParser()
mysql_config.read('modelcache/config/mysql_config.ini')

milvus_config = configparser.ConfigParser()
milvus_config.read('modelcache/config/milvus_config.ini')

# redis_config = configparser.ConfigParser()
# redis_config.read('modelcache/config/redis_config.ini')

# 初始化datamanager
data_manager = get_data_manager(
CacheBase("mysql", config=mysql_config),
VectorBase("milvus", dimension=data2vec.dimension, milvus_config=milvus_config)
)

# # 使用redis初始化datamanager
# data_manager = get_data_manager(
# CacheBase("mysql", config=mysql_config),
# VectorBase("redis", dimension=data2vec.dimension, redis_config=redis_config)
# )

cache.init(
embedding_func=data2vec.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
query_pre_embedding_func=query_multi_splicing,
insert_pre_embedding_func=insert_multi_splicing,
)

executor = ThreadPoolExecutor(max_workers=6)

# 异步保存查询信息
async def save_query_info(result, model, query, delta_time_log):
loop = asyncio.get_running_loop()
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
await loop.run_in_executor(None, func)



@app.get("/welcome", response_class=PlainTextResponse)
async def first_fastapi():
return "hello, modelcache!"

@app.post("/modelcache")
async def user_backend(request: Request):
try:
raw_body = await request.body()
# 解析字符串为JSON对象
if isinstance(raw_body, bytes):
raw_body = raw_body.decode("utf-8")
if isinstance(raw_body, str):
try:
# 尝试将字符串解析为JSON对象
request_data = json.loads(raw_body)
except json.JSONDecodeError as e:
# 如果无法解析,返回格式错误
result = {"errorCode": 101, "errorDesc": str(e), "cacheHit": False, "delta_time": 0, "hit_query": '',
"answer": ''}
asyncio.create_task(save_query_info(result, model='', query='', delta_time_log=0))
raise HTTPException(status_code=101, detail="Invalid JSON format")
else:
request_data = raw_body

# 确保request_data是字典对象
if isinstance(request_data, str):
try:
request_data = json.loads(request_data)
except json.JSONDecodeError:
raise HTTPException(status_code=101, detail="Invalid JSON format")

request_type = request_data.get('type')
model = None
if 'scope' in request_data:
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
query = request_data.get('query')
chat_info = request_data.get('chat_info')

if not request_type or request_type not in ['query', 'insert', 'remove', 'register']:
result = {"errorCode": 102,
"errorDesc": "type exception, should one of ['query', 'insert', 'remove', 'register']",
"cacheHit": False, "delta_time": 0, "hit_query": '', "answer": ''}
asyncio.create_task(save_query_info(result, model=model, query='', delta_time_log=0))
raise HTTPException(status_code=102, detail="Type exception, should be one of ['query', 'insert', 'remove', 'register']")

except Exception as e:
request_data = raw_body if 'raw_body' in locals() else None
result = {
"errorCode": 103,
"errorDesc": str(e),
"cacheHit": False,
"delta_time": 0,
"hit_query": '',
"answer": '',
"para_dict": request_data
}
return result


# model filter
filter_resp = model_blacklist_filter(model, request_type)
if isinstance(filter_resp, dict):
return filter_resp

if request_type == 'query':
try:
start_time = time.time()
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
delta_time = f"{round(time.time() - start_time, 2)}s"

if response is None:
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
elif response in ['adapt_query_exception']:
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
"hit_query": '', "answer": ''}
else:
answer = response['data']
hit_query = response['hitQuery']
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}

delta_time_log = round(time.time() - start_time, 2)
asyncio.create_task(save_query_info(result, model, query, delta_time_log))
return result
except Exception as e:
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
"hit_query": '', "answer": ''}
logging.info(f'result: {str(result)}')
return result

if request_type == 'insert':
try:
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
if response == 'success':
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
else:
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
except Exception as e:
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}

if request_type == 'remove':
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
if not isinstance(response, dict):
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}

state = response.get('status')
if state == 'success':
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
else:
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}

if request_type == 'register':
response = adapter.ChatCompletion.create_register(model=model)
if response in ['create_success', 'already_exists']:
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
else:
return {"errorCode": 502, "errorDesc": "", "response": response, "writeStatus": "exception"}

# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=5000)
162 changes: 162 additions & 0 deletions fastapi4modelcache_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
# -*- coding: utf-8 -*-
import time
import uvicorn
import asyncio
import logging
# import configparser
import json
from fastapi import FastAPI, Request, HTTPException
from pydantic import BaseModel
from concurrent.futures import ThreadPoolExecutor
from starlette.responses import PlainTextResponse
import functools

from modelcache import cache
from modelcache.adapter import adapter
from modelcache.manager import CacheBase, VectorBase, get_data_manager
from modelcache.similarity_evaluation.distance import SearchDistanceEvaluation
from modelcache.processor.pre import query_multi_splicing
from modelcache.processor.pre import insert_multi_splicing
from modelcache.utils.model_filter import model_blacklist_filter
from modelcache.embedding import Data2VecAudio

# 创建一个FastAPI实例
app = FastAPI()

class RequestData(BaseModel):
type: str
scope: dict = None
query: str = None
chat_info: list = None
remove_type: str = None
id_list: list = []

data2vec = Data2VecAudio()

data_manager = get_data_manager(CacheBase("sqlite"), VectorBase("faiss", dimension=data2vec.dimension))

cache.init(
embedding_func=data2vec.to_embeddings,
data_manager=data_manager,
similarity_evaluation=SearchDistanceEvaluation(),
query_pre_embedding_func=query_multi_splicing,
insert_pre_embedding_func=insert_multi_splicing,
)

executor = ThreadPoolExecutor(max_workers=6)

# 异步保存查询信息
async def save_query_info_fastapi(result, model, query, delta_time_log):
loop = asyncio.get_running_loop()
func = functools.partial(cache.data_manager.save_query_resp, result, model=model, query=json.dumps(query, ensure_ascii=False), delta_time=delta_time_log)
await loop.run_in_executor(None, func)



@app.get("/welcome", response_class=PlainTextResponse)
async def first_fastapi():
return "hello, modelcache!"

@app.post("/modelcache")
async def user_backend(request: Request):
try:
raw_body = await request.body()
# 解析字符串为JSON对象
if isinstance(raw_body, bytes):
raw_body = raw_body.decode("utf-8")
if isinstance(raw_body, str):
try:
# 尝试将字符串解析为JSON对象
request_data = json.loads(raw_body)
except json.JSONDecodeError:
# 如果无法解析,返回格式错误
raise HTTPException(status_code=400, detail="Invalid JSON format")
else:
request_data = raw_body

# 确保request_data是字典对象
if isinstance(request_data, str):
try:
request_data = json.loads(request_data)
except json.JSONDecodeError:
raise HTTPException(status_code=400, detail="Invalid JSON format")

request_type = request_data.get('type')
model = None
if 'scope' in request_data:
model = request_data['scope'].get('model', '').replace('-', '_').replace('.', '_')
query = request_data.get('query')
chat_info = request_data.get('chat_info')

if not request_type or request_type not in ['query', 'insert', 'remove', 'detox']:
raise HTTPException(status_code=400, detail="Type exception, should be one of ['query', 'insert', 'remove', 'detox']")

except Exception as e:
request_data = raw_body if 'raw_body' in locals() else None
result = {
"errorCode": 103,
"errorDesc": str(e),
"cacheHit": False,
"delta_time": 0,
"hit_query": '',
"answer": '',
"para_dict": request_data
}
return result


# model filter
filter_resp = model_blacklist_filter(model, request_type)
if isinstance(filter_resp, dict):
return filter_resp

if request_type == 'query':
try:
start_time = time.time()
response = adapter.ChatCompletion.create_query(scope={"model": model}, query=query)
delta_time = f"{round(time.time() - start_time, 2)}s"

if response is None:
result = {"errorCode": 0, "errorDesc": '', "cacheHit": False, "delta_time": delta_time, "hit_query": '', "answer": ''}
elif response in ['adapt_query_exception']:
# elif isinstance(response, str):
result = {"errorCode": 201, "errorDesc": response, "cacheHit": False, "delta_time": delta_time,
"hit_query": '', "answer": ''}
else:
answer = response['data']
hit_query = response['hitQuery']
result = {"errorCode": 0, "errorDesc": '', "cacheHit": True, "delta_time": delta_time, "hit_query": hit_query, "answer": answer}

delta_time_log = round(time.time() - start_time, 2)
asyncio.create_task(save_query_info_fastapi(result, model, query, delta_time_log))
return result
except Exception as e:
result = {"errorCode": 202, "errorDesc": str(e), "cacheHit": False, "delta_time": 0,
"hit_query": '', "answer": ''}
logging.info(f'result: {str(result)}')
return result

if request_type == 'insert':
try:
response = adapter.ChatCompletion.create_insert(model=model, chat_info=chat_info)
if response == 'success':
return {"errorCode": 0, "errorDesc": "", "writeStatus": "success"}
else:
return {"errorCode": 301, "errorDesc": response, "writeStatus": "exception"}
except Exception as e:
return {"errorCode": 303, "errorDesc": str(e), "writeStatus": "exception"}

if request_type == 'remove':
response = adapter.ChatCompletion.create_remove(model=model, remove_type=request_data.get("remove_type"), id_list=request_data.get("id_list"))
if not isinstance(response, dict):
return {"errorCode": 401, "errorDesc": "", "response": response, "removeStatus": "exception"}

state = response.get('status')
if state == 'success':
return {"errorCode": 0, "errorDesc": "", "response": response, "writeStatus": "success"}
else:
return {"errorCode": 402, "errorDesc": "", "response": response, "writeStatus": "exception"}

# TODO: 可以修改为在命令行中使用`uvicorn your_module_name:app --host 0.0.0.0 --port 5000 --reload`的命令启动
if __name__ == '__main__':
uvicorn.run(app, host='0.0.0.0', port=5000)
3 changes: 1 addition & 2 deletions modelcache/manager/scalar_data/sql_storage_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def insert_query_resp(self, query_resp, **kwargs):
hit_query = json.dumps(hit_query, ensure_ascii=False)

table_name = "modelcache_query_log"
insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (%s, %s, %s, %s, %s, %s, %s, %s)".format(table_name)

insert_sql = "INSERT INTO {} (error_code, error_desc, cache_hit, model, query, delta_time, hit_query, answer) VALUES (?, ?, ?, ?, ?, ?, ?, ?)".format(table_name)
conn = sqlite3.connect(self._url)
try:
cursor = conn.cursor()
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,5 @@ transformers==4.38.2
faiss-cpu==1.7.4
redis==5.0.1
modelscope==1.14.0
fastapi==0.115.5
uvicorn==0.32.0