From 2f28e75c50bc6ae91d206da9916ea885b26c6ea9 Mon Sep 17 00:00:00 2001
From: "Dr. Kiji" <merdan.jp@gmail.com>
Date: Fri, 20 Dec 2024 08:44:17 +0900
Subject: [PATCH 1/4] fix: add safe dictionary access for bedrock credentials

---
 .../model_providers/bedrock/get_bedrock_client.py    | 12 ++++++++----
 .../model_providers/bedrock/rerank/rerank.py         |  5 ++++-
 2 files changed, 12 insertions(+), 5 deletions(-)

diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
index a19ffbb20a6a9e..9b9feb9cfba2bb 100644
--- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
+++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
@@ -1,11 +1,15 @@
 import boto3
 from botocore.config import Config
-
+from core.model_runtime.errors.invoke import InvokeBadRequestError
 
 def get_bedrock_client(service_name, credentials=None):
-    client_config = Config(region_name=credentials["aws_region"])
-    aws_access_key_id = credentials["aws_access_key_id"]
-    aws_secret_access_key = credentials["aws_secret_access_key"]
+    region_name = credentials.get("aws_region")
+    if not region_name:
+        raise InvokeBadRequestError("aws_region is required")
+    client_config = Config(region_name=region_name)
+    aws_access_key_id = credentials.get("aws_access_key_id")
+    aws_secret_access_key = credentials.get("aws_secret_access_key")
+    
     if aws_access_key_id and aws_secret_access_key:
         # use aksk to call bedrock
         client = boto3.client(
diff --git a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
index e134db646f3d39..9da23ba1b0f08f 100644
--- a/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
+++ b/api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
@@ -62,7 +62,10 @@ def _invoke(
                 }
             )
         modelId = model
-        region = credentials["aws_region"]
+        region = credentials.get("aws_region")
+        # region is a required field
+        if not region:
+            raise InvokeBadRequestError("aws_region is required in credentials")
         model_package_arn = f"arn:aws:bedrock:{region}::foundation-model/{modelId}"
         rerankingConfiguration = {
             "type": "BEDROCK_RERANKING_MODEL",

From 9ef8cd6ab294f8cc6bdf674323c6c9526666b9ab Mon Sep 17 00:00:00 2001
From: "Dr. Kiji" <merdan.jp@gmail.com>
Date: Fri, 20 Dec 2024 08:52:00 +0900
Subject: [PATCH 2/4] lint and format

---
 .../model_runtime/model_providers/bedrock/get_bedrock_client.py | 2 ++
 1 file changed, 2 insertions(+)

diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
index 9b9feb9cfba2bb..c2ada1d18caee2 100644
--- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
+++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
@@ -1,7 +1,9 @@
 import boto3
 from botocore.config import Config
+
 from core.model_runtime.errors.invoke import InvokeBadRequestError
 
+
 def get_bedrock_client(service_name, credentials=None):
     region_name = credentials.get("aws_region")
     if not region_name:

From 96f9de6d0881aa5269ab9d2e90ed38a291a6de0b Mon Sep 17 00:00:00 2001
From: "Dr. Kiji" <merdan.jp@gmail.com>
Date: Fri, 20 Dec 2024 08:54:59 +0900
Subject: [PATCH 3/4] lint and format

---
 .../model_runtime/model_providers/bedrock/get_bedrock_client.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
index c2ada1d18caee2..3d0fdc0abacf5b 100644
--- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
+++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
@@ -11,7 +11,7 @@ def get_bedrock_client(service_name, credentials=None):
     client_config = Config(region_name=region_name)
     aws_access_key_id = credentials.get("aws_access_key_id")
     aws_secret_access_key = credentials.get("aws_secret_access_key")
-    
+
     if aws_access_key_id and aws_secret_access_key:
         # use aksk to call bedrock
         client = boto3.client(

From 8e109104b1dd090834105817ba824ac1a86000d8 Mon Sep 17 00:00:00 2001
From: -LAN- <laipz8200@outlook.com>
Date: Fri, 20 Dec 2024 11:00:58 +0800
Subject: [PATCH 4/4] fix: enforce type hinting for credentials in
 get_bedrock_client function

Signed-off-by: -LAN- <laipz8200@outlook.com>
---
 .../model_providers/bedrock/get_bedrock_client.py             | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

diff --git a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
index 3d0fdc0abacf5b..2ad37cef3b38f1 100644
--- a/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
+++ b/api/core/model_runtime/model_providers/bedrock/get_bedrock_client.py
@@ -1,10 +1,12 @@
+from collections.abc import Mapping
+
 import boto3
 from botocore.config import Config
 
 from core.model_runtime.errors.invoke import InvokeBadRequestError
 
 
-def get_bedrock_client(service_name, credentials=None):
+def get_bedrock_client(service_name: str, credentials: Mapping[str, str]):
     region_name = credentials.get("aws_region")
     if not region_name:
         raise InvokeBadRequestError("aws_region is required")