Skip to content

Commit

Permalink
Factor out some functions into llm-utils
Browse files Browse the repository at this point in the history
  • Loading branch information
nicovank committed Oct 4, 2023
1 parent 42a621f commit 53a1c55
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 91 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
{ name="Nicolas van Kempen", email="nvankemp@gmail.com" },
{ name="Bryce Adelstein Lelbach", email="brycelelbach@gmail.com" }
]
dependencies = ["openai>=0.27.0", "tiktoken>=0.4.0"]
dependencies = ["openai==0.28.1", "llm_utils==0.1.4"]
description = "Explains and proposes fixes for compile-time errors for many programming languages."
readme = "README.md"
requires-python = ">=3.7"
Expand Down
98 changes: 8 additions & 90 deletions src/cwhy/cwhy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,91 +7,7 @@
from typing import Dict, List, Tuple

import openai
import tiktoken


def word_wrap_except_code_blocks(text: str) -> str:
"""
Wraps text except for code blocks.
Splits the text into paragraphs and wraps each paragraph,
except for paragraphs that are inside of code blocks denoted
by ` ``` `. Returns the updated text.
Args:
text: The text to wrap.
Returns:
The wrapped text.
"""
# Split text into paragraphs
paragraphs = text.split("\n\n")
wrapped_paragraphs = []
# Check if currently in a code block.
in_code_block = False
# Loop through each paragraph and apply appropriate wrapping.
for paragraph in paragraphs:
# If this paragraph starts and ends with a code block, add it as is.
if paragraph.startswith("```") and paragraph.endswith("```"):
wrapped_paragraphs.append(paragraph)
continue
# If this is the beginning of a code block add it as is.
if paragraph.startswith("```"):
in_code_block = True
wrapped_paragraphs.append(paragraph)
continue
# If this is the end of a code block stop skipping text.
if paragraph.endswith("```"):
in_code_block = False
wrapped_paragraphs.append(paragraph)
continue
# If we are currently in a code block add the paragraph as is.
if in_code_block:
wrapped_paragraphs.append(paragraph)
else:
# Otherwise, apply text wrapping to the paragraph.
wrapped_paragraph = textwrap.fill(paragraph)
wrapped_paragraphs.append(wrapped_paragraph)
# Join all paragraphs into a single string
wrapped_text = "\n\n".join(wrapped_paragraphs)
return wrapped_text


def read_lines(file_path, start_line, end_line):
"""
Read lines from a file.
Args:
file_path (str): The path of the file to read.
start_line (int): The line number of the first line to include (1-indexed). Will be bounded below by 0.
end_line (int): The line number of the last line to include (1-indexed). Will be bounded above by file's line count.
Returns:
The lines read as an array and the number of the first line included.
Raises:
FileNotFoundError: If the file does not exist.
"""
max_chars_per_line = 128 # Prevent pathological case where lines are REALLY long.

def truncate(s, l):
"""
Truncate the string to at most the given length, adding ellipses if truncated.
"""
if len(s) < l:
return s
else:
return s[:l] + "..."

with open(file_path, "r") as f:
lines = f.readlines()
lines = [truncate(line.rstrip(), max_chars_per_line) for line in lines]

# Ensure indices are in range.
start_line = max(1, start_line)
end_line = min(len(lines), end_line)

return (lines[start_line - 1 : end_line], start_line)
from llm_utils import llm_utils


def complete(args, user_prompt, **kwargs):
Expand Down Expand Up @@ -188,7 +104,7 @@ def evaluate_text_prompt(args, prompt, wrap=True, **kwargs):
completion = complete(args, prompt, **kwargs)
text = completion.choices[0].message.content
if wrap:
text = word_wrap_except_code_blocks(text)
text = llm_utils.word_wrap_except_code_blocks(text)
return text


Expand Down Expand Up @@ -231,7 +147,6 @@ def evaluate_text_prompt(args, prompt, wrap=True, **kwargs):
class explain_context:
def __init__(self, args, diagnostic):
self.args = args
self.encoding = tiktoken.encoding_for_model(args["llm"])
self.diagnostic_lines = diagnostic.splitlines()

# We group by source file.
Expand All @@ -254,7 +169,7 @@ def __init__(self, args, diagnostic):
continue

try:
(abridged_code, line_start) = read_lines(
(abridged_code, line_start) = llm_utils.read_lines(
file_name, line_number - 7, line_number + 3
)
except FileNotFoundError:
Expand Down Expand Up @@ -292,7 +207,7 @@ def build_diagnostic_string():
line = self.diagnostic_lines[n - i // 2 - 1]
list = back
list.append(line)
count = len(self.encoding.encode(build_diagnostic_string()))
count = llm_utils.count_tokens(self.args["llm"], build_diagnostic_string())
if count > self.args["max_error_tokens"]:
list.pop()
break
Expand Down Expand Up @@ -368,7 +283,10 @@ def format_file_locations(filename: str, lines: Dict[int, str]) -> str:
for filename, lines in self.code_locations.items()
]

counts = [len(self.encoding.encode(x)) for x in formatted_file_locations]
counts = [
llm_utils.count_tokens(self.args["llm"], x)
for x in formatted_file_locations
]
index = 0
total = 0
while (
Expand Down

0 comments on commit 53a1c55

Please sign in to comment.