Skip to content

feat: aws向量模型 #1399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
from setting.models_provider.base_model_provider import (
IModelProvider, ModelProvideInfo, ModelInfo, ModelTypeConst, ModelInfoManage
)
from setting.models_provider.impl.aws_bedrock_model_provider.credential.embedding import BedrockEmbeddingCredential
from setting.models_provider.impl.aws_bedrock_model_provider.credential.llm import BedrockLLMModelCredential
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel
from setting.models_provider.impl.aws_bedrock_model_provider.model.llm import BedrockModel
from setting.models_provider.impl.tencent_model_provider.credential.embedding import TencentEmbeddingCredential
from setting.models_provider.impl.tencent_model_provider.credential.llm import TencentLLMModelCredential
from setting.models_provider.impl.tencent_model_provider.model.embedding import TencentEmbeddingModel
from setting.models_provider.impl.tencent_model_provider.model.llm import TencentModel
from smartdoc.conf import PROJECT_DIR


Expand Down Expand Up @@ -118,10 +116,21 @@ def _initialize_model_info():
BedrockLLMModelCredential,
BedrockModel),
]
embedded_model_info_list = [
_create_model_info(
'amazon.titan-embed-text-v1',
'Titan Embed Text 是 Amazon Titan Embed 系列中最大的嵌入模型,可以处理各种文本嵌入任务,如文本分类、文本相似度计算等。',
ModelTypeConst.EMBEDDING,
BedrockEmbeddingCredential,
BedrockEmbeddingModel
),
]

model_info_manage = ModelInfoManage.builder() \
.append_model_info_list(model_info_list) \
.append_default_model_info(model_info_list[0]) \
.append_model_info_list(embedded_model_info_list) \
.append_default_model_info(embedded_model_info_list[0]) \
.build()

return model_info_manage
Expand Down
Original file line number Diff line number Diff line change
@@ -1,64 +1,64 @@
import json
import os
import re
from typing import Dict

from tencentcloud.common import credential
from tencentcloud.common.profile.client_profile import ClientProfile
from tencentcloud.common.profile.http_profile import HttpProfile
from tencentcloud.hunyuan.v20230901 import hunyuan_client, models

from common import forms
from common.exception.app_exception import AppApiException
from common.forms import BaseForm
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
from setting.models_provider.impl.aws_bedrock_model_provider.model.embedding import BedrockEmbeddingModel


class BedrockEmbeddingCredential(BaseForm, BaseModelCredential):

@staticmethod
def _update_aws_credentials(profile_name, access_key_id, secret_access_key):
credentials_path = os.path.join(os.path.expanduser("~"), ".aws", "credentials")
os.makedirs(os.path.dirname(credentials_path), exist_ok=True)

content = open(credentials_path, 'r').read() if os.path.exists(credentials_path) else ''
pattern = rf'\n*\[{profile_name}\]\n*(aws_access_key_id = .*)\n*(aws_secret_access_key = .*)\n*'
content = re.sub(pattern, '', content, flags=re.DOTALL)

if not re.search(rf'\[{profile_name}\]', content):
content += f"\n[{profile_name}]\naws_access_key_id = {access_key_id}\naws_secret_access_key = {secret_access_key}\n"

class TencentEmbeddingCredential(BaseForm, BaseModelCredential):
@classmethod
def _validate_model_type(cls, model_type: str, provider) -> bool:
with open(credentials_path, 'w') as file:
file.write(content)

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=False):
model_type_list = provider.get_model_type_list()
if not any(mt.get('value') == model_type for mt in model_type_list):
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return True
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
return False

@classmethod
def _validate_credential(cls, model_credential: Dict[str, object]) -> credential.Credential:
for key in ['SecretId', 'SecretKey']:
if key not in model_credential:
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
return credential.Credential(model_credential['SecretId'], model_credential['SecretKey'])
required_keys = ['region_name', 'access_key_id', 'secret_access_key']
if not all(key in model_credential for key in required_keys):
if raise_exception:
raise AppApiException(ValidCode.valid_error.value, f'以下字段为必填字段: {", ".join(required_keys)}')
return False

