diff --git a/docs/docs/usage-guide/additional_configurations.md b/docs/docs/usage-guide/additional_configurations.md index 4ae014148..121d77b61 100644 --- a/docs/docs/usage-guide/additional_configurations.md +++ b/docs/docs/usage-guide/additional_configurations.md @@ -66,7 +66,8 @@ By default, around any change in your PR, git patch provides three lines of cont For the `review`, `describe`, `ask` and `add_docs` tools, if the token budget allows, PR-Agent tries to increase the number of lines of context, via the parameter: ``` [config] -patch_extra_lines=3 +patch_extra_lines_before=4 +patch_extra_lines_after=2 ``` Increasing this number provides more context to the model, but will also increase the token budget. diff --git a/pr_agent/algo/__init__.py b/pr_agent/algo/__init__.py index f51c4415d..f7aa6b60e 100644 --- a/pr_agent/algo/__init__.py +++ b/pr_agent/algo/__init__.py @@ -46,6 +46,7 @@ 'bedrock/anthropic.claude-3-sonnet-20240229-v1:0': 100000, 'bedrock/anthropic.claude-3-haiku-20240307-v1:0': 100000, 'bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0': 100000, + 'claude-3-5-sonnet': 100000, 'groq/llama3-8b-8192': 8192, 'groq/llama3-70b-8192': 8192, 'groq/mixtral-8x7b-32768': 32768, diff --git a/pr_agent/algo/git_patch_processing.py b/pr_agent/algo/git_patch_processing.py index 15343c97a..69e06fcf1 100644 --- a/pr_agent/algo/git_patch_processing.py +++ b/pr_agent/algo/git_patch_processing.py @@ -7,19 +7,8 @@ from pr_agent.log import get_logger -def extend_patch(original_file_str, patch_str, num_lines) -> str: - """ - Extends the given patch to include a specified number of surrounding lines. - - Args: - original_file_str (str): The original file to which the patch will be applied. - patch_str (str): The patch to be applied to the original file. - num_lines (int): The number of surrounding lines to include in the extended patch. - - Returns: - str: The extended patch string. - """ - if not patch_str or num_lines == 0: +def extend_patch(original_file_str, patch_str, patch_extra_lines_before=0, patch_extra_lines_after=0) -> str: + if not patch_str or (patch_extra_lines_before == 0 and patch_extra_lines_after == 0): return patch_str if type(original_file_str) == bytes: @@ -29,6 +18,7 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: return "" original_lines = original_file_str.splitlines() + len_original_lines = len(original_lines) patch_lines = patch_str.splitlines() extended_patch_lines = [] @@ -40,10 +30,11 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: if line.startswith('@@'): match = RE_HUNK_HEADER.match(line) if match: - # finish previous hunk - if start1 != -1: - extended_patch_lines.extend( - original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) + # finish last hunk + if start1 != -1 and patch_extra_lines_after > 0: + delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] + delta_lines = [f' {line}' for line in delta_lines] + extended_patch_lines.extend(delta_lines) res = list(match.groups()) for i in range(len(res)): @@ -55,15 +46,33 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: start1, size1, size2 = map(int, res[:3]) start2 = 0 section_header = res[4] - extended_start1 = max(1, start1 - num_lines) - extended_size1 = size1 + (start1 - extended_start1) + num_lines - extended_start2 = max(1, start2 - num_lines) - extended_size2 = size2 + (start2 - extended_start2) + num_lines + + if patch_extra_lines_before > 0 or patch_extra_lines_after > 0: + extended_start1 = max(1, start1 - patch_extra_lines_before) + extended_size1 = size1 + (start1 - extended_start1) + patch_extra_lines_after + if extended_start1 - 1 + extended_size1 > len(original_lines): + extended_size1 = len_original_lines - extended_start1 + 1 + extended_start2 = max(1, start2 - patch_extra_lines_before) + extended_size2 = size2 + (start2 - extended_start2) + patch_extra_lines_after + if extended_start2 - 1 + extended_size2 > len_original_lines: + extended_size2 = len_original_lines - extended_start2 + 1 + delta_lines = original_lines[extended_start1 - 1:start1 - 1] + delta_lines = [f' {line}' for line in delta_lines] + if section_header: + for line in delta_lines: + if section_header in line: + section_header = '' # remove section header if it is in the extra delta lines + break + else: + extended_start1 = start1 + extended_size1 = size1 + extended_start2 = start2 + extended_size2 = size2 + delta_lines = [] extended_patch_lines.append( f'@@ -{extended_start1},{extended_size1} ' f'+{extended_start2},{extended_size2} @@ {section_header}') - extended_patch_lines.extend( - original_lines[extended_start1 - 1:start1 - 1]) # one to zero based + extended_patch_lines.extend(delta_lines) # one to zero based continue extended_patch_lines.append(line) except Exception as e: @@ -71,10 +80,12 @@ def extend_patch(original_file_str, patch_str, num_lines) -> str: get_logger().error(f"Failed to extend patch: {e}") return patch_str - # finish previous hunk - if start1 != -1: - extended_patch_lines.extend( - original_lines[start1 + size1 - 1:start1 + size1 - 1 + num_lines]) + # finish last hunk + if start1 != -1 and patch_extra_lines_after > 0: + delta_lines = original_lines[start1 + size1 - 1:start1 + size1 - 1 + patch_extra_lines_after] + # add space at the beginning of each extra line + delta_lines = [f' {line}' for line in delta_lines] + extended_patch_lines.extend(delta_lines) extended_patch_str = '\n'.join(extended_patch_lines) return extended_patch_str diff --git a/pr_agent/algo/pr_processing.py b/pr_agent/algo/pr_processing.py index d635ec35f..a51820d9a 100644 --- a/pr_agent/algo/pr_processing.py +++ b/pr_agent/algo/pr_processing.py @@ -33,9 +33,11 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, large_pr_handling=False, return_remaining_files=False): if disable_extra_lines: - PATCH_EXTRA_LINES = 0 + PATCH_EXTRA_LINES_BEFORE = 0 + PATCH_EXTRA_LINES_AFTER = 0 else: - PATCH_EXTRA_LINES = get_settings().config.patch_extra_lines + PATCH_EXTRA_LINES_BEFORE = get_settings().config.patch_extra_lines_before + PATCH_EXTRA_LINES_AFTER = get_settings().config.patch_extra_lines_after try: diff_files_original = git_provider.get_diff_files() @@ -64,7 +66,8 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, # generate a standard diff string, with patch extension patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks, patch_extra_lines=PATCH_EXTRA_LINES) + pr_languages, token_handler, add_line_numbers_to_hunks, + patch_extra_lines_before=PATCH_EXTRA_LINES_BEFORE, patch_extra_lines_after=PATCH_EXTRA_LINES_AFTER) # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): @@ -72,7 +75,7 @@ def get_pr_diff(git_provider: GitProvider, token_handler: TokenHandler, f"returning full diff.") return "\n".join(patches_extended) - # if we are over the limit, start pruning + # if we are over the limit, start pruning (If we got here, we will not extend the patches with extra lines) get_logger().info(f"Tokens: {total_tokens}, total tokens over limit: {get_max_tokens(model)}, " f"pruning diff.") patches_compressed_list, total_tokens_list, deleted_files_list, remaining_files_list, file_dict, files_in_patches_list = \ @@ -174,17 +177,8 @@ def get_pr_diff_multiple_patchs(git_provider: GitProvider, token_handler: TokenH def pr_generate_extended_diff(pr_languages: list, token_handler: TokenHandler, add_line_numbers_to_hunks: bool, - patch_extra_lines: int = 0) -> Tuple[list, int, list]: - """ - Generate a standard diff string with patch extension, while counting the number of tokens used and applying diff - minimization techniques if needed. - - Args: - - pr_languages: A list of dictionaries representing the languages used in the pull request and their corresponding - files. - - token_handler: An object of the TokenHandler class used for handling tokens in the context of the pull request. - - add_line_numbers_to_hunks: A boolean indicating whether to add line numbers to the hunks in the diff. - """ + patch_extra_lines_before: int = 0, + patch_extra_lines_after: int = 0) -> Tuple[list, int, list]: total_tokens = token_handler.prompt_tokens # initial tokens patches_extended = [] patches_extended_tokens = [] @@ -196,7 +190,8 @@ def pr_generate_extended_diff(pr_languages: list, continue # extend each patch with extra lines of context - extended_patch = extend_patch(original_file_content_str, patch, num_lines=patch_extra_lines) + extended_patch = extend_patch(original_file_content_str, patch, + patch_extra_lines_before, patch_extra_lines_after) if not extended_patch: get_logger().warning(f"Failed to extend patch for file: {file.filename}") continue @@ -405,10 +400,13 @@ def get_pr_multi_diffs(git_provider: GitProvider, for lang in pr_languages: sorted_files.extend(sorted(lang['files'], key=lambda x: x.tokens, reverse=True)) - # try first a single run with standard diff string, with patch extension, and no deletions patches_extended, total_tokens, patches_extended_tokens = pr_generate_extended_diff( - pr_languages, token_handler, add_line_numbers_to_hunks=True) + pr_languages, token_handler, add_line_numbers_to_hunks=True, + patch_extra_lines_before=get_settings().config.patch_extra_lines_before, + patch_extra_lines_after=get_settings().config.patch_extra_lines_after) + + # if we are under the limit, return the full diff if total_tokens + OUTPUT_BUFFER_TOKENS_SOFT_THRESHOLD < get_max_tokens(model): return ["\n".join(patches_extended)] if patches_extended else [] diff --git a/pr_agent/git_providers/utils.py b/pr_agent/git_providers/utils.py index a0d65b668..8a9579cff 100644 --- a/pr_agent/git_providers/utils.py +++ b/pr_agent/git_providers/utils.py @@ -47,3 +47,17 @@ def apply_repo_settings(pr_url): os.remove(repo_settings_file) except Exception as e: get_logger().error(f"Failed to remove temporary settings file {repo_settings_file}", e) + + # enable switching models with a short definition + if get_settings().config.model.lower()=='claude-3-5-sonnet': + set_claude_model() + + +def set_claude_model(): + """ + set the claude-sonnet-3.5 model easily (even by users), just by stating: --config.model='claude-3-5-sonnet' + """ + model_claude = "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0" + get_settings().set('config.model', model_claude) + get_settings().set('config.model_turbo', model_claude) + get_settings().set('config.fallback_models', [model_claude]) diff --git a/pr_agent/settings/configuration.toml b/pr_agent/settings/configuration.toml index 5336a48a9..5bfc5071e 100644 --- a/pr_agent/settings/configuration.toml +++ b/pr_agent/settings/configuration.toml @@ -20,7 +20,8 @@ max_commits_tokens = 500 max_model_tokens = 32000 # Limits the maximum number of tokens that can be used by any model, regardless of the model's default capabilities. custom_model_max_tokens=-1 # for models not in the default list # -patch_extra_lines = 1 +patch_extra_lines_before = 3 # Number of extra lines (+3 default ones) to include before each hunk in the patch +patch_extra_lines_after = 1 # Number of extra lines (+3 default ones) to include after each hunk in the patch secret_provider="" cli_mode=false ai_disclaimer_title="" # Pro feature, title for a collapsible disclaimer to AI outputs @@ -96,7 +97,7 @@ enable_help_text=false [pr_code_suggestions] # /improve # -max_context_tokens=10000 +max_context_tokens=14000 num_code_suggestions=4 commitable_code_suggestions = false extra_instructions = "" diff --git a/pr_agent/tools/pr_code_suggestions.py b/pr_agent/tools/pr_code_suggestions.py index f98590ce3..1a965192a 100644 --- a/pr_agent/tools/pr_code_suggestions.py +++ b/pr_agent/tools/pr_code_suggestions.py @@ -286,7 +286,7 @@ async def _prepare_prediction(self, model: str) -> dict: self.token_handler, model, add_line_numbers_to_hunks=True, - disable_extra_lines=True) + disable_extra_lines=False) if self.patches_diff: get_logger().debug(f"PR diff", artifact=self.patches_diff) diff --git a/tests/unittest/test_extend_patch.py b/tests/unittest/test_extend_patch.py index ba0af881b..9d309822f 100644 --- a/tests/unittest/test_extend_patch.py +++ b/tests/unittest/test_extend_patch.py @@ -1,54 +1,18 @@ - -# Generated by CodiumAI - - +import pytest from pr_agent.algo.git_patch_processing import extend_patch - -""" -Code Analysis - -Objective: -The objective of the 'extend_patch' function is to extend a given patch to include a specified number of surrounding -lines. This function takes in an original file string, a patch string, and the number of lines to extend the patch by, -and returns the extended patch string. - -Inputs: -- original_file_str: a string representing the original file -- patch_str: a string representing the patch to be extended -- num_lines: an integer representing the number of lines to extend the patch by - -Flow: -1. Split the original file string and patch string into separate lines -2. Initialize variables to keep track of the current hunk's start and size for both the original file and the patch -3. Iterate through each line in the patch string -4. If the line starts with '@@', extract the start and size values for both the original file and the patch, and -calculate the extended start and size values -5. Append the extended hunk header to the extended patch lines list -6. Append the specified number of lines before the hunk to the extended patch lines list -7. Append the current line to the extended patch lines list -8. If the line is not a hunk header, append it to the extended patch lines list -9. Return the extended patch string - -Outputs: -- extended_patch_str: a string representing the extended patch - -Additional aspects: -- The function uses regular expressions to extract the start and size values from the hunk header -- The function handles cases where the start value of a hunk is less than the number of lines to extend by by setting -the extended start value to 1 -- The function handles cases where the hunk extends beyond the end of the original file by only including lines up to -the end of the original file in the extended patch -""" +from pr_agent.algo.pr_processing import pr_generate_extended_diff +from pr_agent.algo.token_handler import TokenHandler class TestExtendPatch: # Tests that the function works correctly with valid input def test_happy_path(self): original_file_str = 'line1\nline2\nline3\nline4\nline5' - patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3' + patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\n line3' num_lines = 1 - expected_output = '@@ -1,4 +1,4 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4' - actual_output = extend_patch(original_file_str, patch_str, num_lines) + expected_output = '@@ -1,4 +1,4 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4' + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output # Tests that the function returns an empty string when patch_str is empty @@ -57,14 +21,16 @@ def test_empty_patch(self): patch_str = '' num_lines = 1 expected_output = '' - assert extend_patch(original_file_str, patch_str, num_lines) == expected_output + assert extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == expected_output # Tests that the function returns the original patch when num_lines is 0 def test_zero_num_lines(self): original_file_str = 'line1\nline2\nline3\nline4\nline5' patch_str = '@@ -2,2 +2,2 @@ init()\n-line2\n+new_line2\nline3' num_lines = 0 - assert extend_patch(original_file_str, patch_str, num_lines) == patch_str + assert extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) == patch_str # Tests that the function returns the original patch when patch_str contains no hunks def test_no_hunks(self): @@ -77,17 +43,73 @@ def test_no_hunks(self): # Tests that the function extends a patch with a single hunk correctly def test_single_hunk(self): original_file_str = 'line1\nline2\nline3\nline4\nline5' - patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4' - num_lines = 1 - expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5' - actual_output = extend_patch(original_file_str, patch_str, num_lines) - assert actual_output == expected_output + patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4' + + for num_lines in [1, 2, 3]: # check that even if we are over the number of lines in the file, the function still works + expected_output = '@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5' + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) + assert actual_output == expected_output # Tests the functionality of extending a patch with multiple hunks. def test_multiple_hunks(self): original_file_str = 'line1\nline2\nline3\nline4\nline5\nline6' - patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\nline3\nline4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501 + patch_str = '@@ -2,3 +2,3 @@ init()\n-line2\n+new_line2\n line3\n line4\n@@ -4,1 +4,1 @@ init2()\n-line4\n+new_line4' # noqa: E501 num_lines = 1 - expected_output = '@@ -1,5 +1,5 @@ init()\nline1\n-line2\n+new_line2\nline3\nline4\nline5\n@@ -3,3 +3,3 @@ init2()\nline3\n-line4\n+new_line4\nline5' # noqa: E501 - actual_output = extend_patch(original_file_str, patch_str, num_lines) + expected_output = '@@ -1,5 +1,5 @@ init()\n line1\n-line2\n+new_line2\n line3\n line4\n line5\n@@ -3,3 +3,3 @@ init2()\n line3\n-line4\n+new_line4\n line5' # noqa: E501 + actual_output = extend_patch(original_file_str, patch_str, + patch_extra_lines_before=num_lines, patch_extra_lines_after=num_lines) assert actual_output == expected_output + + +class TestExtendedPatchMoreLines: + class File: + def __init__(self, base_file, patch, filename): + self.base_file = base_file + self.patch = patch + self.filename = filename + + @pytest.fixture + def token_handler(self): + # Create a TokenHandler instance with dummy data + th = TokenHandler(system="System prompt", user="User prompt") + th.prompt_tokens = 100 + return th + + @pytest.fixture + def pr_languages(self): + # Create a list of languages with files containing base_file and patch data + return [ + { + 'files': [ + self.File(base_file="line000\nline00\nline0\nline1\noriginal content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10", + patch="@@ -5,5 +5,5 @@\n-original content\n+modified content\n line2\n line3\n line4\n line5", + filename="file1"), + self.File(base_file="original content\nline2\nline3\nline4\nline5\nline6\nline7\nline8\nline9\nline10", + patch="@@ -6,5 +6,5 @@\nline6\nline7\nline8\n-line9\n+modified line9\nline10", + filename="file2") + ] + } + ] + + def test_extend_patches_with_extra_lines(self, token_handler, pr_languages): + patches_extended_no_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff( + pr_languages, token_handler, add_line_numbers_to_hunks=False, + patch_extra_lines_before=0, + patch_extra_lines_after=0 + ) + + # Check that with no extra lines, the patches are the same as the original patches + p0 = patches_extended_no_extra_lines[0].strip() + p1 = patches_extended_no_extra_lines[1].strip() + assert p0 == '## file1\n\n' + pr_languages[0]['files'][0].patch.strip() + assert p1 == '## file2\n\n' + pr_languages[0]['files'][1].patch.strip() + + patches_extended_with_extra_lines, total_tokens, patches_extended_tokens = pr_generate_extended_diff( + pr_languages, token_handler, add_line_numbers_to_hunks=False, + patch_extra_lines_before=2, + patch_extra_lines_after=1 + ) + + p0_extended = patches_extended_with_extra_lines[0].strip() + assert p0_extended == '## file1\n\n@@ -3,8 +3,8 @@ \n line0\n line1\n-original content\n+modified content\n line2\n line3\n line4\n line5\n line6'