Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1831d18

Browse files
committedJan 3, 2025·
feat: Support Anthropic
1 parent d1b1aaa commit 1831d18

File tree

8 files changed

+286
-0
lines changed

8 files changed

+286
-0
lines changed
 

‎apps/setting/models_provider/constants/model_provider_constants.py

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from setting.models_provider.impl.aliyun_bai_lian_model_provider.aliyun_bai_lian_model_provider import \
1212
AliyunBaiLianModelProvider
13+
from setting.models_provider.impl.anthropic_model_provider.anthropic_model_provider import AnthropicModelProvider
1314
from setting.models_provider.impl.aws_bedrock_model_provider.aws_bedrock_model_provider import BedrockModelProvider
1415
from setting.models_provider.impl.azure_model_provider.azure_model_provider import AzureModelProvider
1516
from setting.models_provider.impl.deepseek_model_provider.deepseek_model_provider import DeepSeekModelProvider
@@ -47,3 +48,4 @@ class ModelProvideConstants(Enum):
4748
model_xinference_provider = XinferenceModelProvider()
4849
model_vllm_provider = VllmModelProvider()
4950
aliyun_bai_lian_model_provider = AliyunBaiLianModelProvider()
51+
model_anthropic_provider = AnthropicModelProvider()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: __init__.py.py
6+
@date:2024/3/28 16:25
7+
@desc:
8+
"""
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: openai_model_provider.py
6+
@date:2024/3/28 16:26
7+
@desc:
8+
"""
9+
import os
10+
11+
from common.util.file_util import get_file_content
12+
from setting.models_provider.base_model_provider import IModelProvider, ModelProvideInfo, ModelInfo, \
13+
ModelTypeConst, ModelInfoManage
14+
from setting.models_provider.impl.anthropic_model_provider.credential.image import AnthropicImageModelCredential
15+
from setting.models_provider.impl.anthropic_model_provider.credential.llm import AnthropicLLMModelCredential
16+
from setting.models_provider.impl.anthropic_model_provider.model.image import AnthropicImage
17+
from setting.models_provider.impl.anthropic_model_provider.model.llm import AnthropicChatModel
18+
from smartdoc.conf import PROJECT_DIR
19+
20+
openai_llm_model_credential = AnthropicLLMModelCredential()
21+
openai_image_model_credential = AnthropicImageModelCredential()
22+
23+
model_info_list = [
24+
ModelInfo('claude-3-opus-20240229', '', ModelTypeConst.LLM,
25+
openai_llm_model_credential, AnthropicChatModel
26+
),
27+
ModelInfo('claude-3-sonnet-20240229', '', ModelTypeConst.LLM, openai_llm_model_credential,
28+
AnthropicChatModel),
29+
ModelInfo('claude-3-haiku-20240307', '', ModelTypeConst.LLM, openai_llm_model_credential,
30+
AnthropicChatModel),
31+
ModelInfo('claude-3-5-sonnet-20240620', '', ModelTypeConst.LLM, openai_llm_model_credential,
32+
AnthropicChatModel),
33+
ModelInfo('claude-3-5-haiku-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
34+
AnthropicChatModel),
35+
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.LLM, openai_llm_model_credential,
36+
AnthropicChatModel),
37+
]
38+
39+
image_model_info = [
40+
ModelInfo('claude-3-5-sonnet-20241022', '', ModelTypeConst.IMAGE, openai_image_model_credential,
41+
AnthropicImage),
42+
]
43+
44+
model_info_manage = (
45+
ModelInfoManage.builder()
46+
.append_model_info_list(model_info_list)
47+
.append_default_model_info(model_info_list[0])
48+
.append_model_info_list(image_model_info)
49+
.append_default_model_info(image_model_info[0])
50+
.build()
51+
)
52+
53+
54+
class AnthropicModelProvider(IModelProvider):
55+
56+
def get_model_info_manage(self):
57+
return model_info_manage
58+
59+
def get_model_provide_info(self):
60+
return ModelProvideInfo(provider='model_anthropic_provider', name='Anthropic', icon=get_file_content(
61+
os.path.join(PROJECT_DIR, "apps", "setting", 'models_provider', 'impl', 'anthropic_model_provider', 'icon',
62+
'anthropic_icon_svg')))
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# coding=utf-8
2+
import base64
3+
import os
4+
from typing import Dict
5+
6+
from langchain_core.messages import HumanMessage
7+
8+
from common import forms
9+
from common.exception.app_exception import AppApiException
10+
from common.forms import BaseForm, TooltipLabel
11+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
12+
13+
class AnthropicImageModelParams(BaseForm):
14+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
15+
required=True, default_value=0.7,
16+
_min=0.1,
17+
_max=1.0,
18+
_step=0.01,
19+
precision=2)
20+
21+
max_tokens = forms.SliderField(
22+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
23+
required=True, default_value=800,
24+
_min=1,
25+
_max=100000,
26+
_step=1,
27+
precision=0)
28+
29+
30+
31+
class AnthropicImageModelCredential(BaseForm, BaseModelCredential):
32+
api_base = forms.TextInputField('API 域名', required=True)
33+
api_key = forms.PasswordInputField('API Key', required=True)
34+
35+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
36+
raise_exception=False):
37+
model_type_list = provider.get_model_type_list()
38+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
39+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
40+
41+
for key in ['api_base', 'api_key']:
42+
if key not in model_credential:
43+
if raise_exception:
44+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
45+
else:
46+
return False
47+
try:
48+
model = provider.get_model(model_type, model_name, model_credential)
49+
res = model.stream([HumanMessage(content=[{"type": "text", "text": "你好"}])])
50+
for chunk in res:
51+
print(chunk)
52+
except Exception as e:
53+
if isinstance(e, AppApiException):
54+
raise e
55+
if raise_exception:
56+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
57+
else:
58+
return False
59+
return True
60+
61+
def encryption_dict(self, model: Dict[str, object]):
62+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
63+
64+
def get_model_params_setting_form(self, model_name):
65+
return AnthropicImageModelParams()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# coding=utf-8
2+
"""
3+
@project: MaxKB
4+
@Author:虎
5+
@file: llm.py
6+
@date:2024/7/11 18:32
7+
@desc:
8+
"""
9+
from typing import Dict
10+
11+
from langchain_core.messages import HumanMessage
12+
13+
from common import forms
14+
from common.exception.app_exception import AppApiException
15+
from common.forms import BaseForm, TooltipLabel
16+
from setting.models_provider.base_model_provider import BaseModelCredential, ValidCode
17+
18+
19+
class AnthropicLLMModelParams(BaseForm):
20+
temperature = forms.SliderField(TooltipLabel('温度', '较高的数值会使输出更加随机,而较低的数值会使其更加集中和确定'),
21+
required=True, default_value=0.7,
22+
_min=0.1,
23+
_max=1.0,
24+
_step=0.01,
25+
precision=2)
26+
27+
max_tokens = forms.SliderField(
28+
TooltipLabel('输出最大Tokens', '指定模型可生成的最大token个数'),
29+
required=True, default_value=800,
30+
_min=1,
31+
_max=100000,
32+
_step=1,
33+
precision=0)
34+
35+
36+
class AnthropicLLMModelCredential(BaseForm, BaseModelCredential):
37+
38+
def is_valid(self, model_type: str, model_name, model_credential: Dict[str, object], model_params, provider,
39+
raise_exception=False):
40+
model_type_list = provider.get_model_type_list()
41+
if not any(list(filter(lambda mt: mt.get('value') == model_type, model_type_list))):
42+
raise AppApiException(ValidCode.valid_error.value, f'{model_type} 模型类型不支持')
43+
44+
for key in ['api_base', 'api_key']:
45+
if key not in model_credential:
46+
if raise_exception:
47+
raise AppApiException(ValidCode.valid_error.value, f'{key} 字段为必填字段')
48+
else:
49+
return False
50+
try:
51+
model = provider.get_model(model_type, model_name, model_credential)
52+
model.invoke([HumanMessage(content='你好')])
53+
except Exception as e:
54+
if isinstance(e, AppApiException):
55+
raise e
56+
if raise_exception:
57+
raise AppApiException(ValidCode.valid_error.value, f'校验失败,请检查参数是否正确: {str(e)}')
58+
else:
59+
return False
60+
return True
61+
62+
def encryption_dict(self, model: Dict[str, object]):
63+
return {**model, 'api_key': super().encryption(model.get('api_key', ''))}
64+
65+
api_base = forms.TextInputField('API 域名', required=True)
66+
api_key = forms.PasswordInputField('API Key', required=True)
67+
68+
def get_model_params_setting_form(self, model_name):
69+
return AnthropicLLMModelParams()
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from typing import Dict
2+
3+
from langchain_anthropic import ChatAnthropic
4+
5+
from common.config.tokenizer_manage_config import TokenizerManage
6+
from setting.models_provider.base_model_provider import MaxKBBaseModel
7+
8+
9+
def custom_get_token_ids(text: str):
10+
tokenizer = TokenizerManage.get_tokenizer()
11+
return tokenizer.encode(text)
12+
13+
14+
class AnthropicImage(MaxKBBaseModel, ChatAnthropic):
15+
16+
@staticmethod
17+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
18+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
19+
return AnthropicImage(
20+
model=model_name,
21+
anthropic_api_url=model_credential.get('api_base'),
22+
anthropic_api_key=model_credential.get('api_key'),
23+
# stream_options={"include_usage": True},
24+
streaming=True,
25+
**optional_params,
26+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# coding=utf-8
2+
"""
3+
@project: maxkb
4+
@Author:虎
5+
@file: llm.py
6+
@date:2024/4/18 15:28
7+
@desc:
8+
"""
9+
from typing import List, Dict
10+
11+
from langchain_anthropic import ChatAnthropic
12+
from langchain_core.messages import BaseMessage, get_buffer_string
13+
14+
from common.config.tokenizer_manage_config import TokenizerManage
15+
from setting.models_provider.base_model_provider import MaxKBBaseModel
16+
17+
18+
def custom_get_token_ids(text: str):
19+
tokenizer = TokenizerManage.get_tokenizer()
20+
return tokenizer.encode(text)
21+
22+
23+
class AnthropicChatModel(MaxKBBaseModel, ChatAnthropic):
24+
25+
@staticmethod
26+
def is_cache_model():
27+
return False
28+
29+
@staticmethod
30+
def new_instance(model_type, model_name, model_credential: Dict[str, object], **model_kwargs):
31+
optional_params = MaxKBBaseModel.filter_optional_params(model_kwargs)
32+
azure_chat_open_ai = AnthropicChatModel(
33+
model=model_name,
34+
anthropic_api_url=model_credential.get('api_base'),
35+
anthropic_api_key=model_credential.get('api_key'),
36+
**optional_params,
37+
custom_get_token_ids=custom_get_token_ids
38+
)
39+
return azure_chat_open_ai
40+
41+
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
42+
try:
43+
return super().get_num_tokens_from_messages(messages)
44+
except Exception as e:
45+
tokenizer = TokenizerManage.get_tokenizer()
46+
return sum([len(tokenizer.encode(get_buffer_string([m]))) for m in messages])
47+
48+
def get_num_tokens(self, text: str) -> int:
49+
try:
50+
return super().get_num_tokens(text)
51+
except Exception as e:
52+
tokenizer = TokenizerManage.get_tokenizer()
53+
return len(tokenizer.encode(text))

‎pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ httpx = "^0.27.0"
4141
httpx-sse = "^0.4.0"
4242
websockets = "^13.0"
4343
langchain-google-genai = "^1.0.3"
44+
langchain-anthropic= "^0.1.0"
4445
openpyxl = "^3.1.2"
4546
xlrd = "^2.0.1"
4647
gunicorn = "^22.0.0"

0 commit comments

Comments
 (0)
Please sign in to comment.