Skip to content

Commit

Permalink
Merge pull request Codium-ai#172 from krrishdholakia/patch-1
Browse files Browse the repository at this point in the history
adding support for Anthropic, Cohere, Replicate, Azure
  • Loading branch information
okotek authored Aug 6, 2023
2 parents b36e8a0 + 39eb0fb commit 3522f9f
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 25 deletions.
4 changes: 4 additions & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,8 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
'replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1': 4096,
}
33 changes: 23 additions & 10 deletions pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import openai
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry

import litellm
from litellm import acompletion
from pr_agent.config_loader import get_settings

import traceback
OPENAI_RETRIES=5

class AiHandler:
Expand All @@ -22,15 +23,25 @@ def __init__(self):
"""
try:
openai.api_key = get_settings().openai.key
litellm.openai_key = get_settings().openai.key
self.azure = False
if get_settings().get("OPENAI.ORG", None):
openai.organization = get_settings().openai.org
litellm.organization = get_settings().openai.org
self.deployment_id = get_settings().get("OPENAI.DEPLOYMENT_ID", None)
if get_settings().get("OPENAI.API_TYPE", None):
openai.api_type = get_settings().openai.api_type
if get_settings().openai.api_type == "azure":
self.azure = True
litellm.azure_key = get_settings().openai.key
if get_settings().get("OPENAI.API_VERSION", None):
openai.api_version = get_settings().openai.api_version
litellm.api_version = get_settings().openai.api_version
if get_settings().get("OPENAI.API_BASE", None):
openai.api_base = get_settings().openai.api_base
litellm.api_base = get_settings().openai.api_base
if get_settings().get("ANTHROPIC.KEY", None):
litellm.anthropic_key = get_settings().anthropic.key
if get_settings().get("COHERE.KEY", None):
litellm.cohere_key = get_settings().cohere.key
if get_settings().get("REPLICATE.KEY", None):
litellm.replicate_key = get_settings().replicate.key
except AttributeError as e:
raise ValueError("OpenAI key is required") from e

Expand All @@ -57,14 +68,15 @@ async def chat_completion(self, model: str, temperature: float, system: str, use
TryAgain: If there is an attribute error during OpenAI inference.
"""
try:
response = await openai.ChatCompletion.acreate(
response = await acompletion(
model=model,
deployment_id=self.deployment_id,
messages=[
{"role": "system", "content": system},
{"role": "user", "content": user}
],
temperature=temperature,
azure=self.azure
)
except (APIError, Timeout, TryAgain) as e:
logging.error("Error during OpenAI inference: ", e)
Expand All @@ -75,8 +87,9 @@ async def chat_completion(self, model: str, temperature: float, system: str, use
except (Exception) as e:
logging.error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
if response is None or len(response.choices) == 0:
if response is None or len(response["choices"]) == 0:
raise TryAgain
resp = response.choices[0]['message']['content']
finish_reason = response.choices[0].finish_reason
resp = response["choices"][0]['message']['content']
finish_reason = response["choices"][0]["finish_reason"]
print(resp, finish_reason)
return resp, finish_reason
10 changes: 6 additions & 4 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
from __future__ import annotations

import re
import difflib
import logging
from typing import Callable, Tuple, List, Any
import re
import traceback
from typing import Any, Callable, List, Tuple

from github import RateLimitExceededException

from pr_agent.algo import MAX_TOKENS
from pr_agent.algo.git_patch_processing import convert_to_hunks_with_lines_numbers, extend_patch, handle_patch_deletions
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import GitProvider, FilePatchInfo
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider

DELETED_FILES_ = "Deleted files:\n"

Expand Down Expand Up @@ -215,7 +217,7 @@ async def retry_with_fallback_models(f: Callable):
try:
return await f(model)
except Exception as e:
logging.warning(f"Failed to generate prediction with {model}: {e}")
logging.warning(f"Failed to generate prediction with {model}: {traceback.format_exc()}")
if i == len(all_models) - 1: # If it's the last iteration
raise # Re-raise the last exception

Expand Down
5 changes: 2 additions & 3 deletions pr_agent/algo/token_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from jinja2 import Environment, StrictUndefined
from tiktoken import encoding_for_model
from tiktoken import encoding_for_model, get_encoding

from pr_agent.config_loader import get_settings

Expand Down Expand Up @@ -27,7 +27,7 @@ def __init__(self, pr, vars: dict, system, user):
- system: The system string.
- user: The user string.
"""
self.encoder = encoding_for_model(get_settings().config.model)
self.encoder = encoding_for_model(get_settings().config.model) if "gpt" in get_settings().config.model else get_encoding("cl100k_base")
self.prompt_tokens = self._get_system_user_tokens(pr, self.encoder, vars, system, user)

def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
Expand All @@ -47,7 +47,6 @@ def _get_system_user_tokens(self, pr, encoder, vars: dict, system, user):
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(system).render(vars)
user_prompt = environment.from_string(user).render(vars)

system_prompt_tokens = len(encoder.encode(system_prompt))
user_prompt_tokens = len(encoder.encode(user_prompt))
return system_prompt_tokens + user_prompt_tokens
Expand Down
19 changes: 14 additions & 5 deletions pr_agent/settings/.secrets_template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,26 @@
# See README for details about GitHub App deployment.

[openai]
key = "<API_KEY>" # Acquire through https://platform.openai.com
org = "<ORGANIZATION>" # Optional, may be commented out.
key = "" # Acquire through https://platform.openai.com
#org = "<ORGANIZATION>" # Optional, may be commented out.
# Uncomment the following for Azure OpenAI
#api_type = "azure"
#api_version = '2023-05-15' # Check Azure documentation for the current API version
#api_base = "<API_BASE>" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "<DEPLOYMENT_ID>" # The deployment name you chose when you deployed the engine
#api_base = "" # The base URL for your Azure OpenAI resource. e.g. "https://<your resource name>.openai.azure.com"
#deployment_id = "" # The deployment name you chose when you deployed the engine

[anthropic]
key = "" # Optional, uncomment if you want to use Anthropic. Acquire through https://www.anthropic.com/

[cohere]
key = "" # Optional, uncomment if you want to use Cohere. Acquire through https://dashboard.cohere.ai/

[replicate]
key = "" # Optional, uncomment if you want to use Replicate. Acquire through https://replicate.com/
[github]
# ---- Set the following only for deployment type == "user"
user_token = "<TOKEN>" # A GitHub personal access token with 'repo' scope.
user_token = "" # A GitHub personal access token with 'repo' scope.
deployment_type = "user" #set to user by default

# ---- Set the following only for deployment type == "app", see README for details.
private_key = """\
Expand Down
5 changes: 4 additions & 1 deletion pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _prepare_pr_review(self) -> str:
del pr_feedback['Security concerns']
data.setdefault('PR Analysis', {})['Security concerns'] = security_concerns

#
#
if 'Code feedback' in pr_feedback:
code_feedback = pr_feedback['Code feedback']

Expand Down Expand Up @@ -218,6 +218,9 @@ def _prepare_pr_review(self) -> str:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Markdown response:\n{markdown_text}")

if markdown_text == None or len(markdown_text) == 0:
markdown_text = review

return markdown_text

def _publish_inline_code_comments(self) -> None:
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ dependencies = [
"aiohttp~=3.8.4",
"atlassian-python-api==3.39.0",
"GitPython~=3.1.32",
"starlette-context==0.3.6"
"starlette-context==0.3.6",
"litellm~=0.1.351"
]

[project.urls]
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ python-gitlab==3.15.0
pytest~=7.4.0
aiohttp~=3.8.4
atlassian-python-api==3.39.0
GitPython~=3.1.32
GitPython~=3.1.32
litellm~=0.1.351

0 comments on commit 3522f9f

Please sign in to comment.