Skip to content

Commit

Permalink
feat: added local LLM API, changed the config displayed in the initia…
Browse files Browse the repository at this point in the history
…l gradio demo, and added openai's apibase configuration. (#41)

* add api_bot

* add local qwen api
change initial gradio_demo config shown
add apibase config on openai

* fix some questions

* fix some questions

* fix dependency

* go through ./style/code_format_and_analysis.sh and check some warnings

* 1. fixed dependencies of local llm api
2. provide some usage examples of local llm api in README

* fix the argument's descriptions

* 1. Removed unused function.
2. Added one configurable arguments max_new_tokens to command line argumnents.
3. Fixed some code style issues.

* fix some style issues

* 1. fix import error when python version > 3.12
  • Loading branch information
vichayturen authored Apr 26, 2024
1 parent 243f547 commit ec4b9da
Show file tree
Hide file tree
Showing 11 changed files with 413 additions and 110 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pylint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ jobs:
python -m pip install --upgrade pip
pip install pylint pytest
pip install -r ./hugegraph-llm/requirements.txt
pip install -r ./hugegraph-llm/llm_api/requirements.txt
pip install -r ./hugegraph-python-client/requirements.txt
- name: Analysing the code with pylint
run: |
Expand Down
19 changes: 19 additions & 0 deletions hugegraph-llm/llm_api/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Local LLM Api

## Usage
If hugegraph-llm wants to use local LLM, you can configure it as follows.

Run the program:
```shell
python main.py \
--model_name_or_path "Qwen/Qwen1.5-0.5B-Chat" \
--device "cuda" \
--port 7999
```

The LLM Section of [config.ini](../src/hugegraph_llm/config/config.ini) can be configured as follows:
```ini
[LLM]
type = local_api
llm_url = http://localhost:7999/v1/chat/completions
```
163 changes: 163 additions & 0 deletions hugegraph-llm/llm_api/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.


import gc
import sys
from contextlib import asynccontextmanager
from enum import Enum
from typing import Literal, List

import torch
import uvicorn
from fastapi import FastAPI, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2ForCausalLM, \
Qwen2Tokenizer, GenerationConfig
if sys.version_info >= (3, 12): # >=3.12
from typing import TypedDict
else: # <3.12
from typing_extensions import TypedDict


class Message(TypedDict):
role: Literal["system", "user", "assistant", "tool", "function"]
content: str


class ChatRequest(BaseModel):
messages: List[Message]


class ChatResponse(BaseModel):
content: str


class Role(Enum):
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
FUNCTION = "function"


class QwenChatModel:
def __init__(self, model_name_or_path: str, device: str = "cuda", max_new_tokens: int = 512,
generation_config: GenerationConfig = None):
self.model: Qwen2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name_or_path,
torch_dtype="auto",
device_map=device
)
self.tokenizer: Qwen2Tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.device = torch.device(device)
self.max_new_tokens = max_new_tokens
self.generation_config = generation_config

@torch.inference_mode()
async def achat(self, messages: List[Message]):
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
generated_ids = self.model.generate(
model_inputs.input_ids,
max_new_tokens=self.max_new_tokens,
)
generated_ids = [
output_ids[len(input_ids):]
for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response


def torch_gc() -> None:
r"""
Collects GPU memory.
"""
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()


@asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory
yield
torch_gc()


def create_app(chat_model: "QwenChatModel") -> "FastAPI":
app = FastAPI(lifespan=lifespan)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)

@app.post("/v1/chat/completions", response_model=ChatResponse, status_code=status.HTTP_200_OK)
async def create_chat_completion(request: ChatRequest):
if len(request.messages) == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")

if len(request.messages) % 2 == 0:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST,
detail="Only supports u/a/u/a/u...")

print("* ============= [input] ============= *")
print(request.messages[-1]["content"])

content = await chat_model.achat(
messages=request.messages,
)
print("* ============= [output] ============= *")
print(content)

return ChatResponse(content=content)

return app


def main():
import argparse
parser = argparse.ArgumentParser(description="Local LLM Api for Hugegraph LLM.")
parser.add_argument("--model_name_or_path", type=str, required=True, help="Model name or path")
parser.add_argument("--device", type=str, default="cpu", help="Device to use")
parser.add_argument("--port", type=int, default=7999, help="Port of the service")
parser.add_argument("--max_new_tokens", type=int, default=512,
help="The max number of tokens to generate")

args = parser.parse_args()

model_path = args.model_name_or_path
device = args.device
port = args.port
max_new_tokens = args.max_new_tokens
chat_model = QwenChatModel(model_path, device, max_new_tokens)
app = create_app(chat_model)
uvicorn.run(app, host="0.0.0.0", port=port, workers=1)


