Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Lazy Initialization for AI Handlers in PR Tools #528

Merged
merged 1 commit into from
Dec 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions pr_agent/agent/pr_agent.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import shlex
from functools import partial

from pr_agent.algo.ai_handlers.base_ai_handler import BaseAiHandler
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler

Expand Down Expand Up @@ -41,8 +43,8 @@
commands = list(command2class.keys())

class PRAgent:
def __init__(self, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
self.ai_handler = ai_handler
def __init__(self, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.ai_handler = ai_handler # will be initialized in run_action

async def handle_request(self, pr_url, request, notify=None) -> bool:
# First, apply repo specific settings if exists
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_add_docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,14 +18,14 @@

class PRAddDocs:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)

self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import textwrap
from functools import partial
from typing import Dict, List
from jinja2 import Environment, StrictUndefined

Expand All @@ -16,7 +17,7 @@

class PRCodeSuggestions:
def __init__(self, pr_url: str, cli_mode=False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
Expand All @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None,
else:
num_code_suggestions = get_settings().pr_code_suggestions.num_code_suggestions

self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,7 +18,7 @@

class PRDescription:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRDescription object with the necessary attributes and objects for generating a PR description
using an AI model.
Expand All @@ -38,7 +39,7 @@ def __init__(self, pr_url: str, args: list = None,
get_settings().pr_description.enable_semantic_files_types = False

# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()

# Initialize the variables dictionary
self.vars = {
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_generate_labels.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import re
from functools import partial
from typing import List, Tuple

from jinja2 import Environment, StrictUndefined
Expand All @@ -17,7 +18,7 @@

class PRGenerateLabels:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRGenerateLabels object with the necessary attributes and objects for generating labels
corresponding to the PR using an AI model.
Expand All @@ -33,7 +34,7 @@ def __init__(self, pr_url: str, args: list = None,
self.pr_id = self.git_provider.get_pr_id()

# Initialize the AI handler
self.ai_handler = ai_handler
self.ai_handler = ai_handler()

# Initialize the variables dictionary
self.vars = {
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_information_from_user.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from functools import partial

from jinja2 import Environment, StrictUndefined

Expand All @@ -14,12 +15,12 @@

class PRInformationFromUser:
def __init__(self, pr_url: str, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.vars = {
"title": self.git_provider.pr.title,
"branch": self.git_provider.get_pr_branch(),
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_questions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import copy
from functools import partial

from jinja2 import Environment, StrictUndefined

Expand All @@ -13,13 +14,13 @@


class PRQuestions:
def __init__(self, pr_url: str, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
question_str = self.parse_args(args)
self.git_provider = get_git_provider()(pr_url)
self.main_pr_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.question_str = question_str
self.vars = {
"title": self.git_provider.pr.title,
Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import copy
import datetime
from collections import OrderedDict
from functools import partial
from typing import List, Tuple

import yaml
Expand All @@ -24,7 +25,7 @@ class PRReviewer:
The PRReviewer class is responsible for reviewing a pull request and generating feedback using an AI model.
"""
def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False, args: list = None,
ai_handler: BaseAiHandler = LiteLLMAIHandler()):
ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):
"""
Initialize the PRReviewer object with the necessary attributes and objects to review a pull request.

Expand All @@ -47,7 +48,7 @@ def __init__(self, pr_url: str, is_answer: bool = False, is_auto: bool = False,

if self.is_answer and not self.git_provider.is_supported("get_issue_comments"):
raise Exception(f"Answer mode is not supported for {get_settings().config.git_provider} for now")
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None

Expand Down
5 changes: 3 additions & 2 deletions pr_agent/tools/pr_update_changelog.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
from datetime import date
from functools import partial
from time import sleep
from typing import Tuple

Expand All @@ -18,15 +19,15 @@


class PRUpdateChangelog:
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: BaseAiHandler = LiteLLMAIHandler()):
def __init__(self, pr_url: str, cli_mode=False, args=None, ai_handler: partial[BaseAiHandler,] = LiteLLMAIHandler):

self.git_provider = get_git_provider()(pr_url)
self.main_language = get_main_pr_language(
self.git_provider.get_languages(), self.git_provider.get_files()
)
self.commit_changelog = get_settings().pr_update_changelog.push_changelog_changes
self._get_changlog_file() # self.changelog_file_str
self.ai_handler = ai_handler
self.ai_handler = ai_handler()
self.patches_diff = None
self.prediction = None
self.cli_mode = cli_mode
Expand Down
Loading