From c01b344447a7181c02ada9247c0241e6a2eec45e Mon Sep 17 00:00:00 2001 From: yhj <1454yhj@gmail.com> Date: Thu, 14 Nov 2024 15:28:32 +0800 Subject: [PATCH 1/5] split llm responsibilities --- hugegraph-llm/README.md | 2 +- .../src/hugegraph_llm/api/rag_api.py | 1 + .../src/hugegraph_llm/config/config_data.py | 66 ++++-- .../demo/gremlin_generate_web_demo.py | 20 +- .../demo/rag_demo/configs_block.py | 193 +++++++++++++----- .../models/embeddings/init_embedding.py | 6 +- .../src/hugegraph_llm/models/llms/init_llm.py | 80 ++++++-- .../src/hugegraph_llm/models/llms/ollama.py | 2 +- .../operators/document_op/word_extract.py | 2 +- .../hugegraph_llm/operators/graph_rag_task.py | 4 +- .../operators/llm_op/answer_synthesize.py | 2 +- .../operators/llm_op/keyword_extract.py | 2 +- .../hugegraph_llm/utils/graph_index_utils.py | 8 +- .../hugegraph_llm/utils/vector_index_utils.py | 2 +- 14 files changed, 280 insertions(+), 110 deletions(-) diff --git a/hugegraph-llm/README.md b/hugegraph-llm/README.md index e9d6f90a..e9b07cb9 100644 --- a/hugegraph-llm/README.md +++ b/hugegraph-llm/README.md @@ -99,7 +99,7 @@ This can be obtained from the `LLMs` class. from hugegraph_llm.operators.kg_construction_task import KgBuilder TEXT = "" - builder = KgBuilder(LLMs().get_llm()) + builder = KgBuilder(LLMs().get_chat_llm()) ( builder .import_schema(from_hugegraph="talent_graph").print_result() diff --git a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py index 83795506..bfdc6121 100644 --- a/hugegraph-llm/src/hugegraph_llm/api/rag_api.py +++ b/hugegraph-llm/src/hugegraph_llm/api/rag_api.py @@ -91,6 +91,7 @@ def graph_config_api(req: GraphConfigRequest): res = apply_graph_conf(req.ip, req.port, req.name, req.user, req.pwd, req.gs, origin_call="http") return generate_response(RAGResponse(status_code=res, message="Missing Value")) + #TODO: restructure the implement of llm to three types, like "/config/chat_llm" @router.post("/config/llm", status_code=status.HTTP_201_CREATED) def llm_config_api(req: LLMConfigRequest): settings.llm_type = req.llm_type diff --git a/hugegraph-llm/src/hugegraph_llm/config/config_data.py b/hugegraph-llm/src/hugegraph_llm/config/config_data.py index 3610aa2b..e23d1f57 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config_data.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config_data.py @@ -26,42 +26,76 @@ class ConfigData: """LLM settings""" # env_path: Optional[str] = ".env" - llm_type: Literal["openai", "ollama", "qianfan_wenxin", "zhipu"] = "openai" - embedding_type: Optional[Literal["openai", "ollama", "qianfan_wenxin", "zhipu"]] = "openai" + chat_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai" + extract_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai" + text2gql_llm_type: Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"] = "openai" + embedding_type: Optional[Literal["openai", "ollama/local", "qianfan_wenxin", "zhipu"]] = "openai" reranker_type: Optional[Literal["cohere", "siliconflow"]] = None # 1. OpenAI settings - openai_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") - openai_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") - openai_language_model: Optional[str] = "gpt-4o-mini" + openai_chat_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_chat_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_chat_language_model: Optional[str] = "gpt-4o-mini" + openai_extract_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_extract_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_extract_language_model: Optional[str] = "gpt-4o-mini" + openai_text2gql_api_base: Optional[str] = os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + openai_text2gql_api_key: Optional[str] = os.environ.get("OPENAI_API_KEY") + openai_text2gql_language_model: Optional[str] = "gpt-4o-mini" openai_embedding_api_base: Optional[str] = os.environ.get("OPENAI_EMBEDDING_BASE_URL", "https://api.openai.com/v1") openai_embedding_api_key: Optional[str] = os.environ.get("OPENAI_EMBEDDING_API_KEY") openai_embedding_model: Optional[str] = "text-embedding-3-small" - openai_max_tokens: int = 4096 + openai_chat_tokens: int = 4096 + openai_extract_tokens: int = 4096 + openai_text2gql_tokens: int = 4096 # 2. Rerank settings cohere_base_url: Optional[str] = os.environ.get("CO_API_URL", "https://api.cohere.com/v1/rerank") reranker_api_key: Optional[str] = None reranker_model: Optional[str] = None # 3. Ollama settings - ollama_host: Optional[str] = "127.0.0.1" - ollama_port: Optional[int] = 11434 - ollama_language_model: Optional[str] = None + ollama_chat_host: Optional[str] = "127.0.0.1" + ollama_chat_port: Optional[int] = 11434 + ollama_chat_language_model: Optional[str] = None + ollama_extract_host: Optional[str] = "127.0.0.1" + ollama_extract_port: Optional[int] = 11434 + ollama_extract_language_model: Optional[str] = None + ollama_text2gql_host: Optional[str] = "127.0.0.1" + ollama_text2gql_port: Optional[int] = 11434 + ollama_text2gql_language_model: Optional[str] = None + ollama_embedding_host: Optional[str] = "127.0.0.1" + ollama_embedding_port: Optional[int] = 11434 ollama_embedding_model: Optional[str] = None # 4. QianFan/WenXin settings - qianfan_api_key: Optional[str] = None - qianfan_secret_key: Optional[str] = None - qianfan_access_token: Optional[str] = None + qianfan_chat_api_key: Optional[str] = None + qianfan_chat_secret_key: Optional[str] = None + qianfan_chat_access_token: Optional[str] = None + qianfan_extract_api_key: Optional[str] = None + qianfan_extract_secret_key: Optional[str] = None + qianfan_extract_access_token: Optional[str] = None + qianfan_text2gql_api_key: Optional[str] = None + qianfan_text2gql_secret_key: Optional[str] = None + qianfan_text2gql_access_token: Optional[str] = None + qianfan_embedding_api_key: Optional[str] = None + qianfan_embedding_secret_key: Optional[str] = None # 4.1 URL settings qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop" qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/" - qianfan_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" + qianfan_chat_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" + qianfan_extract_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" + qianfan_text2gql_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/" # refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details qianfan_embedding_model: Optional[str] = "embedding-v1" # TODO: To be confirmed, whether to configure # 5. ZhiPu(GLM) settings - zhipu_api_key: Optional[str] = None - zhipu_language_model: Optional[str] = "glm-4" - zhipu_embedding_model: Optional[str] = "embedding-2" + zhipu_chat_api_key: Optional[str] = None + zhipu_chat_language_model: Optional[str] = "glm-4" + zhipu_chat_embedding_model: Optional[str] = "embedding-2" + zhipu_extract_api_key: Optional[str] = None + zhipu_extract_language_model: Optional[str] = "glm-4" + zhipu_extract_embedding_model: Optional[str] = "embedding-2" + zhipu_text2gql_api_key: Optional[str] = None + zhipu_text2gql_language_model: Optional[str] = "glm-4" + zhipu_text2gql_embedding_model: Optional[str] = "embedding-2" """HugeGraph settings""" graph_ip: Optional[str] = "127.0.0.1" diff --git a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py index 61663210..cb56ed59 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py @@ -34,7 +34,7 @@ def build_example_vector_index(temp_file): else: return "ERROR: please input json file." builder = GremlinGenerator( - llm=LLMs().get_llm(), + llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding(), ) return builder.example_index_build(examples).run() @@ -42,7 +42,7 @@ def build_example_vector_index(temp_file): def gremlin_generate(inp, use_schema, use_example, example_num, schema): generator = GremlinGenerator( - llm=LLMs().get_llm(), + llm=LLMs().get_text2gql_llm(), embedding=Embeddings().get_embedding(), ) if use_example == "true": @@ -58,13 +58,13 @@ def gremlin_generate(inp, use_schema, use_example, example_num, schema): """# HugeGraph LLM Text2Gremlin Demo""" ) gr.Markdown("## Set up the LLM") - llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type, + llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama/local"], value=settings.llm_type, label="LLM") @gr.render(inputs=[llm_dropdown]) def llm_settings(llm_type): - settings.llm_type = llm_type + settings.text2gql_llm_type = llm_type if llm_type == "openai": with gr.Row(): llm_config_input = [ @@ -81,7 +81,7 @@ def llm_settings(llm_type): gr.Textbox(value=settings.qianfan_chat_url, label="chat_url"), gr.Textbox(value=settings.qianfan_language_model, label="model_name") ] - elif llm_type == "ollama": + elif llm_type == "ollama/local": with gr.Row(): llm_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), @@ -94,7 +94,7 @@ def llm_settings(llm_type): llm_config_button = gr.Button("Apply Configuration") def apply_configuration(arg1, arg2, arg3, arg4): - llm_option = settings.llm_type + llm_option = settings.text2gql_llm_type if llm_option == "openai": settings.openai_api_key = arg1 settings.openai_api_base = arg2 @@ -105,7 +105,7 @@ def apply_configuration(arg1, arg2, arg3, arg4): settings.qianfan_secret_key = arg2 settings.qianfan_chat_url = arg3 settings.qianfan_language_model = arg4 - elif llm_option == "ollama": + elif llm_option == "ollama/local": settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_language_model = arg3 @@ -115,7 +115,7 @@ def apply_configuration(arg1, arg2, arg3, arg4): gr.Markdown("## Set up the Embedding") embedding_dropdown = gr.Dropdown( - choices=["openai", "ollama"], + choices=["openai", "ollama/local"], value=settings.embedding_type, label="Embedding" ) @@ -130,7 +130,7 @@ def embedding_settings(embedding_type): gr.Textbox(value=settings.openai_api_base, label="api_base"), gr.Textbox(value=settings.openai_embedding_model, label="model_name") ] - elif embedding_type == "ollama": + elif embedding_type == "ollama/local": with gr.Row(): embedding_config_input = [ gr.Textbox(value=settings.ollama_host, label="host"), @@ -147,7 +147,7 @@ def apply_configuration(arg1, arg2, arg3): settings.openai_api_key = arg1 settings.openai_api_base = arg2 settings.openai_embedding_model = arg3 - elif embedding_option == "ollama": + elif embedding_option == "ollam/local": settings.ollama_host = arg1 settings.ollama_port = int(arg2) settings.ollama_embedding_model = arg3 diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py index cf3e139c..059f6e68 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py @@ -25,6 +25,14 @@ from hugegraph_llm.config import settings from hugegraph_llm.utils.log import log +from functools import partial + +label_mapping = { + "chat LLM": "chat", + "extract LLM": "extract", + "text2gql LLM": "text2gql" +} +current_llm = "chat" def test_api_connection(url, method="GET", headers=None, params=None, body=None, auth=None, origin_call=None) -> int: @@ -60,11 +68,11 @@ def test_api_connection(url, method="GET", headers=None, params=None, body=None, return resp.status_code -def config_qianfan_model(arg1, arg2, arg3=None, origin_call=None) -> int: - settings.qianfan_api_key = arg1 - settings.qianfan_secret_key = arg2 +def config_qianfan_model(arg1, arg2, arg3=None, settings_prefix=None, origin_call=None) -> int: + setattr(settings, f"qianfan_{settings_prefix}_api_key", arg1) + setattr(settings, f"qianfan_{settings_prefix}_secret_key", arg2) if arg3: - settings.qianfan_language_model = arg3 + setattr(settings, f"qianfan_{settings_prefix}_language_model", arg3) params = { "grant_type": "client_credentials", "client_id": arg1, @@ -88,11 +96,11 @@ def apply_embedding_config(arg1, arg2, arg3, origin_call=None) -> int: data = {"model": arg3, "input": "test"} status_code = test_api_connection(test_url, method="POST", headers=headers, body=data, origin_call=origin_call) elif embedding_option == "qianfan_wenxin": - status_code = config_qianfan_model(arg1, arg2, origin_call=origin_call) + status_code = config_qianfan_model(arg1, arg2, settings_prefix="embedding", origin_call=origin_call) settings.qianfan_embedding_model = arg3 - elif embedding_option == "ollama": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) + elif embedding_option == "ollama/local": + settings.ollama_embedding_host = arg1 + settings.ollama_embedding_port = int(arg2) settings.ollama_embedding_model = arg3 status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) settings.update_env() @@ -158,15 +166,20 @@ def apply_graph_config(ip, port, name, user, pwd, gs, origin_call=None) -> int: # Different llm models have different parameters, so no meaningful argument names are given here -def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int: - llm_option = settings.llm_type +def apply_llm_config(current_llm, arg1, arg2, arg3, arg4, origin_call=None) -> int: + log.debug("current llm in apply_llm_config is %s", current_llm) + llm_option = getattr(settings, f"{current_llm}_llm_type") + log.debug("llm option in apply_llm_config is %s", llm_option) status_code = -1 + if llm_option == "openai": - settings.openai_api_key = arg1 - settings.openai_api_base = arg2 - settings.openai_language_model = arg3 - settings.openai_max_tokens = int(arg4) - test_url = settings.openai_api_base + "/chat/completions" + setattr(settings, f"openai_{current_llm}_api_key", arg1) + setattr(settings, f"openai_{current_llm}_api_base", arg2) + setattr(settings, f"openai_{current_llm}_language_model", arg3) + setattr(settings, f"openai_{current_llm}_tokens", int(arg4)) + + test_url = getattr(settings, f"openai_{current_llm}_api_base") + "/chat/completions" + log.debug(f"Type of openai {current_llm} max token is %s", type(arg4)) data = { "model": arg3, "temperature": 0.0, @@ -174,17 +187,23 @@ def apply_llm_config(arg1, arg2, arg3, arg4, origin_call=None) -> int: } headers = {"Authorization": f"Bearer {arg1}"} status_code = test_api_connection(test_url, method="POST", headers=headers, body=data, origin_call=origin_call) + elif llm_option == "qianfan_wenxin": - status_code = config_qianfan_model(arg1, arg2, arg3, origin_call) - elif llm_option == "ollama": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) - settings.ollama_language_model = arg3 + status_code = config_qianfan_model(arg1, arg2, arg3, settings_prefix=current_llm, origin_call=origin_call) + + elif llm_option == "ollama/local": + log.debug("Exec to ollama/local config") + setattr(settings, f"ollama_{current_llm}_host", arg1) + setattr(settings, f"ollama_{current_llm}_port", int(arg2)) + setattr(settings, f"ollama_{current_llm}_language_model", arg3) status_code = test_api_connection(f"http://{arg1}:{arg2}", origin_call=origin_call) + gr.Info("Configured!") settings.update_env() + return status_code + # TODO: refactor the function to reduce the number of statements & separate the logic def create_configs_block() -> list: # pylint: disable=R0915 (too-many-statements) @@ -203,49 +222,113 @@ def create_configs_block() -> list: with gr.Accordion("2. Set up the LLM.", open=False): gr.Markdown("> Tips: the openai sdk also support openai style api from other providers.") - llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama"], value=settings.llm_type, label="LLM") + with gr.Tab(label='chat LLM'): + chat_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], + value=getattr(settings, f"chat_llm_type"), label=f"chat LLM") + apply_llm_config_with_chat_op = partial(apply_llm_config, "chat") + @gr.render(inputs=[chat_llm_dropdown]) + def chat_llm_settings(llm_type): + settings.chat_llm_type = llm_type + llm_config_input = [] + if llm_type == "openai": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"openai_chat_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"openai_chat_api_base"), label="api_base"), + gr.Textbox(value=getattr(settings, f"openai_chat_language_model"), label="model_name"), + gr.Textbox(value=getattr(settings, f"openai_chat_tokens"), label="max_token"), + ] + elif llm_type == "ollama/local": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"ollama_chat_host"), label="host"), + gr.Textbox(value=str(getattr(settings, f"ollama_chat_port")), label="port"), + gr.Textbox(value=getattr(settings, f"ollama_chat_language_model"), label="model_name"), + gr.Textbox(value="", visible=False), + ] + elif llm_type == "qianfan_wenxin": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"qianfan_chat_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_chat_secret_key"), label="secret_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_chat_language_model"), label="model_name"), + gr.Textbox(value="", visible=False), + ] + else: + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] + llm_config_button = gr.Button("Apply configuration") + llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input) - @gr.render(inputs=[llm_dropdown]) - def llm_settings(llm_type): - settings.llm_type = llm_type - if llm_type == "openai": - with gr.Row(): + with gr.Tab(label='extract LLM'): + extract_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], + value=getattr(settings, f"extract_llm_type"), label=f"extract LLM") + apply_llm_config_with_extract_op = partial(apply_llm_config, "extract") + + @gr.render(inputs=[extract_llm_dropdown]) + def extract_llm_settings(llm_type): + settings.extract_llm_type = llm_type + llm_config_input = [] + if llm_type == "openai": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"openai_extract_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"openai_extract_api_base"), label="api_base"), + gr.Textbox(value=getattr(settings, f"openai_extract_language_model"), label="model_name"), + gr.Textbox(value=getattr(settings, f"openai_extract_tokens"), label="max_token"), + ] + elif llm_type == "ollama/local": llm_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key", type="password"), - gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_language_model, label="model_name"), - gr.Textbox(value=settings.openai_max_tokens, label="max_token"), + gr.Textbox(value=getattr(settings, f"ollama_extract_host"), label="host"), + gr.Textbox(value=str(getattr(settings, f"ollama_extract_port")), label="port"), + gr.Textbox(value=getattr(settings, f"ollama_extract_language_model"), label="model_name"), + gr.Textbox(value="", visible=False), ] - elif llm_type == "ollama": - with gr.Row(): + elif llm_type == "qianfan_wenxin": llm_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_language_model, label="model_name"), + gr.Textbox(value=getattr(settings, f"qianfan_extract_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_extract_secret_key"), label="secret_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_extract_language_model"), label="model_name"), gr.Textbox(value="", visible=False), ] - elif llm_type == "qianfan_wenxin": - with gr.Row(): + else: + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] + llm_config_button = gr.Button("Apply configuration") + llm_config_button.click(apply_llm_config_with_extract_op, inputs=llm_config_input) + with gr.Tab(label='text2gql LLM'): + text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], + value=getattr(settings, f"text2gql_llm_type"), label=f"text2gql LLM") + apply_llm_config_with_text2gql_op = partial(apply_llm_config, "text2gql") + + @gr.render(inputs=[text2gql_llm_dropdown]) + def text2gql_llm_settings(llm_type): + settings.text2gql_llm_type = llm_type + llm_config_input = [] + if llm_type == "openai": llm_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), - gr.Textbox(value=settings.qianfan_language_model, label="model_name"), + gr.Textbox(value=getattr(settings, f"openai_text2gql_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"openai_text2gql_api_base"), label="api_base"), + gr.Textbox(value=getattr(settings, f"openai_text2gql_language_model"), label="model_name"), + gr.Textbox(value=getattr(settings, f"openai_text2gql_tokens"), label="max_token"), + ] + elif llm_type == "ollama/local": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"ollama_text2gql_host"), label="host"), + gr.Textbox(value=str(getattr(settings, f"ollama_text2gql_port")), label="port"), + gr.Textbox(value=getattr(settings, f"ollama_text2gql_language_model"), label="model_name"), gr.Textbox(value="", visible=False), ] - else: - llm_config_input = [ - gr.Textbox(value="", visible=False), - gr.Textbox(value="", visible=False), - gr.Textbox(value="", visible=False), - gr.Textbox(value="", visible=False), - ] - llm_config_button = gr.Button("Apply configuration") - llm_config_button.click(apply_llm_config, inputs=llm_config_input) # pylint: disable=no-member + elif llm_type == "qianfan_wenxin": + llm_config_input = [ + gr.Textbox(value=getattr(settings, f"qianfan_text2gql_api_key"), label="api_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_text2gql_secret_key"), label="secret_key", type="password"), + gr.Textbox(value=getattr(settings, f"qianfan_text2gql_language_model"), label="model_name"), + gr.Textbox(value="", visible=False), + ] + else: + llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] + llm_config_button = gr.Button("Apply configuration") + llm_config_button.click(apply_llm_config_with_text2gql_op, inputs=llm_config_input) with gr.Accordion("3. Set up the Embedding.", open=False): embedding_dropdown = gr.Dropdown( - choices=["openai", "qianfan_wenxin", "ollama"], value=settings.embedding_type, label="Embedding" + choices=["openai", "qianfan_wenxin", "ollama/local"], value=settings.embedding_type, label="Embedding" ) @gr.render(inputs=[embedding_dropdown]) @@ -258,18 +341,18 @@ def embedding_settings(embedding_type): gr.Textbox(value=settings.openai_embedding_api_base, label="api_base"), gr.Textbox(value=settings.openai_embedding_model, label="model_name"), ] - elif embedding_type == "ollama": + elif embedding_type == "ollama/local": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), + gr.Textbox(value=settings.ollama_embedding_host, label="host"), + gr.Textbox(value=str(settings.ollama_embedding_port), label="port"), gr.Textbox(value=settings.ollama_embedding_model, label="model_name"), ] elif embedding_type == "qianfan_wenxin": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key", type="password"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key", type="password"), + gr.Textbox(value=settings.qianfan_embedding_api_key, label="api_key", type="password"), + gr.Textbox(value=settings.qianfan_embedding_secret_key, label="secret_key", type="password"), gr.Textbox(value=settings.qianfan_embedding_model, label="model_name"), ] else: diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index ded48af9..db9be47c 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -33,11 +33,11 @@ def get_embedding(self): api_key=settings.openai_embedding_api_key, api_base=settings.openai_embedding_api_base ) - if self.embedding_type == "ollama": + if self.embedding_type == "ollama/local": return OllamaEmbedding( model=settings.ollama_embedding_model, - host=settings.ollama_host, - port=settings.ollama_port + host=settings.ollama_embedding_host, + port=settings.ollama_embedding_port ) if self.embedding_type == "qianfan_wenxin": return QianFanEmbedding( diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py index 2c907489..cb7e73d1 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/init_llm.py @@ -24,28 +24,78 @@ class LLMs: def __init__(self): - self.llm_type = settings.llm_type + self.chat_llm_type = settings.chat_llm_type + self.extract_llm_type = settings.extract_llm_type + self.text2gql_llm_type = settings.text2gql_llm_type - def get_llm(self): - if self.llm_type == "qianfan_wenxin": + def get_chat_llm(self): + if self.chat_llm_type == "qianfan_wenxin": return QianfanClient( - model_name=settings.qianfan_language_model, - api_key=settings.qianfan_api_key, - secret_key=settings.qianfan_secret_key + model_name=settings.qianfan_chat_language_model, + api_key=settings.qianfan_chat_api_key, + secret_key=settings.qianfan_chat_secret_key ) - if self.llm_type == "openai": + if self.chat_llm_type == "openai": return OpenAIClient( - api_key=settings.openai_api_key, - api_base=settings.openai_api_base, - model_name=settings.openai_language_model, - max_tokens=settings.openai_max_tokens, + api_key=settings.openai_chat_api_key, + api_base=settings.openai_chat_api_base, + model_name=settings.openai_chat_language_model, + max_tokens=settings.openai_chat_tokens, ) - if self.llm_type == "ollama": - return OllamaClient(model=settings.ollama_language_model) - raise Exception("llm type is not supported !") + if self.chat_llm_type == "ollama/local": + return OllamaClient( + model=settings.ollama_chat_language_model, + host=settings.ollama_chat_host, + port=settings.ollama_chat_port, + ) + raise Exception("chat llm type is not supported !") + + def get_extract_llm(self): + if self.extract_llm_type == "qianfan_wenxin": + return QianfanClient( + model_name=settings.qianfan_extract_language_model, + api_key=settings.qianfan_extract_api_key, + secret_key=settings.qianfan_extract_secret_key + ) + if self.extract_llm_type == "openai": + return OpenAIClient( + api_key=settings.openai_extract_api_key, + api_base=settings.openai_extract_api_base, + model_name=settings.openai_extract_language_model, + max_tokens=settings.openai_extract_tokens, + ) + if self.extract_llm_type == "ollama/local": + return OllamaClient( + model=settings.ollama_extract_language_model, + host=settings.ollama_extract_host, + port=settings.ollama_extract_port, + ) + raise Exception("extract llm type is not supported !") + + def get_text2gql_llm(self): + if self.text2gql_llm_type == "qianfan_wenxin": + return QianfanClient( + model_name=settings.qianfan_text2gql_language_model, + api_key=settings.qianfan_text2gql_api_key, + secret_key=settings.qianfan_text2gql_secret_key + ) + if self.text2gql_llm_type == "openai": + return OpenAIClient( + api_key=settings.openai_text2gql_api_key, + api_base=settings.openai_text2gql_api_base, + model_name=settings.openai_text2gql_language_model, + max_tokens=settings.openai_text2gql_tokens, + ) + if self.text2gql_llm_type == "ollama/local": + return OllamaClient( + model=settings.ollama_text2gql_language_model, + host=settings.ollama_text2gql_host, + port=settings.ollama_text2gql_port, + ) + raise Exception("text2gql llm type is not supported !") if __name__ == "__main__": - client = LLMs().get_llm() + client = LLMs().get_chat_llm() print(client.generate(prompt="What is the capital of China?")) print(client.generate(messages=[{"role": "user", "content": "What is the capital of China?"}])) diff --git a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py index b7b01481..62f5ef26 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py +++ b/hugegraph-llm/src/hugegraph_llm/models/llms/ollama.py @@ -121,4 +121,4 @@ def max_allowed_token_length( def get_llm_type(self) -> str: """Returns the type of the LLM""" - return "ollama" + return "ollama/local" diff --git a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py index 546d5617..22b72651 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/document_op/word_extract.py @@ -45,7 +45,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context["query"] = self._query if self._llm is None: - self._llm = LLMs().get_llm() + self._llm = LLMs().get_extract_llm() assert isinstance(self._llm, BaseLLM), "Invalid LLM Object." if isinstance(context.get("language"), str): diff --git a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py index e6da8e09..ab57970f 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/graph_rag_task.py @@ -46,7 +46,9 @@ def __init__(self, llm: Optional[BaseLLM] = None, embedding: Optional[BaseEmbedd :param llm: Optional LLM model to use. :param embedding: Optional embedding model to use. """ - self._llm = llm or LLMs().get_llm() + self._chat_llm = llm or LLMs().get_chat_llm() + self._extract_llm = llm or LLMs().get_extract_llm() + self._text2gqlt_llm = llm or LLMs().get_text2gql_llm() self._embedding = embedding or Embeddings().get_embedding() self._operators: List[Any] = [] diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py index baf61e64..3c333f98 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/answer_synthesize.py @@ -60,7 +60,7 @@ def __init__( def run(self, context: Dict[str, Any]) -> Dict[str, Any]: if self._llm is None: - self._llm = LLMs().get_llm() + self._llm = LLMs().get_chat_llm() if self._question is None: self._question = context.get("query") or None diff --git a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py index 2cad98d5..8da25c01 100644 --- a/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py +++ b/hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py @@ -65,7 +65,7 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]: context["query"] = self._query if self._llm is None: - self._llm = LLMs().get_llm() + self._llm = LLMs().get_extract_llm() assert isinstance(self._llm, BaseLLM), "Invalid LLM Object." if isinstance(context.get("language"), str): diff --git a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py index a8ea1156..73d7057f 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/graph_index_utils.py @@ -34,7 +34,7 @@ def get_graph_index_info(): - builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) context = builder.fetch_graph_data().run() vector_index = VectorIndex.from_index_file(str(os.path.join(resource_path, settings.graph_name, "graph_vids"))) context["vid_index"] = { @@ -54,7 +54,7 @@ def clean_all_graph_index(): def extract_graph(input_file, input_text, schema, example_prompt) -> str: texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) if schema: try: @@ -77,7 +77,7 @@ def extract_graph(input_file, input_text, schema, example_prompt) -> str: def fit_vid_index(): - builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) builder.fetch_graph_data().build_vertex_id_semantic_index() log.debug("Operators: %s", builder.operators) try: @@ -94,7 +94,7 @@ def import_graph_data(data: str, schema: str) -> Union[str, Dict[str, Any]]: try: data_json = json.loads(data.strip()) log.debug("Import graph data: %s", data) - builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) if schema: try: schema = json.loads(schema.strip()) diff --git a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py index a7afdf8a..e955aac8 100644 --- a/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py +++ b/hugegraph-llm/src/hugegraph_llm/utils/vector_index_utils.py @@ -71,6 +71,6 @@ def clean_vector_index(): def build_vector_index(input_file, input_text): texts = read_documents(input_file, input_text) - builder = KgBuilder(LLMs().get_llm(), Embeddings().get_embedding(), get_hg_client()) + builder = KgBuilder(LLMs().get_chat_llm(), Embeddings().get_embedding(), get_hg_client()) context = builder.chunk_split(texts, "paragraph", "zh").build_vector_index().run() return json.dumps(context, ensure_ascii=False, indent=2) From 2f4b395bd800cc3c2137b51086dc436715549b77 Mon Sep 17 00:00:00 2001 From: yhj <1454yhj@gmail.com> Date: Thu, 14 Nov 2024 15:35:27 +0800 Subject: [PATCH 2/5] modify config func --- hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py index 059f6e68..cfe8aa44 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py @@ -192,7 +192,6 @@ def apply_llm_config(current_llm, arg1, arg2, arg3, arg4, origin_call=None) -> i status_code = config_qianfan_model(arg1, arg2, arg3, settings_prefix=current_llm, origin_call=origin_call) elif llm_option == "ollama/local": - log.debug("Exec to ollama/local config") setattr(settings, f"ollama_{current_llm}_host", arg1) setattr(settings, f"ollama_{current_llm}_port", int(arg2)) setattr(settings, f"ollama_{current_llm}_language_model", arg3) @@ -220,6 +219,7 @@ def create_configs_block() -> list: graph_config_button = gr.Button("Apply Configuration") graph_config_button.click(apply_graph_config, inputs=graph_config_input) # pylint: disable=no-member + #TODO : use OOP to restruact with gr.Accordion("2. Set up the LLM.", open=False): gr.Markdown("> Tips: the openai sdk also support openai style api from other providers.") with gr.Tab(label='chat LLM'): From 068de8d7ec60af364428a9234901d3f00bb89590 Mon Sep 17 00:00:00 2001 From: yhj <1454yhj@gmail.com> Date: Tue, 19 Nov 2024 20:37:46 +0800 Subject: [PATCH 3/5] fix text2gql model config in gremlin demo --- .../demo/gremlin_generate_web_demo.py | 62 +++++++++---------- .../models/embeddings/init_embedding.py | 4 +- 2 files changed, 33 insertions(+), 33 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py index cb56ed59..d21a54bf 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/gremlin_generate_web_demo.py @@ -58,7 +58,7 @@ def gremlin_generate(inp, use_schema, use_example, example_num, schema): """# HugeGraph LLM Text2Gremlin Demo""" ) gr.Markdown("## Set up the LLM") - llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama/local"], value=settings.llm_type, + llm_dropdown = gr.Dropdown(["openai", "qianfan_wenxin", "ollama/local"], value=settings.text2gql_llm_type, label="LLM") @@ -68,25 +68,25 @@ def llm_settings(llm_type): if llm_type == "openai": with gr.Row(): llm_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key"), - gr.Textbox(value=settings.openai_api_base, label="api_base"), - gr.Textbox(value=settings.openai_language_model, label="model_name"), - gr.Textbox(value=str(settings.openai_max_tokens), label="max_token"), + gr.Textbox(value=settings.openai_text2gql_api_key, label="api_key"), + gr.Textbox(value=settings.openai_text2gql_api_base, label="api_base"), + gr.Textbox(value=settings.openai_text2gql_language_model, label="model_name"), + gr.Textbox(value=str(settings.openai_text2gql_tokens), label="max_token"), ] elif llm_type == "qianfan_wenxin": with gr.Row(): llm_config_input = [ - gr.Textbox(value=settings.qianfan_api_key, label="api_key"), - gr.Textbox(value=settings.qianfan_secret_key, label="secret_key"), + gr.Textbox(value=settings.qianfan_text2gql_api_key, label="api_key"), + gr.Textbox(value=settings.qianfan_text2gql_secret_key, label="secret_key"), gr.Textbox(value=settings.qianfan_chat_url, label="chat_url"), - gr.Textbox(value=settings.qianfan_language_model, label="model_name") + gr.Textbox(value=settings.qianfan_text2gql_language_model, label="model_name") ] elif llm_type == "ollama/local": with gr.Row(): llm_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), - gr.Textbox(value=settings.ollama_language_model, label="model_name"), + gr.Textbox(value=settings.ollama_text2gql_host, label="host"), + gr.Textbox(value=str(settings.ollama_text2gql_port), label="port"), + gr.Textbox(value=settings.ollama_text2gql_language_model, label="model_name"), gr.Textbox(value="", visible=False) ] else: @@ -96,19 +96,19 @@ def llm_settings(llm_type): def apply_configuration(arg1, arg2, arg3, arg4): llm_option = settings.text2gql_llm_type if llm_option == "openai": - settings.openai_api_key = arg1 - settings.openai_api_base = arg2 - settings.openai_language_model = arg3 - settings.openai_max_tokens = int(arg4) + settings.openai_text2gql_api_key = arg1 + settings.openai_text2gql_api_base = arg2 + settings.openai_text2gql_language_model = arg3 + settings.openai_text2gql_tokens = int(arg4) elif llm_option == "qianfan_wenxin": - settings.qianfan_api_key = arg1 - settings.qianfan_secret_key = arg2 + settings.qianfan_text2gql_api_key = arg1 + settings.qianfan_text2gql_secret_key = arg2 settings.qianfan_chat_url = arg3 - settings.qianfan_language_model = arg4 - elif llm_option == "ollama/local": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) - settings.ollama_language_model = arg3 + settings.qianfan_text2gql_language_model = arg4 + elif llm_option == "ollam/local": + settings.ollama_text2gql_host = arg1 + settings.ollama_text2gql_port = int(arg2) + settings.ollama_text2gql_language_model = arg3 gr.Info("configured!") llm_config_button.click(apply_configuration, inputs=llm_config_input) # pylint: disable=no-member @@ -126,15 +126,15 @@ def embedding_settings(embedding_type): if embedding_type == "openai": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.openai_api_key, label="api_key"), - gr.Textbox(value=settings.openai_api_base, label="api_base"), + gr.Textbox(value=settings.openai_text2gql_api_key, label="api_key"), + gr.Textbox(value=settings.openai_text2gql_api_base, label="api_base"), gr.Textbox(value=settings.openai_embedding_model, label="model_name") ] elif embedding_type == "ollama/local": with gr.Row(): embedding_config_input = [ - gr.Textbox(value=settings.ollama_host, label="host"), - gr.Textbox(value=str(settings.ollama_port), label="port"), + gr.Textbox(value=settings.ollama_text2gql_host, label="host"), + gr.Textbox(value=str(settings.ollama_text2gql_port), label="port"), gr.Textbox(value=settings.ollama_embedding_model, label="model_name"), ] else: @@ -144,12 +144,12 @@ def embedding_settings(embedding_type): def apply_configuration(arg1, arg2, arg3): embedding_option = settings.embedding_type if embedding_option == "openai": - settings.openai_api_key = arg1 - settings.openai_api_base = arg2 + settings.openai_text2gql_api_key = arg1 + settings.openai_text2gql_api_base = arg2 settings.openai_embedding_model = arg3 - elif embedding_option == "ollam/local": - settings.ollama_host = arg1 - settings.ollama_port = int(arg2) + elif embedding_option == "ollama/local": + settings.ollama_text2gql_host = arg1 + settings.ollama_text2gql_port = int(arg2) settings.ollama_embedding_model = arg3 gr.Info("configured!") # pylint: disable=no-member diff --git a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py index db9be47c..63ea7ab9 100644 --- a/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py +++ b/hugegraph-llm/src/hugegraph_llm/models/embeddings/init_embedding.py @@ -42,8 +42,8 @@ def get_embedding(self): if self.embedding_type == "qianfan_wenxin": return QianFanEmbedding( model_name=settings.qianfan_embedding_model, - api_key=settings.qianfan_api_key, - secret_key=settings.qianfan_secret_key + api_key=settings.qianfan_embedding_api_key, + secret_key=settings.qianfan_embedding_secret_key ) raise Exception("embedding type is not supported !") From 8dc30438a6c4b0effae02306a6dd9390b2a44c8e Mon Sep 17 00:00:00 2001 From: yhj <1454yhj@gmail.com> Date: Thu, 21 Nov 2024 16:18:05 +0800 Subject: [PATCH 4/5] modify option name in config block --- .../demo/rag_demo/configs_block.py | 19 +++++++------------ 1 file changed, 7 insertions(+), 12 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py index cfe8aa44..39b036bb 100644 --- a/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py +++ b/hugegraph-llm/src/hugegraph_llm/demo/rag_demo/configs_block.py @@ -27,11 +27,6 @@ from hugegraph_llm.utils.log import log from functools import partial -label_mapping = { - "chat LLM": "chat", - "extract LLM": "extract", - "text2gql LLM": "text2gql" -} current_llm = "chat" @@ -221,10 +216,10 @@ def create_configs_block() -> list: #TODO : use OOP to restruact with gr.Accordion("2. Set up the LLM.", open=False): - gr.Markdown("> Tips: the openai sdk also support openai style api from other providers.") - with gr.Tab(label='chat LLM'): + gr.Markdown("> Tips: the openai option also support openai style api from other providers.") + with gr.Tab(label='chat'): chat_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], - value=getattr(settings, f"chat_llm_type"), label=f"chat LLM") + value=getattr(settings, f"chat_llm_type"), label=f"type") apply_llm_config_with_chat_op = partial(apply_llm_config, "chat") @gr.render(inputs=[chat_llm_dropdown]) def chat_llm_settings(llm_type): @@ -256,9 +251,9 @@ def chat_llm_settings(llm_type): llm_config_button = gr.Button("Apply configuration") llm_config_button.click(apply_llm_config_with_chat_op, inputs=llm_config_input) - with gr.Tab(label='extract LLM'): + with gr.Tab(label='extract'): extract_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], - value=getattr(settings, f"extract_llm_type"), label=f"extract LLM") + value=getattr(settings, f"extract_llm_type"), label=f"type") apply_llm_config_with_extract_op = partial(apply_llm_config, "extract") @gr.render(inputs=[extract_llm_dropdown]) @@ -290,9 +285,9 @@ def extract_llm_settings(llm_type): llm_config_input = [gr.Textbox(value="", visible=False) for _ in range(4)] llm_config_button = gr.Button("Apply configuration") llm_config_button.click(apply_llm_config_with_extract_op, inputs=llm_config_input) - with gr.Tab(label='text2gql LLM'): + with gr.Tab(label='text2gql'): text2gql_llm_dropdown = gr.Dropdown(choices=["openai", "qianfan_wenxin", "ollama/local"], - value=getattr(settings, f"text2gql_llm_type"), label=f"text2gql LLM") + value=getattr(settings, f"text2gql_llm_type"), label=f"type") apply_llm_config_with_text2gql_op = partial(apply_llm_config, "text2gql") @gr.render(inputs=[text2gql_llm_dropdown]) From b50156ce3029497636309e53ff072017d6ce4068 Mon Sep 17 00:00:00 2001 From: imbajin Date: Thu, 21 Nov 2024 17:41:14 +0800 Subject: [PATCH 5/5] update qianfan default model --- hugegraph-llm/src/hugegraph_llm/config/config_data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/hugegraph-llm/src/hugegraph_llm/config/config_data.py b/hugegraph-llm/src/hugegraph_llm/config/config_data.py index 6865ebd7..3e5711cb 100644 --- a/hugegraph-llm/src/hugegraph_llm/config/config_data.py +++ b/hugegraph-llm/src/hugegraph_llm/config/config_data.py @@ -79,9 +79,9 @@ class ConfigData: # 4.1 URL settings qianfan_url_prefix: Optional[str] = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop" qianfan_chat_url: Optional[str] = qianfan_url_prefix + "/chat/" - qianfan_chat_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" - qianfan_extract_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" - qianfan_text2gql_language_model: Optional[str] = "ERNIE-4.0-Turbo-8K" + qianfan_chat_language_model: Optional[str] = "ERNIE-Speed-128K" + qianfan_extract_language_model: Optional[str] = "ERNIE-Speed-128K" + qianfan_text2gql_language_model: Optional[str] = "ERNIE-Speed-128K" qianfan_embed_url: Optional[str] = qianfan_url_prefix + "/embeddings/" # refer https://cloud.baidu.com/doc/WENXINWORKSHOP/s/alj562vvu to get more details qianfan_embedding_model: Optional[str] = "embedding-v1"