From a385855e0766d2adf754ba811a7fe234f0b20b5b Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 21:54:51 +0800 Subject: [PATCH 1/6] use one method to get boto client for aws bedrock --- .../bedrock/get_bedrock_client.py | 21 +++++++++++++++++++ .../model_providers/bedrock/llm/llm.py | 9 ++------ .../model_providers/bedrock/rerank/rerank.py | 11 ++-------- .../bedrock/text_embedding/text_embedding.py | 12 ++--------- 4 files changed, 27 insertions(+), 26 deletions(-) create mode 100644 api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py new file mode 100644 index 00000000000000..0b463026741f5c --- /dev/null +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -0,0 +1,21 @@ +import boto3 +from botocore.config import Config + + +def get_bedrock_client(service_name, credentials=None): + client_config = Config(region_name=credentials["aws_region"]) + aws_access_key_id = (credentials["aws_access_key_id"],) + aws_secret_access_key = credentials["aws_secret_access_key"] + if aws_access_key_id and aws_secret_access_key: + # 使用 AKSK 方式 + client = boto3.client( + service_name=service_name, + config=client_config, + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + ) + else: + # 使用 IAM 角色方式 + client = boto3.client(service_name=service_name, config=client_config) + + return client diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index e6e8a765ee9e05..fd8903ee8e5713 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -7,6 +7,7 @@ # 3rd import import boto3 +from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client from botocore.config import Config from botocore.exceptions import ( ClientError, @@ -173,13 +174,7 @@ def _generate_with_converse( :param stream: is stream response :return: full response or stream response chunk generator result """ - bedrock_client = boto3.client( - service_name="bedrock-runtime", - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - region_name=credentials["aws_region"], - ) - + bedrock_client = get_bedrock_client("bedrock-runtime", credentials) system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages) inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop) diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index 397f65e8c960c8..47d72836256595 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -1,7 +1,6 @@ from typing import Optional -import boto3 -from botocore.config import Config +from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( @@ -48,13 +47,7 @@ def _invoke( return RerankResult(model=model, docs=docs) # initialize client - client_config = Config(region_name=credentials["aws_region"]) - bedrock_runtime = boto3.client( - service_name="bedrock-agent-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id", ""), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials) queries = [{"type": "TEXT", "textQuery": {"text": query}}] text_sources = [] for text in docs: diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 2f998d8bdaee90..26d2982ba32411 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -3,8 +3,7 @@ import time from typing import Optional -import boto3 -from botocore.config import Config +from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client from botocore.exceptions import ( ClientError, EndpointConnectionError, @@ -48,14 +47,7 @@ def _invoke( :param input_type: input type :return: embeddings result """ - client_config = Config(region_name=credentials["aws_region"]) - - bedrock_runtime = boto3.client( - service_name="bedrock-runtime", - config=client_config, - aws_access_key_id=credentials.get("aws_access_key_id"), - aws_secret_access_key=credentials.get("aws_secret_access_key"), - ) + bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials) embeddings = [] token_usage = 0 From 26a8c65aa09785abe5ac9ad4bbaea73ebdcb6cfb Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 21:56:31 +0800 Subject: [PATCH 2/6] fix aws_access_key_id --- .../model_runtime/model_providers/bedrock/get_bedrock_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index 0b463026741f5c..a9684db7c795c6 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -4,7 +4,7 @@ def get_bedrock_client(service_name, credentials=None): client_config = Config(region_name=credentials["aws_region"]) - aws_access_key_id = (credentials["aws_access_key_id"],) + aws_access_key_id = credentials["aws_access_key_id"] aws_secret_access_key = credentials["aws_secret_access_key"] if aws_access_key_id and aws_secret_access_key: # 使用 AKSK 方式 From 381e6be0813a36b2bb416fee9728127e24930f2b Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 22:20:24 +0800 Subject: [PATCH 3/6] fix import --- api/core/model_runtime/model_providers/bedrock/llm/llm.py | 3 +-- .../model_runtime/model_providers/bedrock/rerank/rerank.py | 3 +-- .../model_providers/bedrock/text_embedding/text_embedding.py | 3 +-- 3 files changed, 3 insertions(+), 6 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index fd8903ee8e5713..186ab1fe2b4f69 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -7,7 +7,6 @@ # 3rd import import boto3 -from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client from botocore.config import Config from botocore.exceptions import ( ClientError, @@ -16,7 +15,6 @@ ServiceNotInRegionError, UnknownServiceError, ) - # local import from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta @@ -41,6 +39,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object. diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index 47d72836256595..e134db646f3d39 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -1,7 +1,5 @@ from typing import Optional -from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client - from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_runtime.errors.invoke import ( InvokeAuthorizationError, @@ -13,6 +11,7 @@ ) from core.model_runtime.errors.validate import CredentialsValidateFailedError from core.model_runtime.model_providers.__base.rerank_model import RerankModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client class BedrockRerankModel(RerankModel): diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index 26d2982ba32411..edd33c7085d5c3 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -3,7 +3,6 @@ import time from typing import Optional -from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client from botocore.exceptions import ( ClientError, EndpointConnectionError, @@ -11,7 +10,6 @@ ServiceNotInRegionError, UnknownServiceError, ) - from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult @@ -24,6 +22,7 @@ InvokeServerUnavailableError, ) from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client logger = logging.getLogger(__name__) From 4bae2607f95657a16a59c8822735fbf9a8b00355 Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 22:25:09 +0800 Subject: [PATCH 4/6] fix comments --- .../model_providers/bedrock/get_bedrock_client.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index a9684db7c795c6..a19ffbb20a6a9e 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -7,7 +7,7 @@ def get_bedrock_client(service_name, credentials=None): aws_access_key_id = credentials["aws_access_key_id"] aws_secret_access_key = credentials["aws_secret_access_key"] if aws_access_key_id and aws_secret_access_key: - # 使用 AKSK 方式 + # use aksk to call bedrock client = boto3.client( service_name=service_name, config=client_config, @@ -15,7 +15,7 @@ def get_bedrock_client(service_name, credentials=None): aws_secret_access_key=aws_secret_access_key, ) else: - # 使用 IAM 角色方式 + # use iam without aksk to call client = boto3.client(service_name=service_name, config=client_config) return client From 948040b191c5c10d3145f14b9f3b0ea9b8c090e7 Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 22:28:28 +0800 Subject: [PATCH 5/6] fix style --- api/core/model_runtime/model_providers/bedrock/llm/llm.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/model_runtime/model_providers/bedrock/llm/llm.py b/api/core/model_runtime/model_providers/bedrock/llm/llm.py index 186ab1fe2b4f69..75ed7ad62404cb 100644 --- a/api/core/model_runtime/model_providers/bedrock/llm/llm.py +++ b/api/core/model_runtime/model_providers/bedrock/llm/llm.py @@ -15,6 +15,7 @@ ServiceNotInRegionError, UnknownServiceError, ) + # local import from core.model_runtime.callbacks.base_callback import Callback from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta From 77a3dbf1d0d48c20423e7a989289b7d128c44587 Mon Sep 17 00:00:00 2001 From: warren Date: Wed, 11 Dec 2024 22:33:35 +0800 Subject: [PATCH 6/6] fix style --- .../model_providers/bedrock/text_embedding/text_embedding.py | 1 + 1 file changed, 1 insertion(+) diff --git a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py index edd33c7085d5c3..5505797f7658b3 100644 --- a/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/bedrock/text_embedding/text_embedding.py @@ -10,6 +10,7 @@ ServiceNotInRegionError, UnknownServiceError, ) + from core.entities.embedding_type import EmbeddingInputType from core.model_runtime.entities.model_entities import PriceType from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult