-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathstream_api.py
47 lines (36 loc) · 1.44 KB
/
stream_api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from sse_starlette import ServerSentEvent, EventSourceResponse
from transformers import AutoTokenizer, AutoModel
app = FastAPI()
class ChatBody(BaseModel):
question: str
def gen_chat(chat: ChatBody = None, his=None):
with torch.autocast("cuda"):
start_idx = 0
end_idx = 0
for response, history in model.stream_chat(tokenizer, chat.question, his):
end_idx = len(response)
this_response = response[start_idx:end_idx]
start_idx = end_idx
yield {"response": this_response, "finished": False}
yield {"finished": True}
@app.post("/chat_stream")
async def chat_stream(chat_body: ChatBody = None):
'''
request json format : {"question":"介绍一下你自己吧"}
:param chat_body:
:return:
'''
def decorate(generator):
for item in generator:
yield ServerSentEvent(item, event='answer')
return EventSourceResponse(decorate(gen_chat(chat_body)))
if __name__ == "__main__":
global tokenizer, model
PRE_TRAINED_MODEL_PATH = "E:\\AI\\chatglm3-6b"
tokenizer = AutoTokenizer.from_pretrained(PRE_TRAINED_MODEL_PATH, trust_remote_code=True)
model = AutoModel.from_pretrained(PRE_TRAINED_MODEL_PATH, trust_remote_code=True).quantize(8).cuda()
uvicorn.run(app, host="127.0.0.1", port=9999)