Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Python code formatting, configuration loading, and local model additions #942

Merged
merged 7 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions docs/docs/usage-guide/introduction.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@

After [installation](https://codium-ai.github.io/Docs-PR-Agent/installation/), there are three basic ways to invoke CodiumAI PR-Agent:
After [installation](https://pr-agent-docs.codium.ai/installation/), there are three basic ways to invoke CodiumAI PR-Agent:

1. Locally running a CLI command
2. Online usage - by [commenting](https://github.com/Codium-ai/pr-agent/pull/229#issuecomment-1695021901) on a PR
3. Enabling PR-Agent tools to run automatically when a new PR is opened


Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-from-source), or by invoking a [locally cloned repo](https://codium-ai.github.io/Docs-PR-Agent/installation/#locally).
For online usage, you will need to setup either a [GitHub App](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-app), or a [GitHub Action](https://codium-ai.github.io/Docs-PR-Agent/installation/#run-as-a-github-action).
Specifically, CLI commands can be issued by invoking a pre-built [docker image](https://pr-agent-docs.codium.ai/installation/locally/#using-docker-image), or by invoking a [locally cloned repo](https://pr-agent-docs.codium.ai/installation/locally/#run-from-source).
For online usage, you will need to setup either a [GitHub App](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-app), or a [GitHub Action](https://pr-agent-docs.codium.ai/installation/github/#run-as-a-github-action).
GitHub App and GitHub Action also enable to run PR-Agent specific tool automatically when a new PR is opened.


**git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L4) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported:
**git provider**: The [git_provider](https://github.com/Codium-ai/pr-agent/blob/main/pr_agent/settings/configuration.toml#L5) field in the configuration file determines the GIT provider that will be used by PR-Agent. Currently, the following providers are supported:
`
"github", "gitlab", "bitbucket", "azure", "codecommit", "local", "gerrit"
`
Expand Down
8 changes: 5 additions & 3 deletions pr_agent/agent/pr_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@

commands = list(command2class.keys())


class PRAgent:
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action
self.ai_handler = ai_handler # will be initialized in run_action
self.forbidden_cli_args = ['enable_auto_approval']

async def handle_request(self, pr_url, request, notify=None) -> bool:
Expand All @@ -68,7 +69,9 @@ async def handle_request(self, pr_url, request, notify=None) -> bool:
for forbidden_arg in self.forbidden_cli_args:
for arg in args:
if forbidden_arg in arg:
get_logger().error(f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file.")
get_logger().error(
f"CLI argument for param '{forbidden_arg}' is forbidden. Use instead a configuration file."
)
return False
args = update_settings_from_args(args)

Expand All @@ -94,4 +97,3 @@ async def handle_request(self, pr_url, request, notify=None) -> bool:
else:
return False
return True

3 changes: 2 additions & 1 deletion pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
'gpt-4': 8000,
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4o-2024-05-13': 128000, # 128K, but may be limited by config.max_model_tokens
Expand All @@ -36,4 +36,5 @@
'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000,
'groq/llama3-8b-8192': 8192,
'groq/llama3-70b-8192': 8192,
'ollama/llama3': 4096,
}
4 changes: 2 additions & 2 deletions pr_agent/algo/ai_handlers/base_ai_handler.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod


class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler to be used by the PR Agents.
Expand All @@ -14,7 +15,7 @@ def __init__(self):
def deployment_id(self):
pass

@abstractmethod
@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2, img_path: str = None):
"""
This method should be implemented to return a chat completion from the AI model.
Expand All @@ -25,4 +26,3 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
temperature (float): the temperature to use for the chat completion
"""
pass

16 changes: 9 additions & 7 deletions pr_agent/algo/ai_handlers/langchain_ai_handler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
try:
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI
from langchain.schema import SystemMessage, HumanMessage
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
except: # we don't enforce langchain as a dependency, so if it's not installed, just move on
pass

from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
Expand All @@ -14,6 +14,7 @@

OPENAI_RETRIES = 5


class LangChainOpenAIHandler(BaseAiHandler):
def __init__(self):
# Initialize OpenAIHandler specific attributes here
Expand All @@ -36,7 +37,7 @@ def __init__(self):
raise ValueError(f"OpenAI {e.name} is required") from e
else:
raise e

@property
def chat(self):
if self.azure:
Expand All @@ -51,17 +52,18 @@ def deployment_id(self):
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
try:
messages=[SystemMessage(content=system), HumanMessage(content=user)]
messages = [SystemMessage(content=system), HumanMessage(content=user)]

# get a chat completion from the formatted messages
resp = self.chat(messages, model=model, temperature=temperature)
finish_reason="completed"
finish_reason = "completed"
return resp.content, finish_reason

except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise e
raise e
4 changes: 2 additions & 2 deletions pr_agent/algo/ai_handlers/litellm_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self):
if get_settings().get("HUGGINGFACE.API_BASE", None) and 'huggingface' in get_settings().config.model:
litellm.api_base = get_settings().huggingface.api_base
self.api_base = get_settings().huggingface.api_base
if get_settings().get("OLLAMA.API_BASE", None) :
if get_settings().get("OLLAMA.API_BASE", None):
litellm.api_base = get_settings().ollama.api_base
self.api_base = get_settings().ollama.api_base
if get_settings().get("HUGGINGFACE.REPITITION_PENALTY", None):
Expand Down Expand Up @@ -129,7 +129,7 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
"messages": messages,
"temperature": temperature,
"force_timeout": get_settings().config.ai_timeout,
"api_base" : self.api_base,
"api_base": self.api_base,
}
if self.aws_bedrock_client:
kwargs["aws_bedrock_client"] = self.aws_bedrock_client
Expand Down
9 changes: 5 additions & 4 deletions pr_agent/algo/ai_handlers/openai_ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,14 @@ def __init__(self):

except AttributeError as e:
raise ValueError("OpenAI key is required") from e

@property
def deployment_id(self):
"""
Returns the deployment ID for the OpenAI API.
"""
return get_settings().get("OPENAI.DEPLOYMENT_ID", None)

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
Expand All @@ -54,8 +55,8 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
finish_reason = chat_completion["choices"][0]["finish_reason"]
usage = chat_completion.get("usage")
get_logger().info("AI response", response=resp, messages=messages, finish_reason=finish_reason,
model=model, usage=usage)
return resp, finish_reason
model=model, usage=usage)
return resp, finish_reason
except (APIError, Timeout, TryAgain) as e:
get_logger().error("Error during OpenAI inference: ", e)
raise
Expand All @@ -64,4 +65,4 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
raise
except (Exception) as e:
get_logger().error("Unknown error during OpenAI inference: ", e)
raise TryAgain from e
raise TryAgain from e
3 changes: 2 additions & 1 deletion pr_agent/algo/file_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from pr_agent.config_loader import get_settings


def filter_ignored(files):
"""
Filter out files that match the ignore patterns.
Expand All @@ -14,7 +15,7 @@ def filter_ignored(files):
if isinstance(patterns, str):
patterns = [patterns]
glob_setting = get_settings().ignore.glob
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
if isinstance(glob_setting, str): # --ignore.glob=[.*utils.py], --ignore.glob=.*utils.py
glob_setting = glob_setting.strip('[]').split(",")
patterns += [fnmatch.translate(glob) for glob in glob_setting]

Expand Down
2 changes: 1 addition & 1 deletion pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip('-').strip()
vals = arg.split('=', 1)
if len(vals) != 2:
if len(vals) > 2: # --extended is a valid argument
if len(vals) > 2: # --extended is a valid argument
get_logger().error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
Expand Down
3 changes: 3 additions & 0 deletions pr_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
log_level = os.environ.get("LOG_LEVEL", "INFO")
setup_logger(log_level)


def set_parser():
parser = argparse.ArgumentParser(description='AI based pull request analyzer', usage=
"""\
Expand Down Expand Up @@ -50,6 +51,7 @@ def set_parser():
parser.add_argument('rest', nargs=argparse.REMAINDER, default=[])
return parser


def run_command(pr_url, command):
# Preparing the command
run_command_str = f"--pr_url={pr_url} {command.lstrip('/')}"
Expand All @@ -58,6 +60,7 @@ def run_command(pr_url, command):
# Run the command. Feedback will appear in GitHub PR comments
run(args=args)


def run(inargs=None, args=None):
parser = set_parser()
if not args:
Expand Down
13 changes: 11 additions & 2 deletions pr_agent/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,23 @@


def get_settings():
"""
Retrieves the current settings.

This function attempts to fetch the settings from the starlette_context's context object. If it fails,
it defaults to the global settings defined outside of this function.

Returns:
Dynaconf: The current settings object, either from the context or the global default.
"""
try:
return context["settings"]
except Exception:
return global_settings


# Add local configuration from pyproject.toml of the project being reviewed
def _find_repository_root() -> Path:
def _find_repository_root() -> Optional[Path]:
"""
Identify project root directory by recursively searching for the .git directory in the parent directories.
"""
Expand All @@ -61,7 +70,7 @@ def _find_pyproject() -> Optional[Path]:
"""
repo_root = _find_repository_root()
if repo_root:
pyproject = _find_repository_root() / "pyproject.toml"
pyproject = repo_root / "pyproject.toml"
return pyproject if pyproject.is_file() else None
return None

Expand Down
4 changes: 2 additions & 2 deletions pr_agent/git_providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
from pr_agent.git_providers.azuredevops_provider import AzureDevopsProvider
from pr_agent.git_providers.gerrit_provider import GerritProvider


_GIT_PROVIDERS = {
'github': GithubProvider,
'gitlab': GitLabProvider,
'bitbucket': BitbucketProvider,
'bitbucket_server': BitbucketServerProvider,
'azure': AzureDevopsProvider,
'codecommit': CodeCommitProvider,
'local' : LocalGitProvider,
'local': LocalGitProvider,
'gerrit': GerritProvider,
}


def get_git_provider():
try:
provider_id = get_settings().config.git_provider
Expand Down
9 changes: 6 additions & 3 deletions pr_agent/servers/github_action_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def run_action():
if event_payload.get("issue", {}).get("pull_request"):
url = event_payload.get("issue", {}).get("pull_request", {}).get("url")
is_pr = True
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
elif event_payload.get("comment", {}).get("pull_request_url"): # for 'pull_request_review_comment
url = event_payload.get("comment", {}).get("pull_request_url")
is_pr = True
disable_eyes = True
Expand All @@ -139,8 +139,11 @@ async def run_action():
comment_id = event_payload.get("comment", {}).get("id")
provider = get_git_provider()(pr_url=url)
if is_pr:
await PRAgent().handle_request(url, body,
notify=lambda: provider.add_eyes_reaction(comment_id, disable_eyes=disable_eyes))
await PRAgent().handle_request(
url, body, notify=lambda: provider.add_eyes_reaction(
comment_id, disable_eyes=disable_eyes
)
)
else:
await PRAgent().handle_request(url, body)

Expand Down
11 changes: 4 additions & 7 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(self, pr_url: str, args: list = None,
self.ai_handler = ai_handler()
self.ai_handler.main_pr_language = self.main_pr_language


# Initialize the variables dictionary
self.vars = {
"title": self.git_provider.pr.title,
Expand Down Expand Up @@ -157,7 +156,7 @@ async def run(self):
self.git_provider.remove_initial_comment()
except Exception as e:
get_logger().error(f"Error generating PR description {self.pr_id}: {e}")

return ""

async def _prepare_prediction(self, model: str) -> None:
Expand Down Expand Up @@ -221,9 +220,6 @@ def _prepare_data(self):
if 'pr_files' in self.data:
self.data['pr_files'] = self.data.pop('pr_files')




def _prepare_labels(self) -> List[str]:
pr_types = []

Expand Down Expand Up @@ -321,7 +317,7 @@ def _prepare_pr_answer(self) -> Tuple[str, str, str, List[dict]]:
value = self.file_label_dict
else:
key_publish = key.rstrip(':').replace("_", " ").capitalize()
if key_publish== "Type":
if key_publish == "Type":
key_publish = "PR Type"
# elif key_publish == "Description":
# key_publish = "PR Description"
Expand Down Expand Up @@ -512,11 +508,12 @@ def insert_br_after_x_chars(text, x=70):
is_inside_code = False
return ''.join(new_text).strip()


def replace_code_tags(text):
"""
Replace odd instances of ` with <code> and even instances of ` with </code>
"""
parts = text.split('`')
for i in range(1, len(parts), 2):
parts[i] = '<code>' + parts[i] + '</code>'
return ''.join(parts)
return ''.join(parts)
Loading
Loading