Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(llm): separate multi llm configs/models #112

Merged
merged 10 commits into from
Nov 21, 2024
2 changes: 1 addition & 1 deletion hugegraph-llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ This can be obtained from the `LLMs` class.
from hugegraph_llm.operators.kg_construction_task import KgBuilder

TEXT = ""
builder = KgBuilder(LLMs().get_llm())
builder = KgBuilder(LLMs().get_chat_llm())
(
builder
.import_schema(from_hugegraph="talent_graph").print_result()
Expand Down
1 change: 1 addition & 0 deletions hugegraph-llm/src/hugegraph_llm/api/rag_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def graph_config_api(req: GraphConfigRequest):
res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http")
return generate_response(RAGResponse(status_code=res, message="Missing Value"))

#TODO: restructure the implement of llm to three types, like "/config/chat_llm"
@router.post("/config/llm", status_code=status.HTTP_201_CREATED)
def llm_config_api(req: LLMConfigRequest):
settings.llm_type = req.llm_type
Expand Down
66 changes: 50 additions & 16 deletions hugegraph-llm/src/hugegraph_llm/config/config_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,76 @@ class ConfigData:
"""LLM settings"""

# env_path: Optional[str] = ".env"
llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai"
embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai"
chat_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
extract_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
text2gql_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai"
embedding_type: Optional[Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"]] = "openai"
reranker_type: Optional[Literal["cohere", "siliconflow"]] = None
# 1. OpenAI settings
openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_language_model: Optional[str] = "gpt-4o-mini"
openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_chat_language_model: Optional[str] = "gpt-4o-mini"
openai_extract_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_extract_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_extract_language_model: Optional[str] = "gpt-4o-mini"
openai_text2gql_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1")
openai_text2gql_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY")
openai_text2gql_language_model: Optional[str] = "gpt-4o-mini"
openai_embedding_api_base: Optional[str] = os.environ.get("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1")
openai_embedding_api_key: Optional[str] = os.environ.get("OPENAI_EMBEDDING_API_KEY")
openai_embedding_model: Optional[str] = "text-embedding-3-small"
openai_max_tokens: int = 4096
openai_chat_tokens: int = 4096
openai_extract_tokens: int = 4096
openai_text2gql_tokens: int = 4096
# 2. Rerank settings
cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank")
reranker_api_key: Optional[str] = None
reranker_model: Optional[str] = None
# 3. Ollama settings
ollama_host: Optional[str] = "127.0.0.1"
ollama_port: Optional[int] = 11434
ollama_language_model: Optional[str] = None
ollama_chat_host: Optional[str] = "127.0.0.1"
ollama_chat_port: Optional[int] = 11434
ollama_chat_language_model: Optional[str] = None
ollama_extract_host: Optional[str] = "127.0.0.1"
ollama_extract_port: Optional[int] = 11434
ollama_extract_language_model: Optional[str] = None
ollama_text2gql_host: Optional[str] = "127.0.0.1"
ollama_text2gql_port: Optional[int] = 11434
ollama_text2gql_language_model: Optional[str] = None
ollama_embedding_host: Optional[str] = "127.0.0.1"
ollama_embedding_port: Optional[int] = 11434
ollama_embedding_model: Optional[str] = None
# 4. QianFan/WenXin settings
qianfan_api_key: Optional[str] = None
qianfan_secret_key: Optional[str] = None
qianfan_access_token: Optional[str] = None
qianfan_chat_api_key: Optional[str] = None
qianfan_chat_secret_key: Optional[str] = None
qianfan_chat_access_token: Optional[str] = None
qianfan_extract_api_key: Optional[str] = None
qianfan_extract_secret_key: Optional[str] = None
qianfan_extract_access_token: Optional[str] = None
qianfan_text2gql_api_key: Optional[str] = None
qianfan_text2gql_secret_key: Optional[str] = None
qianfan_text2gql_access_token: Optional[str] = None
qianfan_embedding_api_key: Optional[str] = None
qianfan_embedding_secret_key: Optional[str] = None
# 4.1 URL settings
qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop"
qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/"
qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K"
qianfan_chat_language_model: Optional[str] = "ERNIE-Speed-128K"
qianfan_extract_language_model: Optional[str] = "ERNIE-Speed-128K"
qianfan_text2gql_language_model: Optional[str] = "ERNIE-Speed-128K"
qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/"
# refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details
qianfan_embedding_model: Optional[str] = "embedding-v1"
# TODO: To be confirmed, whether to configure
# 5. ZhiPu(GLM) settings
zhipu_api_key: Optional[str] = None
zhipu_language_model: Optional[str] = "glm-4"
zhipu_embedding_model: Optional[str] = "embedding-2"
zhipu_chat_api_key: Optional[str] = None
zhipu_chat_language_model: Optional[str] = "glm-4"
zhipu_chat_embedding_model: Optional[str] = "embedding-2"
zhipu_extract_api_key: Optional[str] = None
zhipu_extract_language_model: Optional[str] = "glm-4"
zhipu_extract_embedding_model: Optional[str] = "embedding-2"
zhipu_text2gql_api_key: Optional[str] = None
zhipu_text2gql_language_model: Optional[str] = "glm-4"
zhipu_text2gql_embedding_model: Optional[str] = "embedding-2"

"""HugeGraph settings"""
graph_ip: Optional[str] = "127.0.0.1"
Expand Down
76 changes: 38 additions & 38 deletions hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ def build_example_vector_index(temp_file):
else:
return "ERROR: please input json file."
builder = GremlinGenerator(
llm=LLMs().get_llm(),
llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding(),
)
return builder.example_index_build(examples).run()


def gremlin_generate(inp, use_schema, use_example, example_num, schema):
generator = GremlinGenerator(
llm=LLMs().get_llm(),
llm=LLMs().get_text2gql_llm(),
embedding=Embeddings().get_embedding(),
)
if use_example == "true":
Expand All @@ -58,64 +58,64 @@ def gremlin_generate(inp, use_schema, use_example, example_num, schema):
"""# HugeGraph LLM Text2Gremlin Demo"""
)
gr.Markdown("## Set up the LLM")
llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type,
llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama/local"], value=settings.text2gql_llm_type,
label="LLM")


