Skip to content

Commit

Permalink
Merge pull request #229 from Codium-ai/tr/sequential_improve
Browse files Browse the repository at this point in the history
Implementing Extended Improve Mode for More Thorough PR Reviews
  • Loading branch information
mrT23 authored Aug 22, 2023
2 parents b1a2e3e + 9157fa6 commit cbe0a69
Show file tree
Hide file tree
Showing 17 changed files with 337 additions and 73 deletions.
2 changes: 1 addition & 1 deletion pr_agent/algo/ai_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def deployment_id(self):

@retry(exceptions=(APIError, Timeout, TryAgain, AttributeError, RateLimitError),
tries=OPENAI_RETRIES, delay=2, backoff=2, jitter=(1, 3))
async def chat_completion(self, model: str, temperature: float, system: str, user: str):
async def chat_completion(self, model: str, system: str, user: str, temperature: float = 0.2):
"""
Performs a chat completion using the OpenAI ChatCompletion API.
Retries in case of API errors or timeouts.
Expand Down
26 changes: 16 additions & 10 deletions pr_agent/algo/git_patch_processing.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from __future__ import annotations

import logging
import re

Expand Down Expand Up @@ -157,7 +156,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
example output:
## src/file.ts
--new hunk--
__new hunk__
881 line1
882 line2
883 line3
Expand All @@ -166,7 +165,7 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
889 line6
890 line7
...
--old hunk--
__old hunk__
line1
line2
- line3
Expand All @@ -176,32 +175,38 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
...
"""

patch_with_lines_str = f"## {file.filename}\n"
import re
patch_with_lines_str = f"\n\n## {file.filename}\n"
patch_lines = patch.splitlines()
RE_HUNK_HEADER = re.compile(
r"^@@ -(\d+)(?:,(\d+))? \+(\d+)(?:,(\d+))? @@[ ]?(.*)")
new_content_lines = []
old_content_lines = []
match = None
start1, size1, start2, size2 = -1, -1, -1, -1
prev_header_line = []
header_line =[]
for line in patch_lines:
if 'no newline at end of file' in line.lower():
continue

if line.startswith('@@'):
header_line = line
match = RE_HUNK_HEADER.match(line)
if match and new_content_lines: # found a new hunk, split the previous lines
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
if prev_header_line:
patch_with_lines_str += f'\n{prev_header_line}\n'
patch_with_lines_str += '__new hunk__\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '--old hunk--\n'
patch_with_lines_str += '__old hunk__\n'
for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n"
new_content_lines = []
old_content_lines = []
if match:
prev_header_line = header_line
try:
start1, size1, start2, size2 = map(int, match.groups()[:4])
except: # '@@ -0,0 +1 @@' case
Expand All @@ -219,12 +224,13 @@ def convert_to_hunks_with_lines_numbers(patch: str, file) -> str:
# finishing last hunk
if match and new_content_lines:
if new_content_lines:
patch_with_lines_str += '\n--new hunk--\n'
patch_with_lines_str += f'\n{header_line}\n'
patch_with_lines_str += '\n__new hunk__\n'
for i, line_new in enumerate(new_content_lines):
patch_with_lines_str += f"{start2 + i} {line_new}\n"
if old_content_lines:
patch_with_lines_str += '\n--old hunk--\n'
patch_with_lines_str += '\n__old hunk__\n'
for line_old in old_content_lines:
patch_with_lines_str += f"{line_old}\n"

return patch_with_lines_str.strip()
return patch_with_lines_str.rstrip()
106 changes: 93 additions & 13 deletions pr_agent/algo/pr_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
PATCH_EXTRA_LINES = 3

