Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Amazon Bedrock support #483

Merged
merged 5 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions Usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat

If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file.

#### Amazon Bedrock

To use Amazon Bedrock and its foundational models, add the below configuration:

```
[config] # in configuration.toml
model = "anthropic.claude-v2"
fallback_models="anthropic.claude-instant-v1"

[aws] # in .secrets.toml
bedrock_region = "us-east-1"
```

Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details.

AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables.

### Working with large PRs

The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens.
Expand Down
3 changes: 3 additions & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,7 @@
'vertex_ai/codechat-bison-32k': 32000,
'codechat-bison': 6144,
'codechat-bison-32k': 32000,
'anthropic.claude-v2': 100000,
'anthropic.claude-instant-v1': 100000,
'anthropic.claude-v1': 100000,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I add support only for Claude for now because it's difficult to tune more parameters like max_tokens_to_sample specific for each model.

}
25 changes: 18 additions & 7 deletions pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

import boto3
import litellm
import openai
from litellm import acompletion
Expand All @@ -24,6 +25,7 @@ def __init__(self):
Raises a ValueError if the OpenAI key is missing.
"""
self.azure = False
self.aws_bedrock_client = None

if get_settings().get("OPENAI.KEY", None):
openai.api_key = get_settings().openai.key
Expand Down Expand Up @@ -60,6 +62,12 @@ def __init__(self):
litellm.vertex_location = get_settings().get(
"VERTEXAI.VERTEX_LOCATION", None
)
if get_settings().get("AWS.BEDROCK_REGION", None):
litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000
Copy link
Contributor Author

@tmokmss tmokmss Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this because max_tokens_to_sample in litellm is 256 by default and that will stop the completion in the middle of a review. We can allow users to set this parameter from a config file but I think hardcoding this large value is usually enough for many users.

self.aws_bedrock_client = boto3.client(
service_name="bedrock-runtime",
region_name=get_settings().aws.bedrock_region,
)

@property
def deployment_id(self):
Expand Down Expand Up @@ -100,13 +108,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
if self.azure:
model = 'azure/' + model
messages = [{"role": "system", "content": system}, {"role": "user", "content": user}]
response = await acompletion(
model=model,
deployment_id=deployment_id,
messages=messages,
temperature=temperature,
force_timeout=get_settings().config.ai_timeout
)
kwargs = {
"model": model,
"deployment_id": deployment_id,
"messages": messages,
"temperature": temperature,
"force_timeout": get_settings().config.ai_timeout,
}
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
response = await acompletion(**kwargs)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adding aws_bedrock_client kwarg when calling other providers like openai will fail due to the below error, so I only add this arg when using bedrock.

openai.error.InvalidRequestError: Unrecognized request argument supplied: aws_bedrock_client

except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
Expand Down
10 changes: 9 additions & 1 deletion pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict:
break
except:
pass
return data

# thrid fallback - try to remove leading and trailing curly brackets
response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}')
try:
data = yaml.safe_load(response_text_copy,)
get_logger().info(f"Successfully parsed AI prediction after removing curly brackets")
return data
except:
pass


def set_custom_labels(variables):
Expand Down
3 changes: 3 additions & 0 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model
vertex_project = "" # the google cloud platform project name for your vertexai deployment
vertex_location = "" # the google cloud platform location for your vertexai deployment

[aws]
bedrock_region = "" # the AWS region to call Bedrock APIs

[github]
# ---- Set the following only for deployment type == "user"
user_token = "" # A GitHub personal access token with 'repo' scope.
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ GitPython==3.1.32
PyYAML==6.0
starlette-context==0.3.6
litellm==0.12.5
boto3==1.28.25
boto3==1.33.1
Copy link
Contributor Author

@tmokmss tmokmss Nov 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

litellm requires boto3>=1.28.57, so I updated boto3 to the latest version. I don't see any breaking changes between 1.28.25 and 1.33.1 (doc).

google-cloud-storage==2.10.0
ujson==5.8.0
azure-devops==7.1.0b3
Expand Down