Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use conversation template for api proxy, fix eventsource format #2383

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,14 @@ bash chat.sh
API example using Python Flask: [api_like_OAI.py](api_like_OAI.py)
This example must be used with server.cpp

requirements:

```shell
pip install flask flask-cors fschat # flask-cors and fschat are optional. flask-cors is used to allow cross-origin requests, fschat is used for integration of chat template
```

Run the server:

```sh
python api_like_OAI.py
```
Expand All @@ -204,6 +212,8 @@ After running the API server, you can use it in Python by setting the API base U
openai.api_base = "http://<Your api-server IP>:port"
```

For better integration with the model, it is recommended to utilize the `--chat-prompt-model` parameter when starting up the system, rather than relying solely on parameters like `--user-name`. This specific parameter accepts model names that have been registered within the [FastChat/conversation.py](https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py) file, an example would be `llama-2`.

Then you can utilize llama.cpp as an OpenAI's **chat.completion** or **text_completion** API

### Extending or building alternative Web Front End
Expand Down
76 changes: 53 additions & 23 deletions examples/server/api_like_OAI.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,20 @@
import requests
import time
import json

try:
from fastchat import conversation
except ImportError:
conversation = None

app = Flask(__name__)
try:
from flask_cors import CORS
CORS(app)
except ImportError:
pass

parser = argparse.ArgumentParser(description="An example of using server.cpp with a similar API to OAI. It must be used together with server.cpp.")
parser.add_argument("--chat-prompt-model", type=str, help="Set the model name of conversation template", default="")
parser.add_argument("--chat-prompt", type=str, help="the top prompt in chat completions(default: 'A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')", default='A chat between a curious user and an artificial intelligence assistant. The assistant follows the given rules no matter what.\\n')
parser.add_argument("--user-name", type=str, help="USER name in chat completions(default: '\\nUSER: ')", default="\\nUSER: ")
parser.add_argument("--ai-name", type=str, help="ASSISTANT name in chat completions(default: '\\nASSISTANT: ')", default="\\nASSISTANT: ")
Expand All @@ -29,25 +38,46 @@ def is_present(json, key):
return True


use_conversation_template = args.chat_prompt_model != "" and conversation is not None

#convert chat to prompt
def convert_chat(messages):
prompt = "" + args.chat_prompt.replace("\\n", "\n")

system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = args.stop.replace("\\n", "\n")
if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
stop_token = conv.stop_str
else:
stop_token = args.stop


for line in messages:
if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()
#convert chat to prompt
def convert_chat(messages):
if use_conversation_template:
conv = conversation.get_conv_template(args.chat_prompt_model)
for line in messages:
if (line["role"] == "system"):
try:
conv.set_system_message(line["content"])
except Exception:
pass
elif (line["role"] == "user"):
conv.append_message(conv.roles[0], line["content"])
elif (line["role"] == "assistant"):
conv.append_message(conv.roles[1], line["content"])
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
else:
prompt = "" + args.chat_prompt.replace("\\n", "\n")
system_n = args.system_name.replace("\\n", "\n")
user_n = args.user_name.replace("\\n", "\n")
ai_n = args.ai_name.replace("\\n", "\n")
stop = stop_token.replace("\\n", "\n")

for line in messages:
if (line["role"] == "system"):
prompt += f"{system_n}{line['content']}"
if (line["role"] == "user"):
prompt += f"{user_n}{line['content']}"
if (line["role"] == "assistant"):
prompt += f"{ai_n}{line['content']}{stop}"
prompt += ai_n.rstrip()

return prompt

Expand All @@ -69,11 +99,11 @@ def make_postData(body, chat=False, stream=False):
if(is_present(body, "mirostat_eta")): postData["mirostat_eta"] = body["mirostat_eta"]
if(is_present(body, "seed")): postData["seed"] = body["seed"]
if(is_present(body, "logit_bias")): postData["logit_bias"] = [[int(token), body["logit_bias"][token]] for token in body["logit_bias"].keys()]
if (args.stop != ""):
postData["stop"] = [args.stop]
if stop_token: # "" or None
postData["stop"] = [stop_token]
else:
postData["stop"] = []
if(is_present(body, "stop")): postData["stop"] += body["stop"]
if(is_present(body, "stop")): postData["stop"] += body["stop"] or []
postData["n_keep"] = -1
postData["stream"] = stream

Expand Down Expand Up @@ -173,12 +203,12 @@ def generate():
data = requests.request("POST", urllib.parse.urljoin(args.llama_api, "/completion"), data=json.dumps(postData), stream=True)
time_now = int(time.time())
resData = make_resData_stream({}, chat=True, time_now=time_now, start=True)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
for line in data.iter_lines():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=True, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')


Expand Down Expand Up @@ -212,7 +242,7 @@ def generate():
if line:
decoded_line = line.decode('utf-8')
resData = make_resData_stream(json.loads(decoded_line[6:]), chat=False, time_now=time_now)
yield 'data: {}\n'.format(json.dumps(resData))
yield 'data: {}\n\n'.format(json.dumps(resData))
return Response(generate(), mimetype='text/event-stream')

if __name__ == '__main__':
Expand Down