-
Notifications
You must be signed in to change notification settings - Fork 2.2k
fix: 修复xinference向量模型添加失败的缺陷 #1566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,26 @@ | ||
# coding=utf-8 | ||
import threading | ||
from typing import Dict | ||
from typing import Dict, Optional, List, Any | ||
|
||
from langchain_community.embeddings import XinferenceEmbeddings | ||
from langchain_core.embeddings import Embeddings | ||
|
||
from setting.models_provider.base_model_provider import MaxKBBaseModel | ||
|
||
|
||
class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings): | ||
class XinferenceEmbedding(MaxKBBaseModel, Embeddings): | ||
client: Any | ||
server_url: Optional[str] | ||
"""URL of the xinference server""" | ||
model_uid: Optional[str] | ||
"""UID of the launched model""" | ||
|
||
@staticmethod | ||
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): | ||
return XinferenceEmbedding( | ||
model_uid=model_name, | ||
server_url=model_credential.get('api_base'), | ||
api_key=model_credential.get('api_key'), | ||
) | ||
|
||
def down_model(self): | ||
|
@@ -22,3 +30,63 @@ def start_down_model_thread(self): | |
thread = threading.Thread(target=self.down_model) | ||
thread.daemon = True | ||
thread.start() | ||
|
||
def __init__( | ||
self, server_url: Optional[str] = None, model_uid: Optional[str] = None, | ||
api_key: Optional[str] = None | ||
): | ||
try: | ||
from xinference.client import RESTfulClient | ||
except ImportError: | ||
try: | ||
from xinference_client import RESTfulClient | ||
except ImportError as e: | ||
raise ImportError( | ||
"Could not import RESTfulClient from xinference. Please install it" | ||
" with `pip install xinference` or `pip install xinference_client`." | ||
) from e | ||
|
||
if server_url is None: | ||
raise ValueError("Please provide server URL") | ||
|
||
if model_uid is None: | ||
raise ValueError("Please provide the model UID") | ||
|
||
self.server_url = server_url | ||
|
||
self.model_uid = model_uid | ||
|
||
self.api_key = api_key | ||
|
||
self.client = RESTfulClient(server_url, api_key) | ||
|
||
def embed_documents(self, texts: List[str]) -> List[List[float]]: | ||
"""Embed a list of documents using Xinference. | ||
Args: | ||
texts: The list of texts to embed. | ||
Returns: | ||
List of embeddings, one for each text. | ||
""" | ||
|
||
model = self.client.get_model(self.model_uid) | ||
|
||
embeddings = [ | ||
model.create_embedding(text)["data"][0]["embedding"] for text in texts | ||
] | ||
return [list(map(float, e)) for e in embeddings] | ||
|
||
def embed_query(self, text: str) -> List[float]: | ||
"""Embed a query of documents using Xinference. | ||
Args: | ||
text: The text to embed. | ||
Returns: | ||
Embeddings for the text. | ||
""" | ||
|
||
model = self.client.get_model(self.model_uid) | ||
|
||
embedding_res = model.create_embedding(text) | ||
|
||
embedding = embedding_res["data"][0]["embedding"] | ||
|
||
return list(map(float, embedding)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 上述代码存在以下问题:
对于问题一:由于需要导入 |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
代码中并没有发现明显的不符合要求或问题的地方,但可能存在一些需要改进的空间:
check_credentials()
函数中的异常处理:可以更清晰地说明错误的具体原因,使用有意义的对象来描述当前的状态信息,如 `raise AppApiException(Vali
##Suffix:
##Middle:
dCode.valid_error.value,
"Invalid API address")`。
将方法重载部分进行合并:
如果
get_base_model_list() 和 get_model_info_by_name()
成为一个单独的方法,则可以将这两个函数整合在一起,减少重复编写和提高可读性。添加单元测试或API调用文档:
这有助于确保在实际开发中不会出现类似问题。建议增加单元测试以验证所有可能的情况是否都能正确处理。同时也可以通过API文档格式化的方式记录API接口的信息。这能帮助后续的查询者快速了解如何与该服务交互并获得预期的结果。