def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: str,
add_line_numbers_to_hunks: bool = False, disable_extra_lines: bool = False) -> str:
add_line_numbers_to_hunks: bool = True, disable_extra_lines: bool = True) -> str:
"""
Returns a string with the diff of the pull request, applying diff minimization techniques if needed.
Expand Down Expand Up @@ -57,7 +57,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)

# generate a standard diff string, with patch extension
patches_extended, total_tokens = pr_generate_extended_diff(pr_languages, token_handler,
patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff(pr_languages, token_handler,
add_line_numbers_to_hunks)

# if we are under the limit, return the full diff
Expand All @@ -78,9 +78,9 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, model: s
return final_diff


def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> \
Tuple[list, int]:
def pr_generate_extended_diff(pr_languages: list,
token_handler: TokenHandler,
add_line_numbers_to_hunks: bool) -> Tuple[list, int, list]:
"""
Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff
minimization techniques if needed.
Expand All @@ -90,13 +90,10 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,
files.
- token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request.
- add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff.
Returns:
- patches_extended: A list of extended patches for each file in the pull request.
- total_tokens: The total number of tokens used in the extended patches.
"""
total_tokens = token_handler.prompt_tokens # initial tokens
patches_extended = []
patches_extended_tokens = []
for lang in pr_languages:
for file in lang['files']:
original_file_content_str = file.base_file
Expand All @@ -106,17 +103,18 @@ def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler,

# extend each patch with extra lines of context
extended_patch = extend_patch(original_file_content_str, patch, num_lines=PATCH_EXTRA_LINES)
full_extended_patch = f"## {file.filename}\n\n{extended_patch}\n"
full_extended_patch = f"\n\n## {file.filename}\n\n{extended_patch}\n"

if add_line_numbers_to_hunks:
full_extended_patch = convert_to_hunks_with_lines_numbers(extended_patch, file)

patch_tokens = token_handler.count_tokens(full_extended_patch)
file.tokens = patch_tokens
total_tokens += patch_tokens
patches_extended_tokens.append(patch_tokens)
patches_extended.append(full_extended_patch)

return patches_extended, total_tokens
return patches_extended, total_tokens, patches_extended_tokens


def pr_generate_compressed_diff(top_langs: list, token_handler: TokenHandler, model: str,
Expand Down Expand Up @@ -324,7 +322,9 @@ def clip_tokens(text: str, max_tokens: int) -> str:
Returns:
str: The clipped string.
"""
# We'll estimate the number of tokens by hueristically assuming 2.5 tokens per word
if not text:
return text

try:
encoder = get_token_encoder()
num_input_tokens = len(encoder.encode(text))
Expand All @@ -337,4 +337,84 @@ def clip_tokens(text: str, max_tokens: int) -> str:
return clipped_text
except Exception as e:
logging.warning(f"Failed to clip tokens: {e}")
return text
return text


def get_pr_multi_diffs(git_provider: GitProvider,
token_handler: TokenHandler,
model: str,
max_calls: int = 5) -> List[str]:
"""
Retrieves the diff files from a Git provider, sorts them by main language, and generates patches for each file.
The patches are split into multiple groups based on the maximum number of tokens allowed for the given model.
Args:
git_provider (GitProvider): An object that provides access to Git provider APIs.
token_handler (TokenHandler): An object that handles tokens in the context of a pull request.
model (str): The name of the model.
max_calls (int, optional): The maximum number of calls to retrieve diff files. Defaults to 5.
Returns:
List[str]: A list of final diff strings, split into multiple groups based on the maximum number of tokens allowed for the given model.
Raises:
RateLimitExceededException: If the rate limit for the Git provider API is exceeded.
"""
try:
diff_files = git_provider.get_diff_files()
except RateLimitExceededException as e:
logging.error(f"Rate limit exceeded for git provider API. original message {e}")
raise

# Sort files by main language
pr_languages = sort_files_by_main_languages(git_provider.get_languages(), diff_files)

# Sort files within each language group by tokens in descending order
sorted_files = []
for lang in pr_languages:
sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True))

patches = []
final_diff_list = []
total_tokens = token_handler.prompt_tokens
call_number = 1
for file in sorted_files:
if call_number > max_calls:
if get_settings().config.verbosity_level >= 2:
logging.info(f"Reached max calls ({max_calls})")
break

original_file_content_str = file.base_file
new_file_content_str = file.head_file
patch = file.patch
if not patch:
continue

# Remove delete-only hunks
patch = handle_patch_deletions(patch, original_file_content_str, new_file_content_str, file.filename)
if patch is None:
continue

patch = convert_to_hunks_with_lines_numbers(patch, file)
new_patch_tokens = token_handler.count_tokens(patch)
if patch and (total_tokens + new_patch_tokens > MAX_TOKENS[model] - OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD):
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)
patches = []
total_tokens = token_handler.prompt_tokens
call_number += 1
if get_settings().config.verbosity_level >= 2:
logging.info(f"Call number: {call_number}")

if patch:
patches.append(patch)
total_tokens += new_patch_tokens
if get_settings().config.verbosity_level >= 2:
logging.info(f"Tokens: {total_tokens}, last filename: {file.filename}")

# Add the last chunk
if patches:
final_diff = "\n".join(patches)
final_diff_list.append(final_diff)

