diff --git a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py index 200183e6c03..7cddb4f09da 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/credential/embedding.py @@ -15,7 +15,8 @@ def is_valid(self, model_type: str, model_name, model_credential: Dict[str, obje if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))): raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') try: - model_list = provider.get_base_model_list(model_credential.get('api_base'), 'embedding') + model_list = provider.get_base_model_list(model_credential.get('api_base'), model_credential.get('api_key'), + 'embedding') except Exception as e: raise AppApiException(ValidCode.valid_error.value, "API 域名无效") exist = provider.get_model_info_by_name(model_list, model_name) @@ -36,3 +37,4 @@ def build_model(self, model_info: Dict[str, object]): return self api_base = forms.TextInputField('API 域名', required=True) + api_key = forms.PasswordInputField('API Key', required=True) diff --git a/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py index 1cf34aaf875..935f4d23919 100644 --- a/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/xinference_model_provider/model/embedding.py @@ -1,18 +1,26 @@ # coding=utf-8 import threading -from typing import Dict +from typing import Dict, Optional, List, Any from langchain_community.embeddings import XinferenceEmbeddings +from langchain_core.embeddings import Embeddings from setting.models_provider.base_model_provider import MaxKBBaseModel -class XinferenceEmbedding(MaxKBBaseModel, XinferenceEmbeddings): +class XinferenceEmbedding(MaxKBBaseModel, Embeddings): + client: Any + server_url: Optional[str] + """URL of the xinference server""" + model_uid: Optional[str] + """UID of the launched model""" + @staticmethod def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs): return XinferenceEmbedding( model_uid=model_name, server_url=model_credential.get('api_base'), + api_key=model_credential.get('api_key'), ) def down_model(self): @@ -22,3 +30,63 @@ def start_down_model_thread(self): thread = threading.Thread(target=self.down_model) thread.daemon = True thread.start() + + def __init__( + self, server_url: Optional[str] = None, model_uid: Optional[str] = None, + api_key: Optional[str] = None + ): + try: + from xinference.client import RESTfulClient + except ImportError: + try: + from xinference_client import RESTfulClient + except ImportError as e: + raise ImportError( + "Could not import RESTfulClient from xinference. Please install it" + " with `pip install xinference` or `pip install xinference_client`." + ) from e + + if server_url is None: + raise ValueError("Please provide server URL") + + if model_uid is None: + raise ValueError("Please provide the model UID") + + self.server_url = server_url + + self.model_uid = model_uid + + self.api_key = api_key + + self.client = RESTfulClient(server_url, api_key) + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed a list of documents using Xinference. + Args: + texts: The list of texts to embed. + Returns: + List of embeddings, one for each text. + """ + + model = self.client.get_model(self.model_uid) + + embeddings = [ + model.create_embedding(text)["data"][0]["embedding"] for text in texts + ] + return [list(map(float, e)) for e in embeddings] + + def embed_query(self, text: str) -> List[float]: + """Embed a query of documents using Xinference. + Args: + text: The text to embed. + Returns: + Embeddings for the text. + """ + + model = self.client.get_model(self.model_uid) + + embedding_res = model.create_embedding(text) + + embedding = embedding_res["data"][0]["embedding"] + + return list(map(float, embedding))