Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moving the 'improve' command to turbo mode, with auto_extended=true #636

Merged
merged 1 commit into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pr_agent/algo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
'gpt-4-0613': 8000,
'gpt-4-32k': 32000,
'gpt-4-1106-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'gpt-4-0125-preview': 128000, # 128K, but may be limited by config.max_model_tokens
'claude-instant-1': 100000,
'claude-2': 100000,
'command-nightly': 4096,
Expand Down
13 changes: 8 additions & 5 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pr_agent.algo.language_handler import sort_files_by_main_languages
from pr_agent.algo.file_filter import filter_ignored
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import get_max_tokens
from pr_agent.algo.utils import get_max_tokens, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers.git_provider import FilePatchInfo, GitProvider, EDIT_TYPE
from pr_agent.log import get_logger
Expand Down Expand Up @@ -220,8 +220,8 @@ def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, mo
return patches, modified_files_list, deleted_files_list, added_files_list


async def retry_with_fallback_models(f: Callable):
all_models = _get_all_models()
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
for i, (model, deployment_id) in enumerate(zip(all_models, all_deployments)):
Expand All @@ -243,8 +243,11 @@ async def retry_with_fallback_models(f: Callable):
raise # Re-raise the last exception


def _get_all_models() -> List[str]:
model = get_settings().config.model
def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
if not isinstance(fallback_models, list):
fallback_models = [m.strip() for m in fallback_models.split(",")]
Expand Down
4 changes: 4 additions & 0 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import re
import textwrap
from datetime import datetime
from enum import Enum
from typing import Any, List

import yaml
Expand All @@ -15,6 +16,9 @@
from pr_agent.config_loader import get_settings, global_settings
from pr_agent.log import get_logger

class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"

def get_setting(key: str) -> Any:
try:
Expand Down
12 changes: 7 additions & 5 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[config]
model="gpt-4" # "gpt-4-0125-preview"
model_turbo="gpt-4-0125-preview"
fallback_models=["gpt-3.5-turbo-16k"]
git_provider="github"
publish_output=true
Expand Down Expand Up @@ -68,17 +69,18 @@ enable_help_text=true


[pr_code_suggestions] # /improve #
max_context_tokens=8000
num_code_suggestions=4
summarize = true
extra_instructions = ""
rank_suggestions = false
enable_help_text=true
# params for '/improve --extended' mode
auto_extended_mode=false
num_code_suggestions_per_chunk=8
rank_extended_suggestions = true
max_number_of_calls = 5
final_clip_factor = 0.9
auto_extended_mode=true
num_code_suggestions_per_chunk=5
rank_extended_suggestions = false
max_number_of_calls = 3
final_clip_factor = 0.8

[pr_add_docs] # /add_docs #
extra_instructions = ""
Expand Down
14 changes: 11 additions & 3 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from pr_agent.algo.ai_handlers.litellm_ai_handler import LiteLLMAIHandler
from pr_agent.algo.pr_processing import get_pr_diff, get_pr_multi_diffs, retry_with_fallback_models
from pr_agent.algo.token_handler import TokenHandler
from pr_agent.algo.utils import load_yaml, replace_code_tags
from pr_agent.algo.utils import load_yaml, replace_code_tags, ModelType
from pr_agent.config_loader import get_settings
from pr_agent.git_providers import get_git_provider
from pr_agent.git_providers.git_provider import get_main_pr_language
Expand All @@ -26,6 +26,14 @@ def __init__(self, pr_url: str, cli_mode=False, args: list = None,
self.git_provider.get_languages(), self.git_provider.get_files()
)

# limit context specifically for the improve command, which has hard input to parse:
if get_settings().pr_code_suggestions.max_context_tokens:
MAX_CONTEXT_TOKENS_IMPROVE = get_settings().pr_code_suggestions.max_context_tokens
if get_settings().config.max_model_tokens > MAX_CONTEXT_TOKENS_IMPROVE:
get_logger().info(f"Setting max_model_tokens to {MAX_CONTEXT_TOKENS_IMPROVE} for PR improve")
get_settings().config.max_model_tokens = MAX_CONTEXT_TOKENS_IMPROVE


# extended mode
try:
self.is_extended = self._get_is_extended(args or [])
Expand Down Expand Up @@ -64,10 +72,10 @@ async def run(self):

get_logger().info('Preparing PR code suggestions...')
if not self.is_extended:
await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
data = self._prepare_pr_code_suggestions()
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
data = await retry_with_fallback_models(self._prepare_prediction_extended, ModelType.TURBO)


if (not data) or (not 'code_suggestions' in data):
Expand Down
Loading