From 2f28e75c50bc6ae91d206da9916ea885b26c6ea9 Mon Sep 17 00:00:00 2001 From: "Dr. Kiji" <merdan.jp@gmail.com> Date: Fri, 20 Dec 2024 08:44:17 +0900 Subject: [PATCH 1/4] fix: add safe dictionary access for bedrock credentials --- .../model_providers/bedrock/get_bedrock_client.py | 12 ++++++++---- .../model_providers/bedrock/rerank/rerank.py | 5 ++++- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index a19ffbb20a6a9e..9b9feb9cfba2bb 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -1,11 +1,15 @@ import boto3 from botocore.config import Config - +from core.model_runtime.errors.invoke import InvokeBadRequestError def get_bedrock_client(service_name, credentials=None): - client_config = Config(region_name=credentials["aws_region"]) - aws_access_key_id = credentials["aws_access_key_id"] - aws_secret_access_key = credentials["aws_secret_access_key"] + region_name = credentials.get("aws_region") + if not region_name: + raise InvokeBadRequestError("aws_region is required") + client_config = Config(region_name=region_name) + aws_access_key_id = credentials.get("aws_access_key_id") + aws_secret_access_key = credentials.get("aws_secret_access_key") + if aws_access_key_id and aws_secret_access_key: # use aksk to call bedrock client = boto3.client( diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py index e134db646f3d39..9da23ba1b0f08f 100644 --- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py +++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py @@ -62,7 +62,10 @@ def _invoke( } ) modelId = model - region = credentials["aws_region"] + region = credentials.get("aws_region") + # region is a required field + if not region: + raise InvokeBadRequestError("aws_region is required in credentials") model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}" rerankingConfiguration = { "type": "BEDROCK_RERANKING_MODEL", From 9ef8cd6ab294f8cc6bdf674323c6c9526666b9ab Mon Sep 17 00:00:00 2001 From: "Dr. Kiji" <merdan.jp@gmail.com> Date: Fri, 20 Dec 2024 08:52:00 +0900 Subject: [PATCH 2/4] lint and format --- .../model_runtime/model_providers/bedrock/get_bedrock_client.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index 9b9feb9cfba2bb..c2ada1d18caee2 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -1,7 +1,9 @@ import boto3 from botocore.config import Config + from core.model_runtime.errors.invoke import InvokeBadRequestError + def get_bedrock_client(service_name, credentials=None): region_name = credentials.get("aws_region") if not region_name: From 96f9de6d0881aa5269ab9d2e90ed38a291a6de0b Mon Sep 17 00:00:00 2001 From: "Dr. Kiji" <merdan.jp@gmail.com> Date: Fri, 20 Dec 2024 08:54:59 +0900 Subject: [PATCH 3/4] lint and format --- .../model_runtime/model_providers/bedrock/get_bedrock_client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index c2ada1d18caee2..3d0fdc0abacf5b 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -11,7 +11,7 @@ def get_bedrock_client(service_name, credentials=None): client_config = Config(region_name=region_name) aws_access_key_id = credentials.get("aws_access_key_id") aws_secret_access_key = credentials.get("aws_secret_access_key") - + if aws_access_key_id and aws_secret_access_key: # use aksk to call bedrock client = boto3.client( From 8e109104b1dd090834105817ba824ac1a86000d8 Mon Sep 17 00:00:00 2001 From: -LAN- <laipz8200@outlook.com> Date: Fri, 20 Dec 2024 11:00:58 +0800 Subject: [PATCH 4/4] fix: enforce type hinting for credentials in get_bedrock_client function Signed-off-by: -LAN- <laipz8200@outlook.com> --- .../model_providers/bedrock/get_bedrock_client.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py index 3d0fdc0abacf5b..2ad37cef3b38f1 100644 --- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py +++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py @@ -1,10 +1,12 @@ +from collections.abc import Mapping + import boto3 from botocore.config import Config from core.model_runtime.errors.invoke import InvokeBadRequestError -def get_bedrock_client(service_name, credentials=None): +def get_bedrock_client(service_name: str, credentials: Mapping[str, str]): region_name = credentials.get("aws_region") if not region_name: raise InvokeBadRequestError("aws_region is required")