return final_diff_list
3 changes: 2 additions & 1 deletion pr_agent/algo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,8 @@ def update_settings_from_args(args: List[str]) -> List[str]:
arg = arg.strip('-').strip()
vals = arg.split('=', 1)
if len(vals) != 2:
logging.error(f'Invalid argument format: {arg}')
if len(vals) > 2: # --extended is a valid argument
logging.error(f'Invalid argument format: {arg}')
other_args.append(arg)
continue
key, value = _fix_key_value(*vals)
Expand Down
20 changes: 14 additions & 6 deletions pr_agent/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,21 @@ def run(inargs=None):
- cli.py --pr_url=... reflect
Supported commands:
review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
ask / ask_question [question] - Ask a question about the PR.
describe / describe_pr - Modify the PR title and description based on the PR's contents.
improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
reflect - Ask the PR author questions about the PR.
update_changelog - Update the changelog based on the PR's contents.
-review / review_pr - Add a review that includes a summary of the PR and specific suggestions for improvement.
-ask / ask_question [question] - Ask a question about the PR.
-describe / describe_pr - Modify the PR title and description based on the PR's contents.
-improve / improve_code - Suggest improvements to the code in the PR as pull request comments ready to commit.
Extended mode ('improve --extended') employs several calls, and provides a more thorough feedback
-reflect - Ask the PR author questions about the PR.
-update_changelog - Update the changelog based on the PR's contents.
Configuration:
To edit any configuration parameter from 'configuration.toml', just add -config_path=<value>.
For example: 'python cli.py --pr_url=... review --pr_reviewer.extra_instructions="focus on the file: ..."'
""")
Expand Down
1 change: 1 addition & 0 deletions pr_agent/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"settings/pr_questions_prompts.toml",
"settings/pr_description_prompts.toml",
"settings/pr_code_suggestions_prompts.toml",
"settings/pr_sort_code_suggestions_prompts.toml",
"settings/pr_information_from_user_prompts.toml",
"settings/pr_update_changelog_prompts.toml",
"settings_prod/.secrets.toml"
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/bitbucket_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_repo_settings(self):
except Exception:
return ""

def publish_code_suggestions(self, code_suggestions: list):
def publish_code_suggestions(self, code_suggestions: list) -> bool:
"""
Publishes code suggestions as comments on the PR.
"""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/git_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def publish_inline_comments(self, comments: list[dict]):
pass

@abstractmethod
def publish_code_suggestions(self, code_suggestions: list):
def publish_code_suggestions(self, code_suggestions: list) -> bool:
pass

@abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/github_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def create_inline_comment(self, body: str, relevant_file: str, relevant_line_in_
def publish_inline_comments(self, comments: list[dict]):
self.pr.create_review(commit=self.last_commit_id, comments=comments)

def publish_code_suggestions(self, code_suggestions: list):
def publish_code_suggestions(self, code_suggestions: list) -> bool:
"""
Publishes code suggestions as comments on the PR.
"""
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/gitlab_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ def get_relevant_diff(self, relevant_file: str, relevant_line_in_file: int) -> O
f'No relevant diff found for {relevant_file} {relevant_line_in_file}. Falling back to last diff.')
return self.last_diff # fallback to last_diff if no relevant diff is found

def publish_code_suggestions(self, code_suggestions: list):
def publish_code_suggestions(self, code_suggestions: list) -> bool:
for suggestion in code_suggestions:
try:
body = suggestion['body']
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/git_providers/local_git_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def publish_code_suggestion(self, body: str, relevant_file: str,
relevant_lines_start: int, relevant_lines_end: int):
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')

def publish_code_suggestions(self, code_suggestions: list):
def publish_code_suggestions(self, code_suggestions: list) -> bool:
raise NotImplementedError('Publishing code suggestions is not implemented for the local git provider')

def publish_labels(self, labels):
Expand Down
2 changes: 1 addition & 1 deletion pr_agent/servers/help.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
commands_text = "> **/review [-i]**: Request a review of your Pull Request. For an incremental review, which only " \
"considers changes since the last review, include the '-i' option.\n" \
"> **/describe**: Modify the PR title and description based on the contents of the PR.\n" \
"> **/improve**: Suggest improvements to the code in the PR. \n" \
"> **/improve [--extended]**: Suggest improvements to the code in the PR. Extended mode employs several calls, and provides a more thorough feedback. \n" \
"> **/ask \\<QUESTION\\>**: Pose a question about the PR.\n" \
"> **/update_changelog**: Update the changelog based on the PR's contents.\n\n" \
">To edit any configuration parameter from **configuration.toml**, add --config_path=new_value\n" \
Expand Down
6 changes: 6 additions & 0 deletions pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,12 @@ extra_instructions = ""
[pr_code_suggestions] # /improve #
num_code_suggestions=4
extra_instructions = ""
rank_suggestions = false
# params for '/improve --extended' mode
num_code_suggestions_per_chunk=8
rank_extended_suggestions = true
max_number_of_calls = 5
final_clip_factor = 0.9

[pr_update_changelog] # /update_changelog #
push_changelog_changes=false
Expand Down
Loading

0 comments on commit cbe0a69

Please sign in to comment.