-
-
Notifications
You must be signed in to change notification settings - Fork 14
/
run_web.py
115 lines (98 loc) · 3.58 KB
/
run_web.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers_stream_generator import init_stream_support
import torch
import random
from typing import Dict
init_stream_support()
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = model.eval()
class CompletionsRequest(BaseModel):
prompt: str = Field(title="input prompt")
parameters: Dict = Field(title="runtime parameters")
app = FastAPI(
title="Stream Restful Demo",
redoc_url=None,
docs=None,
)
@app.post(
"/inference_stream",
)
def inference_stream(request: CompletionsRequest):
with torch.no_grad():
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
pars = request.parameters
generator = model.generate(
input_ids,
max_new_tokens=pars.get("max_tokens") or 200,
do_sample=True,
do_stream=True,
top_k=pars.get("top_k") or 30,
top_p=pars.get("top_p") or 0.85,
temperature=pars.get("temperature") or 0.35,
repetition_penalty=pars.get("repetition_penalty") or 1.2,
early_stopping=True,
seed=0,
)
def _decode_generator(generator):
words = ""
last_tokens = []
last_decoded_tokens = []
for x in generator:
tokens = x.cpu().numpy().tolist()
tokens = last_tokens + tokens
word = tokenizer.decode(tokens, skip_special_tokens=True)
if "�" in word:
last_tokens = tokens
else:
if " " in tokenizer.decode(
last_decoded_tokens + tokens, skip_special_tokens=True
):
word = " " + word
last_tokens = []
last_decoded_tokens = tokens
yield json.dumps(
{
"code": 200,
"message": "success",
"data": [
{
"delta": {"content": word},
"logprobs": None,
"index": 0,
"finish_reason": None,
}
],
},
ensure_ascii=False,
)
words += word
log.info(uid + "$$ " + words)
return StreamingResponse(_decode_generator(generator))
@app.post(
"/inference",
)
def inference(request: CompletionsRequest):
with torch.no_grad():
input_ids = tokenizer(request.prompt, return_tensors="pt").input_ids
pars = request.parameters
result = model.generate(
input_ids,
max_new_tokens=pars.get("max_tokens") or 200,
do_sample=True,
top_k=pars.get("top_k") or 30,
top_p=pars.get("top_p") or 0.85,
temperature=pars.get("temperature") or 0.35,
repetition_penalty=pars.get("repetition_penalty") or 1.2,
early_stopping=True,
seed=0,
)
words = tokenizer.decode(result[0], skip_special_tokens=True)
print(words)
return {"result": words}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="localhost", port=8000)