Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Abstract AiHandler to BaseAiHandler #514

Merged
merged 26 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
7e47baa
Refactor AI handler classes
brianphamsia Dec 9, 2023
f2abe5c
Abstract AiHandler to BaseAiHandler
brianpham93 Dec 9, 2023
c0303ff
Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler
brianpham93 Dec 9, 2023
b640992
Remove extra code
brianpham93 Dec 9, 2023
523a896
Rename AiHandler to LiteLLMAiHandler
brianphamsia Dec 11, 2023
b8021d7
rename file
brianphamsia Dec 11, 2023
a1cbd80
update base ai handler
brianphamsia Dec 11, 2023
ebf7027
add openai handler
brianphamsia Dec 11, 2023
5239e1c
Load default AI Handler from util function
brianphamsia Dec 12, 2023
7eb2e76
Move ai handlers to specific folder
brianphamsia Dec 12, 2023
6c7becc
add LangChain AI Handler
brianphamsia Dec 12, 2023
506eafc
add langchain in requirement
brianphamsia Dec 12, 2023
0c66554
langchain: move model and temperature to chat_completion
brianphamsia Dec 12, 2023
a627dcd
Update langchain
brianphamsia Dec 12, 2023
b7225cc
update langchain
brianphamsia Dec 12, 2023
ca1ccd7
update base
brianphamsia Dec 12, 2023
8fb4a42
Update AI handler instantiation in server files
brianphamsia Dec 13, 2023
be8d6af
Add code documentation generation for PR diffs
brianphamsia Dec 13, 2023
ebb2ed8
Add logging to pr_agent.py
brianphamsia Dec 13, 2023
69a7c77
Refactor PRAgent class and has_ai_handler_param
brianphamsia Dec 13, 2023
557b39e
Merge branch 'base-ai-handler' into abstract-BaseAiHandler
brianphamsia Dec 13, 2023
e37598f
Merge remote-tracking branch 'upstream/main' into abstract-BaseAiHandler
brianphamsia Dec 13, 2023
3531016
Refactor AI handler instantiation in PRAgent and related classes
mrT23 Dec 14, 2023
246be61
Set LiteLLMAIHandler as default AI handler in all PR tools and simpli…
mrT23 Dec 14, 2023
38ea914
Make LangChain dependency optional in pr-agent and update requirement…
mrT23 Dec 14, 2023
02871b1
Remove logging from pr_agent.py and add line breaks in cli.py and git…
mrT23 Dec 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from openai.error import APIError, RateLimitError, Timeout, TryAgain
from retry import retry
from pr_agent.config_loader import get_settings
from pr_agent.algo.base_ai_handler import BaseAiHandler
from pr_agent.log import get_logger

OPENAI_RETRIES = 5


class AiHandler:
class AiHandler(BaseAiHandler):
"""
This class handles interactions with the OpenAI API for chat completions.
It initializes the API key and other settings from a configuration file,
Expand Down
20 changes: 20 additions & 0 deletions pr_agent/algo/base_ai_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from abc import ABC, abstractmethod

class BaseAiHandler(ABC):
"""
This class defines the interface for an AI handler.
"""

@abstractmethod
def __init__(self):
pass

@property
@abstractmethod
def deployment_id(self):
pass

@abstractmethod
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
pass

6 changes: 3 additions & 3 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, List
from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml
Expand All @@ -14,7 +14,7 @@


class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None):
def __init__(self, pr_url: str, cli_mode=False, args: list = None, ai_handler: BaseAiHandler = AiHandler()):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
Expand All @@ -31,7 +31,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None):
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions

self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, set_custom_labels, get_user_labels
Expand All @@ -15,7 +15,7 @@


class PRDescription:
def __init__(self, pr_url: str, args: list = None):
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Expand All @@ -36,7 +36,7 @@ def __init__(self, pr_url: str, args: list = None):
get_settings().pr_description.enable_semantic_files_types = False

# Initialize the AI handler
self.ai_handler = AiHandler()
self.ai_handler = ai_handler

# Initialize the variables dictionary
self.vars = {
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/tools/pr_information_from_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
Expand All @@ -12,12 +12,12 @@


class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None):
def __init__(self, pr_url: str, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/tools/pr_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
Expand All @@ -12,13 +12,13 @@


class PRQuestions:
def __init__(self, pr_url: str, args=None):
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = AiHandler()):
question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from jinja2 import Environment, StrictUndefined
from yaml import SafeLoader

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import convert_to_markdown, load_yaml, try_fix_yaml, set_custom_labels, get_user_labels
Expand All @@ -22,7 +22,7 @@ class PRReviewer:
"""
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None):
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None, ai_handler: BaseAiHandler = AiHandler()):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.

Expand All @@ -43,7 +43,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False,

if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.patches_diff = None
self.prediction = None

Expand Down
6 changes: 3 additions & 3 deletions pr_agent/tools/pr_update_changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from jinja2 import Environment, StrictUndefined

from pr_agent.algo.ai_handler import AiHandler
from pr_agent.algo.ai_handler import BaseAiHandler, AiHandler
from pr_agent.algo.pr_processing import get_pr_diff, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.config_loader import get_settings
Expand All @@ -17,15 +17,15 @@


class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None):
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = AiHandler()):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = AiHandler()
self.ai_handler = ai_handler
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down