Skip to content

Commit b90dde4

Browse files
authored
Merge pull request #483 from tmokmss/add-bedrock-support
Add Amazon Bedrock support
2 parents c57807e + 5e642c1 commit b90dde4

File tree

6 files changed

+51
-9
lines changed

6 files changed

+51
-9
lines changed

Usage.md

+17
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat
328328
329329
If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file.
330330
331+
#### Amazon Bedrock
332+
333+
To use Amazon Bedrock and its foundational models, add the below configuration:
334+
335+
```
336+
[config] # in configuration.toml
337+
model = "anthropic.claude-v2"
338+
fallback_models="anthropic.claude-instant-v1"
339+
340+
[aws] # in .secrets.toml
341+
bedrock_region = "us-east-1"
342+
```
343+
344+
Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details.
345+
346+
AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.
347+
331348
### Working with large PRs
332349
333350
The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens.

pr_agent/algo/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,7 @@
1818
'vertex_ai/codechat-bison-32k': 32000,
1919
'codechat-bison': 6144,
2020
'codechat-bison-32k': 32000,
21+
'anthropic.claude-v2': 100000,
22+
'anthropic.claude-instant-v1': 100000,
23+
'anthropic.claude-v1': 100000,
2124
}

pr_agent/algo/ai_handler.py

+18-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22

3+
import boto3
34
import litellm
45
import openai
56
from litellm import acompletion
@@ -24,6 +25,7 @@ def __init__(self):
2425
Raises a ValueError if the OpenAI key is missing.
2526
"""
2627
self.azure = False
28+
self.aws_bedrock_client = None
2729

2830
if get_settings().get("OPENAI.KEY", None):
2931
openai.api_key = get_settings().openai.key
@@ -60,6 +62,12 @@ def __init__(self):
6062
litellm.vertex_location = get_settings().get(
6163
"VERTEXAI.VERTEX_LOCATION", None
6264
)
65+
if get_settings().get("AWS.BEDROCK_REGION", None):
66+
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
67+
self.aws_bedrock_client = boto3.client(
68+
service_name="bedrock-runtime",
69+
region_name=get_settings().aws.bedrock_region,
70+
)
6371

6472
@property
6573
def deployment_id(self):
@@ -100,13 +108,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
100108
if self.azure:
101109
model = 'azure/' + model
102110
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
103-
response = await acompletion(
104-
model=model,
105-
deployment_id=deployment_id,
106-
messages=messages,
107-
temperature=temperature,
108-
force_timeout=get_settings().config.ai_timeout
109-
)
111+
kwargs = {
112+
"model": model,
113+
"deployment_id": deployment_id,
114+
"messages": messages,
115+
"temperature": temperature,
116+
"force_timeout": get_settings().config.ai_timeout,
117+
}
118+
if self.aws_bedrock_client:
119+
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
120+
response = await acompletion(**kwargs)
110121
except (APIError, Timeout, TryAgain) as e:
111122
get_logger().error("Error during OpenAI inference: ", e)
112123
raise

pr_agent/algo/utils.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict:
325325
break
326326
except:
327327
pass
328-
return data
328+
329+
# thrid fallback - try to remove leading and trailing curly brackets
330+
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
331+
try:
332+
data = yaml.safe_load(response_text_copy,)
333+
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
334+
return data
335+
except:
336+
pass
329337

330338

331339
def set_custom_labels(variables):

pr_agent/settings/.secrets_template.toml

+3
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model
4040
vertex_project = "" # the google cloud platform project name for your vertexai deployment
4141
vertex_location = "" # the google cloud platform location for your vertexai deployment
4242

43+
[aws]
44+
bedrock_region = "" # the AWS region to call Bedrock APIs
45+
4346
[github]
4447
# ---- Set the following only for deployment type == "user"
4548
user_token = "" # A GitHub personal access token with 'repo' scope.

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ GitPython==3.1.32
1414
PyYAML==6.0
1515
starlette-context==0.3.6
1616
litellm==0.12.5
17-
boto3==1.28.25
17+
boto3==1.33.1
1818
google-cloud-storage==2.10.0
1919
ujson==5.8.0
2020
azure-devops==7.1.0b3

0 commit comments

Comments
 (0)