Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add test command #507

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def get_all_features(
split_intervals: bool = True,
) -> list[CodeFeature]:
"""
Retrieves every CodeFeature under the cwd. If files_only is True the features won't be split into intervals
Retrieves every CodeFeature under the cwd. If split_intervals is False the features won't be split
"""
session_context = SESSION_CONTEXT.get()

Expand Down
5 changes: 5 additions & 0 deletions mentat/code_file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ async def write_changes_to_files(
self.history.add_edit(applied_edit)
if not agent_handler.agent_enabled:
self.history.push_edits()

stream.send(
"Changes applied." if applied_edits else "No changes applied.",
style="input",
)
return applied_edits

def get_file_checksum(self, path: Path, interval: Interval | None = None) -> str:
Expand Down
1 change: 1 addition & 0 deletions mentat/command/commands/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from .screenshot import ScreenshotCommand
from .search import SearchCommand
from .talk import TalkCommand
from .test import TestCommand
from .undo import UndoCommand
from .undoall import UndoAllCommand
from .viewer import ViewerCommand
123 changes: 123 additions & 0 deletions mentat/command/commands/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
from pathlib import Path
from typing import List, Set

from openai.types.chat import (
ChatCompletionAssistantMessageParam,
ChatCompletionMessageParam,
ChatCompletionSystemMessageParam,
)
from typing_extensions import override

from mentat.command.command import Command, CommandArgument
from mentat.git_handler import get_git_diff
from mentat.include_files import get_code_features_for_path
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT
from mentat.transcripts import ModelMessage
from mentat.utils import get_relative_path

test_selection_prompt_path = Path("test_selection_prompt.txt")
test_selection_prompt = read_prompt(test_selection_prompt_path)
test_selection_prompt_2_path = Path("test_selection_prompt_2.txt")
test_selection_prompt_2 = read_prompt(test_selection_prompt_2_path)


class TestCommand(Command, command_name="test"):
@override
async def apply(self, *args: str) -> None:
ctx = SESSION_CONTEXT.get()

target = args[0] if args else "main"
diff = get_git_diff(target)

features = ctx.code_context.get_all_features(split_intervals=False)
messages: List[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
role="system", content=test_selection_prompt
),
ChatCompletionSystemMessageParam(
role="system",
content="\n".join(
str(feature.path.relative_to(ctx.cwd)) for feature in features
),
),
ChatCompletionSystemMessageParam(role="system", content=diff),
]
response = await ctx.llm_api_handler.call_llm_api(
messages, model=ctx.config.model, stream=False
)
message = response.choices[0].message.content or ""
messages.append(
ChatCompletionAssistantMessageParam(content=message, role="assistant")
)

if message.strip() != "NO FILES NEEDED":
included: Set[Path] = set()
for line in message.split("\n"):
included.update(ctx.code_context.include(line))
for included_path in included:
rel_path = get_relative_path(included_path, ctx.cwd)
ctx.stream.send(f"{rel_path} added to context", style="success")

messages.append(
ChatCompletionSystemMessageParam(
role="system", content=test_selection_prompt_2
)
)
response = await ctx.llm_api_handler.call_llm_api(
messages, model=ctx.config.model, stream=False
)
message = response.choices[0].message.content or ""
messages.append(
ChatCompletionAssistantMessageParam(content=message, role="assistant")
)
ctx.conversation.add_transcript_message(
ModelMessage(message=message, prior_messages=messages, message_type="test")
)

all_tests: List[str] = []
if message.strip() != "NO TESTS FOUND":
for line in message.split("\n"):
all_tests += [
feature.rel_path(ctx.cwd)
for feature in get_code_features_for_path(Path(line), ctx.cwd)
]

ctx.conversation.add_user_message(
"You will be given both a list of all existing test files in this"
" repository as well as a git diff of a recent PR. Create a set of"
" comprehensive tests for this PR if they don't already exist. Keep your"
" changes laser focused! Only make tests for larger, broader additions or"
" changes."
)
ctx.conversation.add_message(
ChatCompletionSystemMessageParam(
role="system", content="All test files:\n" + "\n".join(all_tests)
)
)
ctx.conversation.add_message(
ChatCompletionSystemMessageParam(role="system", content=diff)
)
parsed_llm_response = await ctx.conversation.get_model_response()
file_edits = parsed_llm_response.file_edits
await ctx.code_file_manager.write_changes_to_files(file_edits)

@override
@classmethod
def arguments(cls) -> List[CommandArgument]:
return [CommandArgument("optional", "git tree-ish")]

@override
@classmethod
def argument_autocompletions(
cls, arguments: list[str], argument_position: int
) -> list[str]:
return []

@override
@classmethod
def help_message(cls) -> str:
return (
"Write tests for a PR using the diff to the given branch or commit."
" Defaults to main."
)
5 changes: 5 additions & 0 deletions mentat/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,11 @@ async def get_model_response(self) -> ParsedLLMResponse:
finally:
if loading_multiplier:
stream.send(None, channel="loading", terminate=True)
response.file_edits = [
file_edit for file_edit in response.file_edits if file_edit.is_valid()
]
for file_edit in response.file_edits:
file_edit.resolve_conflicts()
return response

def remaining_context(self) -> int | None:
Expand Down
7 changes: 7 additions & 0 deletions mentat/resources/prompts/test_selection_prompt.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
You are part of an automated coding system. Your responses must follow the required format so they can be parsed programmaticaly.
Your job is to create a comprehensive test suite for for a given PR in a coding project. In order to do this, you will need information on existing tests and configuration files in the project.
You will be given a list of all files in a software project and a git diff of the given PR.
Output a newline separated list of files that will encompass any files whose content you believe will be needed to create a test suite.
Make sure to include any tests for files changed in this PR and any files needed to help create new tests.
Context is very limited!!! Only include at most 10 files!!!
If there are no files needed, output NO FILES NEEDED.
2 changes: 2 additions & 0 deletions mentat/resources/prompts/test_selection_prompt_2.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Next, output a newline separated list of files, directories, and glob patterns that will encompass all test files in the project (for example: src/tests, or **/*.spec).
If there are no tests in the project, output NO TESTS FOUND.
4 changes: 4 additions & 0 deletions mentat/resources/templates/css/transcript.css
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ pre {
background-color: rgb(243, 246, 196);
}

.test {
background-color: rgb(246, 196, 234);
}

.button-group {
position: absolute;
top: 20px;
Expand Down
16 changes: 2 additions & 14 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,13 +169,7 @@ async def _main(self):
conversation.add_user_message(message.data)

parsed_llm_response = await conversation.get_model_response()
file_edits = [
file_edit
for file_edit in parsed_llm_response.file_edits
if file_edit.is_valid()
]
for file_edit in file_edits:
file_edit.resolve_conflicts()
file_edits = parsed_llm_response.file_edits
if file_edits:
if session_context.config.revisor:
await revise_edits(file_edits)
Expand All @@ -188,13 +182,7 @@ async def _main(self):
if session_context.config.sampler:
session_context.sampler.set_active_diff()

applied_edits = await code_file_manager.write_changes_to_files(
file_edits
)
stream.send(
"Changes applied." if applied_edits else "No changes applied.",
style="input",
)
await code_file_manager.write_changes_to_files(file_edits)

if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
Expand Down
Loading