Skip to content

fix: 修复xinference向量模型添加失败的缺陷 #1566

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
try:
model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding')
model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'),
'embedding')
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, "API 域名无效")
exist = provider.get_model_info_by_name(model_list, model_name)
Expand All @@ -36,3 +37,4 @@ def build_model(self, model_info: Dict[str, object]):
return self

api_base = forms.TextInputField('API 域名', required=True)
api_key = forms.PasswordInputField('API Key', required=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

代码中并没有发现明显的不符合要求或问题的地方,但可能存在一些需要改进的空间:

  1. check_credentials() 函数中的异常处理:
    可以更清晰地说明错误的具体原因,使用有意义的对象来描述当前的状态信息,如 `raise AppApiException(Vali
    ##Suffix:

##Middle:
dCode.valid_error.value,
"Invalid API address")`。

  1. 将方法重载部分进行合并:
    如果 get_base_model_list() 和 get_model_info_by_name() 成为一个单独的方法,则可以将这两个函数整合在一起,减少重复编写和提高可读性。

  2. 添加单元测试或API调用文档:
    这有助于确保在实际开发中不会出现类似问题。建议增加单元测试以验证所有可能的情况是否都能正确处理。同时也可以通过API文档格式化的方式记录API接口的信息。这能帮助后续的查询者快速了解如何与该服务交互并获得预期的结果。

Original file line number Diff line number Diff line change
@@ -1,18 +1,26 @@
# coding=utf-8
import threading
from typing import Dict
from typing import Dict, Optional, List, Any

from langchain_community.embeddings import XinferenceEmbeddings
from langchain_core.embeddings import Embeddings

from setting.models_provider.base_model_provider import MaxKBBaseModel


class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings):
class XinferenceEmbedding(MaxKBBaseModel, Embeddings):
client: Any
server_url: Optional[str]
"""URL of the xinference server"""
model_uid: Optional[str]
"""UID of the launched model"""

@staticmethod
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
return XinferenceEmbedding(
model_uid=model_name,
server_url=model_credential.get('api_base'),
api_key=model_credential.get('api_key'),
)

def down_model(self):
Expand All @@ -22,3 +30,63 @@ def start_down_model_thread(self):
thread = threading.Thread(target=self.down_model)
thread.daemon = True
thread.start()

def __init__(
self, server_url: Optional[str] = None, model_uid: Optional[str] = None,
api_key: Optional[str] = None
):
try:
from xinference.client import RESTfulClient
except ImportError:
try:
from xinference_client import RESTfulClient
except ImportError as e:
raise ImportError(
"Could not import RESTfulClient from xinference. Please install it"
" with `pip install xinference` or `pip install xinference_client`."
) from e

if server_url is None:
raise ValueError("Please provide server URL")

if model_uid is None:
raise ValueError("Please provide the model UID")

self.server_url = server_url

self.model_uid = model_uid

self.api_key = api_key

self.client = RESTfulClient(server_url, api_key)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Embed a list of documents using Xinference.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""

model = self.client.get_model(self.model_uid)

embeddings = [
model.create_embedding(text)["data"][0]["embedding"] for text in texts
]
return [list(map(float, e)) for e in embeddings]

def embed_query(self, text: str) -> List[float]:
"""Embed a query of documents using Xinference.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""

model = self.client.get_model(self.model_uid)

embedding_res = model.create_embedding(text)

embedding = embedding_res["data"][0]["embedding"]

return list(map(float, embedding))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上述代码存在以下问题:

  1. 在创建子类时,没有导入必要的模块。
  2. 运行模型和下载模型文件的代码块被放在了下划线函数 down_model 的后面,并且它们没有对应的文档描述。

对于问题一:由于需要导入 restfulclient 模块,请在调用新建子类方法之前确保该插件已安装或通过 pip 安装它。对于问题二,请提供上下文中的更详细文档说明,以便于理解代码意图。

Loading