Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge prompt code with llm response #458

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions kai/kai_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,24 +129,35 @@ def prompt(self, current_batch_count: int, prompt: str, pb_vars: dict):

@enabled_check
def llm_result(
self, current_batch_count: int, retry_count: int, result: BaseMessage
self,
current_batch_count: int,
retry_count: int,
result: BaseMessage,
output_filename: str,
):
result_file_path = os.path.join(
self.trace_dir, f"{current_batch_count}", f"{retry_count}", "llm_result"
self.trace_dir,
f"{current_batch_count}",
f"{retry_count}",
f"{output_filename}",
)
os.makedirs(os.path.dirname(result_file_path), exist_ok=True)
with open(result_file_path, "w") as f:
f.write(result.pretty_repr())
f.write(str(result))

@enabled_check
def response_metadata(
self, current_batch_count: int, retry_count: int, response_metadata: dict
self,
current_batch_count: int,
retry_count: int,
response_metadata: dict,
output_filename: str,
):
response_metadata_file_path = os.path.join(
self.trace_dir,
f"{current_batch_count}",
f"{retry_count}",
"response_metadata.json",
f"{output_filename}",
)
os.makedirs(os.path.dirname(response_metadata_file_path), exist_ok=True)
with open(response_metadata_file_path, "w") as f:
Expand Down
105 changes: 97 additions & 8 deletions kai/service/kai_application/kai_application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import logging
import re
import time
import traceback
from difflib import unified_diff
from typing import Iterator, Optional, cast

import tiktoken
Expand Down Expand Up @@ -93,6 +95,40 @@ def __init__(self, config: KaiConfig):

self.solution_consumer = solution_consumer_factory(config.solution_consumers)

def is_response_request_merge_needed(self, content: str):
comments_pattern = r"(Rest of the code remains unchanged)|(rest of the code remains the same)|(Other methods remain unchanged)|(Rest of the class remains unchanged)|(Rest of the methods remain unchanged)"
comments_pattern_matches = re.findall(comments_pattern, content, re.DOTALL)
if comments_pattern_matches:
return True

def merging_response(self, input_file: str, llm_response: str) -> str:
diff_output = unified_diff(
llm_response.splitlines(), input_file.splitlines(), lineterm=""
)
diff_lines = "\n".join(list(diff_output))
diff_line_list = diff_lines.splitlines()

updated_lines = [line for line in diff_line_list]
backtick_indices = []
commented_text = ""

commented_line_index = 1

for line in updated_lines:
if "Rest of" in line:
commented_line_index += updated_lines.index(line)
commented_text += line
if "```" in line:
backtick_indices.append(updated_lines.index(line))

cleaned_comment_text = commented_text.replace("-", "").strip()
extracted_code = ""
for i in range(commented_line_index, backtick_indices[1]):
if "+" in updated_lines[i]:
extracted_code += updated_lines[i].replace("+", "") + "\n"

return llm_response.replace(cleaned_comment_text, extracted_code)

def estimating_prompt_tokens(self, prompt: str) -> int:
try:
enc = tiktoken.encoding_for_model(self.tiktoken_encoding_base)
Expand Down Expand Up @@ -238,26 +274,79 @@ def get_incident_solutions_for_file(
application_name,
f'{file_name.replace("/", "-")}',
):
llm_result = self.model_provider.llm.invoke(prompt)
trace.llm_result(count, retry_attempt_count, llm_result)
llm_request = [("human", prompt)]
llm_result = self.model_provider.llm.invoke(llm_request)
content = parse_file_solution_content(
src_file_language, str(llm_result.content)
)

# The LLM response must include code blocks (formatted within triple backticks) to be considered complete. Usually, the LM responds with code blocks, but occasionally it fails to do so, as noted in issue #350 [https://github.com/konveyor/kai/issues/350] . Complete responses are saved in the trace directory directly. For incomplete responses, an additional prompt is sent to the LLM, and the resulting complete response (with code blocks) is saved in the trace directory as a new file.
if len(content.updated_file) == 0:
trace.llm_result(
count,
retry_attempt_count,
llm_result.content,
"llm_result_without_codeblocks",
)
trace.response_metadata(
count,
retry_attempt_count,
llm_result.response_metadata,
"response_metadata_without_codeblocks.json",
)
self.has_tokens_exceeded(
llm_result.response_metadata,
estimated_prompt_tokens,
file_name,
)
llm_request.append(
(
"human",
"I request you to generate a complete response.",
)
)
llm_result = self.model_provider.llm.invoke(llm_request)
content = parse_file_solution_content(
src_file_language, str(llm_result.content)
)

trace.llm_result(
count,
retry_attempt_count,
llm_result.content,
"llm_result_with_codeblocks",
)
trace.response_metadata(
count,
retry_attempt_count,
llm_result.response_metadata,
"response_metadata_with_codeblocks.json",
)
trace.estimated_tokens(
count,
retry_attempt_count,
estimated_prompt_tokens,
self.tiktoken_encoding_base,
)
trace.response_metadata(
count, retry_attempt_count, llm_result.response_metadata
)
self.has_tokens_exceeded(
llm_result.response_metadata,
estimated_prompt_tokens,
file_name,
)

content = parse_file_solution_content(
src_file_language, str(llm_result.content)
)
if self.is_response_request_merge_needed(
str(llm_result.content)
):
KAI_LOG.warning("This file contains unnecessary comments.")
new_llm_response = self.merging_response(
prompt, str(llm_result.content)
)
trace.llm_result(
count,
retry_attempt_count,
new_llm_response,
"llm_result_with_no_unnecessary_comments",
)

if not content.updated_file:
raise Exception(
Expand Down