@classmethod
def _test_credentials(cls, client, model_name: str):
req = models.GetEmbeddingRequest()
params = {
"Model": model_name,
"Input": "测试"
}
req.from_json_string(json.dumps(params))
try:
res = client.GetEmbedding(req)
print(res.to_json_string())
self._update_aws_credentials('aws-profile', model_credential['access_key_id'],
model_credential['secret_access_key'])
model_credential['credentials_profile_name'] = 'aws-profile'
model: BedrockEmbeddingModel = provider.get_model(model_type, model_name, model_credential)
aa = model.embed_query('你好')
print(aa)
except AppApiException:
raise
except Exception as e:
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')

def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], provider,
raise_exception=True) -> bool:
try:
self._validate_model_type(model_type, provider)
cred = self._validate_credential(model_credential)
httpProfile = HttpProfile(endpoint="hunyuan.tencentcloudapi.com")
clientProfile = ClientProfile(httpProfile=httpProfile)
client = hunyuan_client.HunyuanClient(cred, "", clientProfile)
self._test_credentials(client, model_name)
return True
except AppApiException as e:
if raise_exception:
raise e
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
return False

def encryption_dict(self, model: Dict[str, object]) -> Dict[str, object]:
encrypted_secret_key = super().encryption(model.get('SecretKey', ''))
return {**model, 'SecretKey': encrypted_secret_key}
return True

def encryption_dict(self, model: Dict[str, object]):
return {**model, 'secret_access_key': super().encryption(model.get('secret_access_key', ''))}

SecretId = forms.PasswordInputField('SecretId', required=True)
SecretKey = forms.PasswordInputField('SecretKey', required=True)
region_name = forms.TextInputField('Region Name', required=True)
access_key_id = forms.TextInputField('Access Key ID', required=True)
secret_access_key = forms.PasswordInputField('Secret Access Key', required=True)
Original file line number Diff line number Diff line change
@@ -1,25 +1,56 @@
from langchain_community.embeddings import BedrockEmbeddings

from setting.models_provider.base_model_provider import MaxKBBaseModel
from typing import Dict
import requests


class TencentEmbeddingModel(MaxKBBaseModel):
def __init__(self, secret_id: str, secret_key: str, api_base: str, model_name: str):
self.secret_id = secret_id
self.secret_key = secret_key
self.api_base = api_base
self.model_name = model_name

@staticmethod
def new_instance(model_type: str, model_name: str, model_credential: Dict[str, str], **model_kwargs):
return TencentEmbeddingModel(
secret_id=model_credential.get('SecretId'),
secret_key=model_credential.get('SecretKey'),
api_base=model_credential.get('api_base'),
model_name=model_name,
from typing import Dict, List


class BedrockEmbeddingModel(MaxKBBaseModel, BedrockEmbeddings):
def __init__(self, model_id: str, region_name: str, credentials_profile_name: str,
**kwargs):
super().__init__(model_id=model_id, region_name=region_name,
credentials_profile_name=credentials_profile_name, **kwargs)

@classmethod
def new_instance(cls, model_type: str, model_name: str, model_credential: Dict[str, str],
**model_kwargs) -> 'BedrockModel':
return cls(
model_id=model_name,
region_name=model_credential['region_name'],
credentials_profile_name=model_credential['credentials_profile_name'],
)

def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute doc embeddings using a Bedrock model.

Args:
texts: The list of texts to embed

Returns:
List of embeddings, one for each text.
"""
results = []
for text in texts:
response = self._embedding_func(text)

if self.normalize:
response = self._normalize_vector(response)

results.append(response)

return results

def embed_query(self, text: str) -> List[float]:
"""Compute query embeddings using a Bedrock model.

Args:
text: The text to embed.

Returns:
Embeddings for the text.
"""
embedding = self._embedding_func(text)

if self.normalize:
return self._normalize_vector(embedding)

def _generate_auth_token(self):
# Example method to generate an authentication token for the model API
return f"{self.secret_id}:{self.secret_key}"
return embedding
Loading