diff --git a/Usage.md b/Usage.md index 957077731..37fd61e1b 100644 --- a/Usage.md +++ b/Usage.md @@ -328,6 +328,23 @@ Your [application default credentials](https://cloud.google.com/docs/authenticat If you do want to set explicit credentials then you can use the `GOOGLE_APPLICATION_CREDENTIALS` environment variable set to a path to a json credentials file. +#### Amazon Bedrock + +To use Amazon Bedrock and its foundational models, add the below configuration: + +``` +[config] # in configuration.toml +model = "anthropic.claude-v2" +fallback_models="anthropic.claude-instant-v1" + +[aws] # in .secrets.toml +bedrock_region = "us-east-1" +``` + +Note that you have to add access to foundational models before using them. Please refer to [this document](https://docs.aws.amazon.com/bedrock/latest/userguide/setting-up.html) for more details. + +AWS session is automatically authenticated from your environment, but you can also explicitly set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY` environment variables. + ### Working with large PRs The default mode of CodiumAI is to have a single call per tool, using GPT-4, which has a token limit of 8000 tokens. diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index 5fe82ee52..63a628a57 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -18,4 +18,7 @@ 'vertex_ai/codechat-bison-32k': 32000, 'codechat-bison': 6144, 'codechat-bison-32k': 32000, + 'anthropic.claude-v2': 100000, + 'anthropic.claude-instant-v1': 100000, + 'anthropic.claude-v1': 100000, } diff --git a/pr_agent/algo/ai_handler.py b/pr_agent/algo/ai_handler.py index 9a48cdc3d..5b6a05f4e 100644 --- a/pr_agent/algo/ai_handler.py +++ b/pr_agent/algo/ai_handler.py @@ -1,5 +1,6 @@ import os +import boto3 import litellm import openai from litellm import acompletion @@ -24,6 +25,7 @@ def __init__(self): Raises a ValueError if the OpenAI key is missing. """ self.azure = False + self.aws_bedrock_client = None if get_settings().get("OPENAI.KEY", None): openai.api_key = get_settings().openai.key @@ -60,6 +62,12 @@ def __init__(self): litellm.vertex_location = get_settings().get( "VERTEXAI.VERTEX_LOCATION", None ) + if get_settings().get("AWS.BEDROCK_REGION", None): + litellm.AmazonAnthropicConfig.max_tokens_to_sample = 2000 + self.aws_bedrock_client = boto3.client( + service_name="bedrock-runtime", + region_name=get_settings().aws.bedrock_region, + ) @property def deployment_id(self): @@ -100,13 +108,16 @@ async def chat_completion(self, model: str, system: str, user: str, temperature: if self.azure: model = 'azure/' + model messages = [{"role": "system", "content": system}, {"role": "user", "content": user}] - response = await acompletion( - model=model, - deployment_id=deployment_id, - messages=messages, - temperature=temperature, - force_timeout=get_settings().config.ai_timeout - ) + kwargs = { + "model": model, + "deployment_id": deployment_id, + "messages": messages, + "temperature": temperature, + "force_timeout": get_settings().config.ai_timeout, + } + if self.aws_bedrock_client: + kwargs["aws_bedrock_client"] = self.aws_bedrock_client + response = await acompletion(**kwargs) except (APIError, Timeout, TryAgain) as e: get_logger().error("Error during OpenAI inference: ", e) raise diff --git a/pr_agent/algo/utils.py b/pr_agent/algo/utils.py index 7a6e666c4..1599f056c 100644 --- a/pr_agent/algo/utils.py +++ b/pr_agent/algo/utils.py @@ -325,7 +325,15 @@ def try_fix_yaml(response_text: str) -> dict: break except: pass - return data + + # thrid fallback - try to remove leading and trailing curly brackets + response_text_copy = response_text.strip().rstrip().removeprefix('{').removesuffix('}') + try: + data = yaml.safe_load(response_text_copy,) + get_logger().info(f"Successfully parsed AI prediction after removing curly brackets") + return data + except: + pass def set_custom_labels(variables): diff --git a/pr_agent/settings/.secrets_template.toml b/pr_agent/settings/.secrets_template.toml index ba51382c6..e7ca4057c 100644 --- a/pr_agent/settings/.secrets_template.toml +++ b/pr_agent/settings/.secrets_template.toml @@ -40,6 +40,9 @@ api_base = "" # the base url for your local Llama 2, Code Llama, and other model vertex_project = "" # the google cloud platform project name for your vertexai deployment vertex_location = "" # the google cloud platform location for your vertexai deployment +[aws] +bedrock_region = "" # the AWS region to call Bedrock APIs + [github] # ---- Set the following only for deployment type == "user" user_token = "" # A GitHub personal access token with 'repo' scope. diff --git a/requirements.txt b/requirements.txt index eae08f4cc..678cafd66 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,7 +14,7 @@ GitPython==3.1.32 PyYAML==6.0 starlette-context==0.3.6 litellm==0.12.5 -boto3==1.28.25 +boto3==1.33.1 google-cloud-storage==2.10.0 ujson==5.8.0 azure-devops==7.1.0b3