if __name__ == '__main__':
main()
6 changes: 6 additions & 0 deletions hugegraph-llm/llm_api/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
torch==2.1.2
transformers==4.39.3
fastapi==0.110.1
accelerate==0.29.2
charset-normalizer==3.3.2
uvicorn==0.29.0
25 changes: 23 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/config/config.ini
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,30 @@ pwd = admin
graph = hugegraph

[llm]
type = openai
## local llm
# type = local_api
# llm_url = http://localhost:7999/v1/chat/completions
#
## openai
# type = openai
# api_key = xxx
# api_base = xxx
# model_name = gpt-3.5-turbo-16k
# max_token = 4000
#
## ernie
# type = ernie
# api_key = xxx
# secret_key = xxx
# llm_url = xxx
# model_name = ernie
#
# type = openai
type = local_api
api_key = xxx
api_base = https://api.openai.com/v1
secret_key = xxx
llm_url = https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=
# llm_url = https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/chat/completions_pro?access_token=
llm_url = http://localhost:7999/v1/chat/completions
model_name = gpt-3.5-turbo-16k
max_token = 4000
78 changes: 78 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/llms/api_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import json
from typing import Optional, List, Dict, Any, Callable

import requests
from retry import retry

from hugegraph_llm.llms.base import BaseLLM
from hugegraph_llm.utils.config import Config
from hugegraph_llm.utils.constants import Constants


class ApiBotClient(BaseLLM):
def __init__(self):
self.c = Config(section=Constants.LLM_CONFIG)
self.base_url = self.c.get_llm_url()

@retry(tries=3, delay=1)
def generate(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
) -> str:
if messages is None:
assert prompt is not None, "Messages or prompt must be provided."
messages = [{"role": "user", "content": prompt}]
url = self.base_url

payload = json.dumps({
"messages": messages,
})
headers = {"Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=payload, timeout=30)
if response.status_code != 200:
raise Exception(
f"Request failed with code {response.status_code}, message: {response.text}"
)
response_json = json.loads(response.text)
return response_json["content"]

def generate_streaming(
self,
messages: Optional[List[Dict[str, Any]]] = None,
prompt: Optional[str] = None,
on_token_callback: Callable = None,
) -> str:
return self.generate(messages, prompt)

def num_tokens_from_string(self, string: str) -> int:
return len(string)

def max_allowed_token_length(self) -> int:
return 4096

def get_llm_type(self) -> str:
return "local_api"


if __name__ == "__main__":
client = ApiBotClient()
print(client.generate(prompt="What is the capital of China?"))
print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}]))
4 changes: 4 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/llms/init_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from hugegraph_llm.llms.openai import OpenAIChat
from hugegraph_llm.llms.ernie_bot import ErnieBotClient
from hugegraph_llm.llms.api_bot import ApiBotClient
from hugegraph_llm.utils.config import Config
from hugegraph_llm.utils.constants import Constants

Expand All @@ -32,9 +33,12 @@ def get_llm(self):
if self.config.get_llm_type() == "openai":
return OpenAIChat(
api_key=self.config.get_llm_api_key(),
api_base=self.config.get_llm_api_base(),
model_name=self.config.get_llm_model_name(),
max_tokens=self.config.get_llm_max_token(),
)
if self.config.get_llm_type() == "local_api":
return ApiBotClient()
raise Exception("llm type is not supported !")


Expand Down
3 changes: 3 additions & 0 deletions hugegraph-llm/src/hugegraph_llm/llms/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ class OpenAIChat(BaseLLM):
def __init__(
self,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
model_name: str = "gpt-3.5-turbo",
max_tokens: int = 1000,
temperature: float = 0.0,
) -> None:
openai.api_key = api_key or os.getenv("OPENAI_API_KEY")
if api_base is not None:
openai.api_base = api_base or os.getenv("OPENAI_API_BASE")
self.model = model_name
self.max_tokens = max_tokens
self.temperature = temperature
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def __init__(
self._prop_to_match = prop_to_match
self._schema = ""


def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
if self._client is None:
if isinstance(context.get("graph_client"), PyHugeClient):
Expand Down
8 changes: 6 additions & 2 deletions hugegraph-llm/src/hugegraph_llm/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import configparser
import os
from .constants import Constants


class Config:
Expand All @@ -35,8 +36,8 @@ def init_config_file(self, config_file):

if not os.path.exists(config_file):
config = configparser.ConfigParser()
config.add_section("llm")
config.add_section("hugegraph")
config.add_section(Constants.HUGEGRAPH_CONFIG)
config.add_section(Constants.LLM_CONFIG)
with open(config_file, "w", encoding="utf-8") as file:
config.write(file)
return config_file
Expand Down Expand Up @@ -68,6 +69,9 @@ def get_graph_name(self):
def get_llm_api_key(self):
return self.config.get(self.section, "api_key")

def get_llm_api_base(self):
return self.config.get(self.section, "api_base")

def get_llm_secret_key(self):
return self.config.get(self.section, "secret_key")

Expand Down
Loading

0 comments on commit ec4b9da

Please sign in to comment.