@gr.render(inputs=[llm_dropdown])
def llm_settings(llm_type):
settings.llm_type = llm_type
settings.text2gql_llm_type = llm_type
if llm_type == "openai":
with gr.Row():
llm_config_input = [
gr.Textbox(value=settings.openai_api_key, label="api_key"),
gr.Textbox(value=settings.openai_api_base, label="api_base"),
gr.Textbox(value=settings.openai_language_model, label="model_name"),
gr.Textbox(value=str(settings.openai_max_tokens), label="max_token"),
gr.Textbox(value=settings.openai_text2gql_api_key, label="api_key"),
gr.Textbox(value=settings.openai_text2gql_api_base, label="api_base"),
gr.Textbox(value=settings.openai_text2gql_language_model, label="model_name"),
gr.Textbox(value=str(settings.openai_text2gql_tokens), label="max_token"),
]
elif llm_type == "qianfan_wenxin":
with gr.Row():
llm_config_input = [
gr.Textbox(value=settings.qianfan_api_key, label="api_key"),
gr.Textbox(value=settings.qianfan_secret_key, label="secret_key"),
gr.Textbox(value=settings.qianfan_text2gql_api_key, label="api_key"),
gr.Textbox(value=settings.qianfan_text2gql_secret_key, label="secret_key"),
gr.Textbox(value=settings.qianfan_chat_url, label="chat_url"),
gr.Textbox(value=settings.qianfan_language_model, label="model_name")
gr.Textbox(value=settings.qianfan_text2gql_language_model, label="model_name")
]
elif llm_type == "ollama":
elif llm_type == "ollama/local":
with gr.Row():
llm_config_input = [
gr.Textbox(value=settings.ollama_host, label="host"),
gr.Textbox(value=str(settings.ollama_port), label="port"),
gr.Textbox(value=settings.ollama_language_model, label="model_name"),
gr.Textbox(value=settings.ollama_text2gql_host, label="host"),
gr.Textbox(value=str(settings.ollama_text2gql_port), label="port"),
gr.Textbox(value=settings.ollama_text2gql_language_model, label="model_name"),
gr.Textbox(value="", visible=False)
]
else:
llm_config_input = []
llm_config_button = gr.Button("Apply Configuration")

