Skip to content

Commit

Permalink
Merge pull request Codium-ai#644 from Codium-ai/tr/parallel_calls
Browse files Browse the repository at this point in the history
Tr/parallel calls
  • Loading branch information
mrT23 authored Feb 7, 2024
2 parents 9002796 + df4d963 commit adec333
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 13 deletions.
3 changes: 2 additions & 1 deletion pr_agent/settings/configuration.toml
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,9 @@ enable_help_text=true
# params for '/improve --extended' mode
auto_extended_mode=true
num_code_suggestions_per_chunk=5
rank_extended_suggestions = false
max_number_of_calls = 3
parallel_calls = true
rank_extended_suggestions = false
final_clip_factor = 0.8

[pr_add_docs] # /add_docs #
Expand Down
29 changes: 17 additions & 12 deletions pr_agent/tools/pr_code_suggestions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import copy
import textwrap
from functools import partial
Expand Down Expand Up @@ -111,18 +112,18 @@ async def run(self):

async def _prepare_prediction(self, model: str):
get_logger().info('Getting PR diff...')
self.patches_diff = get_pr_diff(self.git_provider,
patches_diff = get_pr_diff(self.git_provider,
self.token_handler,
model,
add_line_numbers_to_hunks=True,
disable_extra_lines=True)

get_logger().info('Getting AI prediction...')
self.prediction = await self._get_prediction(model)
self.prediction = await self._get_prediction(model, patches_diff)

async def _get_prediction(self, model: str):
async def _get_prediction(self, model: str, patches_diff: str):
variables = copy.deepcopy(self.vars)
variables["diff"] = self.patches_diff # update diff
variables["diff"] = patches_diff # update diff
environment = Environment(undefined=StrictUndefined)
system_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.system).render(variables)
user_prompt = environment.from_string(get_settings().pr_code_suggestions_prompt.user).render(variables)
Expand Down Expand Up @@ -229,14 +230,18 @@ async def _prepare_prediction_extended(self, model: str) -> dict:
patches_diff_list = get_pr_multi_diffs(self.git_provider, self.token_handler, model,
max_calls=get_settings().pr_code_suggestions.max_number_of_calls)

get_logger().info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
self.patches_diff = patches_diff
prediction = await self._get_prediction(model) # toDo: parallelize
prediction_list.append(prediction)
self.prediction_list = prediction_list
# parallelize calls to AI:
if get_settings().pr_code_suggestions.parallel_calls:
get_logger().info('Getting multi AI predictions in parallel...')
prediction_list = await asyncio.gather(*[self._get_prediction(model, patches_diff) for patches_diff in patches_diff_list])
self.prediction_list = prediction_list
else:
get_logger().info('Getting multi AI predictions...')
prediction_list = []
for i, patches_diff in enumerate(patches_diff_list):
get_logger().info(f"Processing chunk {i + 1} of {len(patches_diff_list)}")
prediction = await self._get_prediction(model, patches_diff)
prediction_list.append(prediction)

data = {}
for prediction in prediction_list:
Expand Down

0 comments on commit adec333

Please sign in to comment.