Skip to content

Commit

Permalink
Support tool call in openai api server
Browse files Browse the repository at this point in the history
  • Loading branch information
li-plus committed Jan 20, 2024
1 parent 23442b0 commit baff67d
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 17 deletions.
19 changes: 12 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ For more options, please refer to [examples/langchain_client.py](examples/langch
Start an API server compatible with [OpenAI chat completions protocol](https://platform.openai.com/docs/api-reference/chat):
```sh
MODEL=./chatglm2-ggml.bin uvicorn chatglm_cpp.openai_api:app --host 127.0.0.1 --port 8000
MODEL=./chatglm3-ggml.bin uvicorn chatglm_cpp.openai_api:app --host 127.0.0.1 --port 8000
```
Test your endpoint with `curl`:
Expand All @@ -509,17 +509,22 @@ curl http://127.0.0.1:8000/v1/chat/completions -H 'Content-Type: application/jso
Use the OpenAI client to chat with your model:
```python
>>> import openai
>>> from openai import OpenAI
>>>
>>> openai.api_base = "http://127.0.0.1:8000/v1"
>>> response = openai.ChatCompletion.create(model="default-model", messages=[{"role": "user", "content": "你好"}])
>>> response["choices"][0]["message"]["content"]
'你好👋!我是人工智能助手 ChatGLM2-6B,很高兴见到你,欢迎问我任何问题。'
>>> client = OpenAI(base_url="http://127.0.0.1:8000/v1")
>>> response = client.chat.completions.create(model="default-model", messages=[{"role": "user", "content": "你好"}])
>>> response.choices[0].message.content
'你好👋!我是人工智能助手 ChatGLM3-6B,很高兴见到你,欢迎问我任何问题。'
```
For stream response, check out the example client script:
```sh
OPENAI_API_BASE=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --stream --prompt 你好
OPENAI_BASE_URL=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --stream --prompt 你好
```
Tool calling is also supported:
```sh
OPENAI_BASE_URL=http://127.0.0.1:8000/v1 python3 examples/openai_client.py --tool_call --prompt 上海天气怎么样
```
With this API server as backend, ChatGLM.cpp models can be seamlessly integrated into any frontend that uses OpenAI-style API, including [mckaywrigley/chatbot-ui](https://github.com/mckaywrigley/chatbot-ui), [fuergaosi233/wechat-chatgpt](https://github.com/fuergaosi233/wechat-chatgpt), [Yidadaa/ChatGPT-Next-Web](https://github.com/Yidadaa/ChatGPT-Next-Web), and more.
Expand Down
72 changes: 68 additions & 4 deletions chatglm_cpp/openai_api.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
import logging
import time
from typing import List, Literal, Optional, Union
from typing import Dict, List, Literal, Optional, Union

import chatglm_cpp
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, computed_field
Expand All @@ -14,18 +16,41 @@


class Settings(BaseSettings):
model: str = "chatglm-ggml.bin"
model: str = "chatglm3-ggml.bin"
num_threads: int = 0


class ToolCallFunction(BaseModel):
arguments: str
name: str


class ToolCall(BaseModel):
function: Optional[ToolCallFunction] = None
type: Literal["function"]


class ChatMessage(BaseModel):
role: Literal["system", "user", "assistant"]
content: str
tool_calls: Optional[List[ToolCall]] = None


class DeltaMessage(BaseModel):
role: Optional[Literal["system", "user", "assistant"]] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None


class ChatCompletionToolFunction(BaseModel):
description: Optional[str] = None
name: str
parameters: Dict


class ChatCompletionTool(BaseModel):
type: Literal["function"] = "function"
function: ChatCompletionToolFunction


class ChatCompletionRequest(BaseModel):
Expand All @@ -35,6 +60,7 @@ class ChatCompletionRequest(BaseModel):
top_p: float = Field(default=0.7, ge=0.0, le=1.0)
stream: bool = False
max_tokens: int = Field(default=2048, ge=0)
tools: Optional[List[ChatCompletionTool]] = None

model_config = {
"json_schema_extra": {"examples": [{"model": "default-model", "messages": [{"role": "user", "content": "你好"}]}]}
Expand All @@ -44,7 +70,7 @@ class ChatCompletionRequest(BaseModel):
class ChatCompletionResponseChoice(BaseModel):
index: int = 0
message: ChatMessage
finish_reason: Literal["stop", "length"] = "stop"
finish_reason: Literal["stop", "length", "function_call"]


class ChatCompletionResponseStreamChoice(BaseModel):
Expand Down Expand Up @@ -144,10 +170,25 @@ async def stream_chat_event_publisher(history, body):

@app.post("/v1/chat/completions")
async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionResponse:
def to_json_arguments(arguments):
def tool_call(**kwargs):
return kwargs

try:
return json.dumps(eval(arguments, dict(tool_call=tool_call)))
except Exception:
return arguments

if not body.messages:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "empty messages")

messages = [chatglm_cpp.ChatMessage(role=msg.role, content=msg.content) for msg in body.messages]
if body.tools:
system_content = (
"Answer the following questions as best as you can. You have access to the following tools:\n"
+ json.dumps([tool.model_dump() for tool in body.tools], indent=4)
)
messages.insert(0, chatglm_cpp.ChatMessage(role="system", content=system_content))

if body.stream:
generator = stream_chat_event_publisher(messages, body)
Expand All @@ -166,9 +207,28 @@ async def create_chat_completion(body: ChatCompletionRequest) -> ChatCompletionR
prompt_tokens = len(pipeline.tokenizer.encode_messages(messages, max_context_length))
completion_tokens = len(pipeline.tokenizer.encode(output.content, body.max_tokens))

finish_reason = "stop"
tool_calls = None
if output.tool_calls:
tool_calls = [
ToolCall(
type=tool_call.type,
function=ToolCallFunction(
name=tool_call.function.name, arguments=to_json_arguments(tool_call.function.arguments)
),
)
for tool_call in output.tool_calls
]
finish_reason = "function_call"

return ChatCompletionResponse(
object="chat.completion",
choices=[ChatCompletionResponseChoice(message=ChatMessage(role="assistant", content=output.content))],
choices=[
ChatCompletionResponseChoice(
message=ChatMessage(role="assistant", content=output.content, tool_calls=tool_calls),
finish_reason=finish_reason,
)
],
usage=ChatCompletionUsage(prompt_tokens=prompt_tokens, completion_tokens=completion_tokens),
)

Expand Down Expand Up @@ -199,3 +259,7 @@ class ModelList(BaseModel):
@app.get("/v1/models")
async def list_models() -> ModelList:
return ModelList(data=[ModelCard(id="gpt-3.5-turbo")])


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1)
39 changes: 33 additions & 6 deletions examples/openai_client.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,46 @@
import argparse

import openai
from openai import OpenAI

parser = argparse.ArgumentParser()
parser.add_argument("--stream", action="store_true")
parser.add_argument("--prompt", default="你好", type=str)
parser.add_argument("--tool_call", action="store_true")
args = parser.parse_args()

client = OpenAI()

tools = None
if args.tool_call:
tools = [
{
"type": "function",
"function": {
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
},
}
]

messages = [{"role": "user", "content": args.prompt}]
if args.stream:
response = openai.ChatCompletion.create(model="default-model", messages=messages, stream=True)
response = client.chat.completions.create(model="default-model", messages=messages, stream=True, tools=tools)
for chunk in response:
content = chunk["choices"][0]["delta"].get("content", "")
print(content, end="", flush=True)
content = chunk.choices[0].delta.content
if content is not None:
print(content, end="", flush=True)
print()
else:
response = openai.ChatCompletion.create(model="default-model", messages=messages)
print(response["choices"][0]["message"]["content"])
response = client.chat.completions.create(model="default-model", messages=messages, tools=tools)
print(response.choices[0].message.content)

0 comments on commit baff67d

Please sign in to comment.