Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce to weak model #1387

Merged
merged 10 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 11 additions & 11 deletions docs/docs/usage-guide/changing_a_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ To use a different model than the default (GPT-4), you need to edit in the [conf
```
[config]
model = "..."
model_turbo = "..."
model_weak = "..."
fallback_models = ["..."]
```

Expand All @@ -28,7 +28,7 @@ and set in your configuration file:
```
[config]
model="" # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo)
model_turbo="" # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo)
model_weak="" # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo)
fallback_models=["..."] # the OpenAI model you've deployed on Azure (e.g. gpt-3.5-turbo)
```

Expand All @@ -52,7 +52,7 @@ MAX_TOKENS={

[config] # in configuration.toml
model = "ollama/llama2"
model_turbo = "ollama/llama2"
model_weak = "ollama/llama2"
fallback_models=["ollama/llama2"]

[ollama] # in .secrets.toml
Expand All @@ -76,7 +76,7 @@ MAX_TOKENS={
}
[config] # in configuration.toml
model = "huggingface/meta-llama/Llama-2-7b-chat-hf"
model_turbo = "huggingface/meta-llama/Llama-2-7b-chat-hf"
model_weak = "huggingface/meta-llama/Llama-2-7b-chat-hf"
fallback_models=["huggingface/meta-llama/Llama-2-7b-chat-hf"]

[huggingface] # in .secrets.toml
Expand All @@ -91,7 +91,7 @@ To use Llama2 model with Replicate, for example, set:
```
[config] # in configuration.toml
model = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
model_turbo = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
model_weak = "replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"
fallback_models=["replicate/llama-2-70b-chat:2c1608e18606fad2812020dc541930f2d0495ce32eee50074220b87300bc16e1"]
[replicate] # in .secrets.toml
key = ...
Expand All @@ -107,7 +107,7 @@ To use Llama3 model with Groq, for example, set:
```
[config] # in configuration.toml
model = "llama3-70b-8192"
model_turbo = "llama3-70b-8192"
model_weak = "llama3-70b-8192"
fallback_models = ["groq/llama3-70b-8192"]
[groq] # in .secrets.toml
key = ... # your Groq api key
Expand All @@ -121,7 +121,7 @@ To use Google's Vertex AI platform and its associated models (chat-bison/codecha
```
[config] # in configuration.toml
model = "vertex_ai/codechat-bison"
model_turbo = "vertex_ai/codechat-bison"
model_weak = "vertex_ai/codechat-bison"
fallback_models="vertex_ai/codechat-bison"

[vertexai] # in .secrets.toml
Expand All @@ -140,7 +140,7 @@ To use [Google AI Studio](https://aistudio.google.com/) models, set the relevant
```toml
[config] # in configuration.toml
model="google_ai_studio/gemini-1.5-flash"
model_turbo="google_ai_studio/gemini-1.5-flash"
model_weak="google_ai_studio/gemini-1.5-flash"
fallback_models=["google_ai_studio/gemini-1.5-flash"]

[google_ai_studio] # in .secrets.toml
Expand All @@ -156,7 +156,7 @@ To use Anthropic models, set the relevant models in the configuration section of
```
[config]
model="anthropic/claude-3-opus-20240229"
model_turbo="anthropic/claude-3-opus-20240229"
model_weak="anthropic/claude-3-opus-20240229"
fallback_models=["anthropic/claude-3-opus-20240229"]
```

Expand All @@ -173,7 +173,7 @@ To use Amazon Bedrock and its foundational models, add the below configuration:
```
[config] # in configuration.toml
model="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
model_turbo="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
model_weak="bedrock/anthropic.claude-3-sonnet-20240229-v1:0"
fallback_models=["bedrock/anthropic.claude-v2:1"]
```

Expand All @@ -195,7 +195,7 @@ If the relevant model doesn't appear [here](https://github.com/Codium-ai/pr-agen
```
[config]
model="custom_model_name"
model_turbo="custom_model_name"
model_weak="custom_model_name"
fallback_models=["custom_model_name"]
```
(2) Set the maximal tokens for the model:
Expand Down
6 changes: 3 additions & 3 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ def generate_full_patch(convert_hunks_to_line_numbers, file_dict, max_tokens_mod
return total_tokens, patches, remaining_files_list_new, files_in_patch_list


async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.REGULAR):
async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelType.WEAK):
KennyDizi marked this conversation as resolved.
Show resolved Hide resolved
all_models = _get_all_models(model_type)
all_deployments = _get_all_deployments(all_models)
# try each (model, deployment_id) pair until one is successful, otherwise raise exception
Expand All @@ -354,8 +354,8 @@ async def retry_with_fallback_models(f: Callable, model_type: ModelType = ModelT


def _get_all_models(model_type: ModelType = ModelType.REGULAR) -> List[str]:
if model_type == ModelType.TURBO:
model = get_settings().config.model_turbo
if get_settings().config.get('model_weak') and model_type == ModelType.WEAK:
model = get_settings().config.model_weak
else:
model = get_settings().config.model
fallback_models = get_settings().config.fallback_models
Expand Down
3 changes: 1 addition & 2 deletions pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,7 @@ class Range(BaseModel):

class ModelType(str, Enum):
REGULAR = "regular"
TURBO = "turbo"

WEAK = "weak"

class PRReviewHeader(str, Enum):
REGULAR = "## PR Reviewer Guide"
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,5 +99,5 @@ def set_claude_model():
"""
model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0"
get_settings().set('config.model', model_claude)
get_settings().set('config.model_turbo', model_claude)
get_settings().set('config.model_weak', model_claude)
get_settings().set('config.fallback_models', [model_claude])
4 changes: 2 additions & 2 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[config]
# models
model="gpt-4-turbo-2024-04-09"
model_turbo="gpt-4o-2024-11-20"
model_weak="gpt-4o-mini-2024-07-18"
model="gpt-4o-2024-11-20"
fallback_models=["gpt-4o-2024-08-06"]
# CLI
git_provider="github"
Expand Down
4 changes: 2 additions & 2 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ async def run(self):
self.git_provider.publish_comment("Preparing suggestions...", is_temporary=True)

if not self.is_extended:
data = await retry_with_fallback_models(self._prepare_prediction)
data = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
else:
data = await retry_with_fallback_models(self._prepare_prediction_extended)
data = await retry_with_fallback_models(self._prepare_prediction_extended, model_type=ModelType.REGULAR)
if not data:
data = {"code_suggestions": []}

Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_description.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ async def run(self):
# ticket extraction if exists
await extract_and_cache_pr_tickets(self.git_provider, self.vars)

await retry_with_fallback_models(self._prepare_prediction, ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, ModelType.WEAK)

if self.prediction:
self._prepare_data()
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_help_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ async def run(self):
self.vars['snippets'] = docs_prompt.strip()

# run the AI model
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
response = await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)
response_yaml = load_yaml(response)
response_str = response_yaml.get('response')
relevant_sections = response_yaml.get('relevant_sections')
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_line_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ async def run(self):
line_end=line_end,
side=side)
if self.patch_with_lines:
response = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.TURBO)
response = await retry_with_fallback_models(self._get_prediction, model_type=ModelType.WEAK)

get_logger().info('Preparing answer...')
if comment_id:
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_questions.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async def run(self):
if img_path:
get_logger().debug(f"Image path identified", artifact=img_path)

await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)

pr_comment = self._prepare_pr_answer()
get_logger().debug(f"PR output", artifact=pr_comment)
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_reviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ async def run(self) -> None:
if get_settings().config.publish_output and not get_settings().config.get('is_auto_command', False):
self.git_provider.publish_comment("Preparing review...", is_temporary=True)

await retry_with_fallback_models(self._prepare_prediction)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.REGULAR)
if not self.prediction:
self.git_provider.remove_initial_comment()
return None
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/tools/pr_update_changelog.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ async def run(self):
if get_settings().config.publish_output:
self.git_provider.publish_comment("Preparing changelog updates...", is_temporary=True)

await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.TURBO)
await retry_with_fallback_models(self._prepare_prediction, model_type=ModelType.WEAK)

new_file_content, answer = self._prepare_changelog_update()

Expand Down
Loading