diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py index c15bdf7e5a1..ddb5afd52e1 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/aws_bedrock_model_provider.py @@ -6,12 +6,10 @@ from setting.models_provider.base_model_provider import ( IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage ) +from setting.models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential +from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel -from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential -from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential -from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel -from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel from smartdoc.conf import PROJECT_DIR @@ -118,10 +116,21 @@ def _initialize_model_info(): BedrockLLMModelCredential, BedrockModel), ] + embedded_model_info_list = [ + _create_model_info( + 'amazon.titan-embed-text-v1', + 'Titan Embed Text 是 Amazon Titan Embed 系列中最大的嵌入模型,可以处理各种文本嵌入任务,如文本分类、文本相似度计算等。', + ModelTypeConst.EMBEDDING, + BedrockEmbeddingCredential, + BedrockEmbeddingModel + ), + ] model_info_manage = ModelInfoManage.builder() \ .append_model_info_list(model_info_list) \ .append_default_model_info(model_info_list[0]) \ + .append_model_info_list(embedded_model_info_list) \ + .append_default_model_info(embedded_model_info_list[0]) \ .build() return model_info_manage diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py index 40e36caca64..520960d7ac4 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/credential/embedding.py @@ -1,64 +1,64 @@ -import json +import os +import re from typing import Dict -from tencentcloud.common import credential -from tencentcloud.common.profile.client_profile import ClientProfile -from tencentcloud.common.profile.http_profile import HttpProfile -from tencentcloud.hunyuan.v20230901 import hunyuan_client, models - from common import forms from common.exception.app_exception import AppApiException from common.forms import BaseForm from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode +from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel + + +class BedrockEmbeddingCredential(BaseForm, BaseModelCredential): + + @staticmethod + def _update_aws_credentials(profile_name, access_key_id, secret_access_key): + credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials") + os.makedirs(os.path.dirname(credentials_path), exist_ok=True) + + content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else '' + pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*' + content = re.sub(pattern, '', content, flags=re.DOTALL) + if not re.search(rf'\[{profile_name}\]', content): + content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n" -class TencentEmbeddingCredential(BaseForm, BaseModelCredential): - @classmethod - def _validate_model_type(cls, model_type: str, provider) -> bool: + with open(credentials_path, 'w') as file: + file.write(content) + + def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, + raise_exception=False): model_type_list = provider.get_model_type_list() if not any(mt.get('value') == model_type for mt in model_type_list): - raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') - return True + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持') + return False - @classmethod - def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential: - for key in ['SecretId', 'SecretKey']: - if key not in model_credential: - raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段') - return credential.Credential(model_credential['SecretId'], model_credential['SecretKey']) + required_keys = ['region_name', 'access_key_id', 'secret_access_key'] + if not all(key in model_credential for key in required_keys): + if raise_exception: + raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}') + return False - @classmethod - def _test_credentials(cls, client, model_name: str): - req = models.GetEmbeddingRequest() - params = { - "Model": model_name, - "Input": "测试" - } - req.from_json_string(json.dumps(params)) try: - res = client.GetEmbedding(req) - print(res.to_json_string()) + self._update_aws_credentials('aws-profile', model_credential['access_key_id'], + model_credential['secret_access_key']) + model_credential['credentials_profile_name'] = 'aws-profile' + model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential) + aa = model.embed_query('你好') + print(aa) + except AppApiException: + raise except Exception as e: - raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') - - def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider, - raise_exception=True) -> bool: - try: - self._validate_model_type(model_type, provider) - cred = self._validate_credential(model_credential) - httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com") - clientProfile = ClientProfile(httpProfile=httpProfile) - client = hunyuan_client.HunyuanClient(cred, "", clientProfile) - self._test_credentials(client, model_name) - return True - except AppApiException as e: if raise_exception: - raise e + raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}') return False - def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]: - encrypted_secret_key = super().encryption(model.get('SecretKey', '')) - return {**model, 'SecretKey': encrypted_secret_key} + return True + + def encryption_dict(self, model: Dict[str, object]): + return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))} - SecretId = forms.PasswordInputField('SecretId', required=True) - SecretKey = forms.PasswordInputField('SecretKey', required=True) + region_name = forms.TextInputField('Region Name', required=True) + access_key_id = forms.TextInputField('Access Key ID', required=True) + secret_access_key = forms.PasswordInputField('Secret Access Key', required=True) diff --git a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py index a5bd0336a78..d08f62c6223 100644 --- a/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py +++ b/apps/setting/models_provider/impl/aws_bedrock_model_provider/model/embedding.py @@ -1,25 +1,56 @@ +from langchain_community.embeddings import BedrockEmbeddings + from setting.models_provider.base_model_provider import MaxKBBaseModel -from typing import Dict -import requests - - -class TencentEmbeddingModel(MaxKBBaseModel): - def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str): - self.secret_id = secret_id - self.secret_key = secret_key - self.api_base = api_base - self.model_name = model_name - - @staticmethod - def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs): - return TencentEmbeddingModel( - secret_id=model_credential.get('SecretId'), - secret_key=model_credential.get('SecretKey'), - api_base=model_credential.get('api_base'), - model_name=model_name, +from typing import Dict, List + + +class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings): + def __init__(self, model_id: str, region_name: str, credentials_profile_name: str, + **kwargs): + super().__init__(model_id=model_id, region_name=region_name, + credentials_profile_name=credentials_profile_name, **kwargs) + + @classmethod + def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str], + **model_kwargs) -> 'BedrockModel': + return cls( + model_id=model_name, + region_name=model_credential['region_name'], + credentials_profile_name=model_credential['credentials_profile_name'], ) + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Compute doc embeddings using a Bedrock model. + + Args: + texts: The list of texts to embed + + Returns: + List of embeddings, one for each text. + """ + results = [] + for text in texts: + response = self._embedding_func(text) + + if self.normalize: + response = self._normalize_vector(response) + + results.append(response) + + return results + + def embed_query(self, text: str) -> List[float]: + """Compute query embeddings using a Bedrock model. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + embedding = self._embedding_func(text) + + if self.normalize: + return self._normalize_vector(embedding) - def _generate_auth_token(self): - # Example method to generate an authentication token for the model API - return f"{self.secret_id}:{self.secret_key}" + return embedding