def apply_configuration(arg1, arg2, arg3, arg4):
llm_option = settings.llm_type
llm_option = settings.text2gql_llm_type
if llm_option == "openai":
settings.openai_api_key = arg1
settings.openai_api_base = arg2
settings.openai_language_model = arg3
settings.openai_max_tokens = int(arg4)
settings.openai_text2gql_api_key = arg1
settings.openai_text2gql_api_base = arg2
settings.openai_text2gql_language_model = arg3
settings.openai_text2gql_tokens = int(arg4)
elif llm_option == "qianfan_wenxin":
settings.qianfan_api_key = arg1
settings.qianfan_secret_key = arg2
settings.qianfan_text2gql_api_key = arg1
settings.qianfan_text2gql_secret_key = arg2
settings.qianfan_chat_url = arg3
settings.qianfan_language_model = arg4
elif llm_option == "ollama":
settings.ollama_host = arg1
settings.ollama_port = int(arg2)
settings.ollama_language_model = arg3
settings.qianfan_text2gql_language_model = arg4
elif llm_option == "ollam/local":
settings.ollama_text2gql_host = arg1
settings.ollama_text2gql_port = int(arg2)
settings.ollama_text2gql_language_model = arg3
gr.Info("configured!")

llm_config_button.click(apply_configuration, inputs=llm_config_input) # pylint: disable=no-member

gr.Markdown("## Set up the Embedding")
embedding_dropdown = gr.Dropdown(
choices=["openai", "ollama"],
choices=["openai", "ollama/local"],
value=settings.embedding_type,
label="Embedding"
)
Expand All @@ -126,15 +126,15 @@ def embedding_settings(embedding_type):
if embedding_type == "openai":
with gr.Row():
embedding_config_input = [
gr.Textbox(value=settings.openai_api_key, label="api_key"),
gr.Textbox(value=settings.openai_api_base, label="api_base"),
gr.Textbox(value=settings.openai_text2gql_api_key, label="api_key"),
gr.Textbox(value=settings.openai_text2gql_api_base, label="api_base"),
gr.Textbox(value=settings.openai_embedding_model, label="model_name")
]
elif embedding_type == "ollama":
elif embedding_type == "ollama/local":
with gr.Row():
embedding_config_input = [
gr.Textbox(value=settings.ollama_host, label="host"),
gr.Textbox(value=str(settings.ollama_port), label="port"),
gr.Textbox(value=settings.ollama_text2gql_host, label="host"),
gr.Textbox(value=str(settings.ollama_text2gql_port), label="port"),
gr.Textbox(value=settings.ollama_embedding_model, label="model_name"),
]
else:
Expand All @@ -144,12 +144,12 @@ def embedding_settings(embedding_type):
def apply_configuration(arg1, arg2, arg3):
embedding_option = settings.embedding_type
if embedding_option == "openai":
settings.openai_api_key = arg1
settings.openai_api_base = arg2
settings.openai_text2gql_api_key = arg1
settings.openai_text2gql_api_base = arg2
settings.openai_embedding_model = arg3
elif embedding_option == "ollama":
settings.ollama_host = arg1
settings.ollama_port = int(arg2)
elif embedding_option == "ollama/local":
settings.ollama_text2gql_host = arg1
settings.ollama_text2gql_port = int(arg2)
settings.ollama_embedding_model = arg3
gr.Info("configured!")
# pylint: disable=no-member
Expand Down
Loading
Loading