Skip to content

Commit

Permalink
Refactor prompt configuration and improve output formatting
Browse files Browse the repository at this point in the history
The prompt type for each parser has been moved from hardcoded value within the parser to a centralized configuration through 'mentat.config.py'. This improves maintainability by centralizing configuration. Additionally, the use of 'rich' library has been expanded in 'session.py' and 'code_context.py' to improve readability and color-coding of messages.
  • Loading branch information
use-the-fork committed Dec 27, 2023
1 parent e6f2596 commit c2a935e
Show file tree
Hide file tree
Showing 16 changed files with 58 additions and 33 deletions.
5 changes: 3 additions & 2 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
from mentat.session_context import SESSION_CONTEXT
from mentat.session_input import ask_yes_no, collect_user_input
from mentat.transcripts import ModelMessage
from mentat.config import config

agent_file_selection_prompt_path = Path("markdown/agent_file_selection_prompt.md")
agent_command_prompt_path = Path("markdown/agent_command_selection_prompt.md")
agent_file_selection_prompt_path = config.ai.prompts.get("agent_file_selection_prompt")
agent_command_prompt_path = config.ai.prompts.get("agent_command_selection_prompt")


class AgentHandler:
Expand Down
22 changes: 8 additions & 14 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,11 @@ def __init__(
def display_context(self):
"""Display the baseline context: included files and auto-context settings"""
session_context = SESSION_CONTEXT.get()
stream = session_context.stream


print("[blue]Code Context:[/blue]")
prefix = " "
stream.send(f"{prefix}Directory: {session_context.cwd}")
print(f"{prefix}Directory: {session_context.cwd}")
if self.diff_context and self.diff_context.name:
print(f"{prefix}Diff: [green]{self.diff_context.get_display_context()}[/green]")

Expand All @@ -83,11 +82,11 @@ def display_context(self):

features = None
if self.features:
stream.send(f"{prefix}Active Features:")
print(f"{prefix}Active Features:")
features = self.features
elif self.include_files:
stream.send(f"{prefix}Included files:")
stream.send(f"{prefix + prefix}{session_context.cwd.name}")
print(f"{prefix}Included files:")
print(f"{prefix + prefix}{session_context.cwd.name}")
features = [
_feat for _file in self.include_files.values() for _feat in _file
]
Expand Down Expand Up @@ -354,7 +353,7 @@ def include(
exclude_patterns=abs_exclude_patterns,
)
except PathValidationError as e:
session_context.stream.send(str(e), color="light_red")
print(f"[red]{str(e)}[/red]")
return included_paths

for code_feature in code_features:
Expand Down Expand Up @@ -382,20 +381,15 @@ def _exclude_file(self, path: Path) -> Path | None:
del self.include_files[path]
return path
else:
session_context.stream.send(
f"Path {path} not in context", color="light_red"
)
print(f"[red]Path {path} not in context[/red]")

def _exclude_file_interval(self, path: Path) -> Set[Path]:
session_context = SESSION_CONTEXT.get()

excluded_paths: Set[Path] = set()

interval_path, interval_str = split_intervals_from_path(path)
if interval_path not in self.include_files:
session_context.stream.send(
f"Path {interval_path} not in context", color="light_red"
)
print(f"[red]Path {interval_path} not in context[/red]")
return excluded_paths

intervals = parse_intervals(interval_str)
Expand Down Expand Up @@ -478,7 +472,7 @@ def exclude(self, path: Path | str) -> Set[Path]:
case PathType.GLOB:
excluded_paths.update(self._exclude_glob(validated_path))
except PathValidationError as e:
session_context.stream.send(str(e), color="light_red")
print(f"[red]Path {str(e)}[/red]")

return excluded_paths

Expand Down
39 changes: 33 additions & 6 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,15 @@
import yaml
import shutil

from dataclasses import asdict

from mentat.git_handler import get_git_root_for_path
from mentat.parsers.parser_map import parser_map
from mentat.parsers.block_parser import BlockParser
from mentat.utils import mentat_dir_path, dd
from mentat.utils import mentat_dir_path
from dataclasses import dataclass, field
from dataclasses_json import DataClassJsonMixin
from typing import Optional, List, Tuple
from typing import Tuple
from mentat.parsers.parser import Parser
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from typing import Any, Dict, List, Optional


config_file_name = Path(".mentat_config.yaml")
Expand Down Expand Up @@ -45,6 +43,7 @@ class AIModelSettings(DataClassJsonMixin):
model: str = "gpt-4-1106-preview"
feature_selection_model: str = "gpt-4-1106-preview"
embedding_model: str = "text-embedding-ada-002"
prompts: Dict[str, Path] = None
temperature: float = 0.2

maximum_context: Optional[int] = None
Expand Down Expand Up @@ -93,6 +92,7 @@ def yaml_to_config(yaml_dict: dict):

return {
"model": yaml_dict.get("model"),
"prompt_type": yaml_dict.get("prompt_type", "text"),
"maximum_context": yaml_dict.get("maximum_context"),
"file_exclude_glob_list": yaml_dict.get("file_exclude_glob_list", []),
"input_style": yaml_dict.get("input_style"),
Expand All @@ -108,6 +108,27 @@ def init_config():
shutil.copy(default_conf_path, current_conf_path)


def load_prompts(prompt_type: str):

if prompt_type == "markdown":
return {
"agent_file_selection_prompt" : Path("markdown/agent_file_selection_prompt.md"),
"agent_command_selection_prompt" : Path("markdown/agent_command_selection_prompt.md"),
"block_parser_prompt" : Path("markdown/block_parser_prompt.md"),
"feature_selection_prompt" : Path("markdown/feature_selection_prompt.md"),
"replacement_parser_prompt" : Path("markdown/replacement_parser_prompt.md"),
"unified_diff_parser_prompt" : Path("markdown/unified_diff_parser_prompt.md"),
}

return {
"agent_file_selection_prompt": Path("text/agent_file_selection_prompt.txt"),
"agent_command_prompt": Path("text/agent_command_selection_prompt.txt"),
"block_parser_prompt": Path("text/block_parser_prompt.txt"),
"feature_selection_prompt": Path("text/feature_selection_prompt.txt"),
"replacement_parser_prompt": Path("text/replacement_parser_prompt.txt"),
"unified_diff_parser_prompt": Path("text/unified_diff_parser_prompt.txt"),
}

def load_settings():
"""Load the configuration from the `.mentatconf.yaml` file."""

Expand All @@ -134,8 +155,13 @@ def load_settings():
current_path_config = yaml_to_config(yaml_dict)
yaml_config = merge_configs(yaml_config, current_path_config)

file_exclude_glob_list = yaml_config.get("file_exclude_glob_list", [])

#always ignore .mentatconf
file_exclude_glob_list.append(".mentatconf.yaml")

run_settings = RunSettings(
file_exclude_glob_list=[Path(p) for p in yaml_config.get("file_exclude_glob_list", [])]
file_exclude_glob_list=[Path(p) for p in file_exclude_glob_list]
)

ui_settings = UISettings(
Expand All @@ -144,6 +170,7 @@ def load_settings():

ai_model_settings = AIModelSettings(
model=yaml_config.get("model", "gpt-4-1106-preview"),
prompts=load_prompts(yaml_config.get("prompt_type", "text")),
feature_selection_model=yaml_config.get("model", "gpt-4-1106-preview"),
maximum_context=yaml_config.get("maximum_context", 16000)
)
Expand Down
3 changes: 2 additions & 1 deletion mentat/feature_filters/llm_feature_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
CodeMessageLevel,
get_code_message_from_features,
)
from mentat.config import config
from mentat.errors import ModelError, UserError
from mentat.feature_filters.feature_filter import FeatureFilter
from mentat.feature_filters.truncate_filter import TruncateFilter
Expand All @@ -23,7 +24,7 @@


class LLMFeatureFilter(FeatureFilter):
feature_selection_prompt_path = Path("markdown/feature_selection_prompt.md")
feature_selection_prompt_path = config.ai.prompts.get("feature_selection_prompt")

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions mentat/parsers/block_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT

block_parser_prompt_filename = Path("markdown/block_parser_prompt.md")


class _BlockParserAction(Enum):
Insert = "insert"
Expand Down Expand Up @@ -71,6 +69,8 @@ def __init__(self, json_data: dict[str, Any]):
class BlockParser(Parser):
@override
def get_system_prompt(self) -> str:
from mentat.config import config
block_parser_prompt_filename = config.ai.prompts.get("block_parser_prompt")
return read_prompt(block_parser_prompt_filename)

@override
Expand Down
3 changes: 2 additions & 1 deletion mentat/parsers/json_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from mentat.session_context import SESSION_CONTEXT
from mentat.streaming_printer import StreamingPrinter

json_parser_prompt_filename = Path("markdown/json_parser_prompt.md")

comment_schema = {
"type": "object",
Expand Down Expand Up @@ -84,6 +83,8 @@
class JsonParser(Parser):
@override
def get_system_prompt(self) -> str:
from mentat.config import config
json_parser_prompt_filename = config.ai.prompts.get("json_parser_prompt")
return read_prompt(json_parser_prompt_filename)

@override
Expand Down
4 changes: 2 additions & 2 deletions mentat/parsers/replacement_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
from mentat.prompts.prompts import read_prompt
from mentat.session_context import SESSION_CONTEXT

replacement_parser_prompt_filename = Path("markdown/replacement_parser_prompt.md")


class ReplacementParser(Parser):
@override
def get_system_prompt(self) -> str:
from mentat.config import config
replacement_parser_prompt_filename = config.ai.prompts.get("replacement_parser_prompt")
return read_prompt(replacement_parser_prompt_filename)

@override
Expand Down
3 changes: 2 additions & 1 deletion mentat/parsers/unified_diff_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from mentat.parsers.parser import Parser
from mentat.prompts.prompts import read_prompt

unified_diff_parser_prompt_filename = Path("markdown/unified_diff_parser_prompt.md")


class UnifiedDiffDelimiter(Enum):
Expand All @@ -29,6 +28,8 @@ class UnifiedDiffDelimiter(Enum):
class UnifiedDiffParser(Parser):
@override
def get_system_prompt(self) -> str:
from mentat.config import config
unified_diff_parser_prompt_filename = config.ai.prompts.get("unified_diff_parser_prompt")
return read_prompt(unified_diff_parser_prompt_filename)

@override
Expand Down
8 changes: 4 additions & 4 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Any, Coroutine, List, Optional, Set
from uuid import uuid4

import attr
from rich import print
import sentry_sdk
from openai import APITimeoutError, BadRequestError, RateLimitError

Expand Down Expand Up @@ -222,15 +222,15 @@ async def run_main():
except (SessionExit, CancelledError):
pass
except (MentatError, UserError) as e:
self.stream.send(str(e), color="red")
print(f"[red]{str(e)}[/red]")
except Exception as e:
# All unhandled exceptions end up here
error = f"Unhandled Exception: {traceback.format_exc()}"
# Helps us handle errors in tests
if is_test_environment():
print(error)
sentry_sdk.capture_exception(e)
self.stream.send(error, color="red")
print(f"[red]{str(e)}[/red]")
finally:
await self._stop()
sentry_sdk.flush()
Expand Down Expand Up @@ -270,5 +270,5 @@ def send_errors_to_stream(self):
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
for error in self._errors:
stream.send(error, color="light_yellow")
print(f"[light_yellow3]{error}[/light_yellow3]")
self._errors = []

0 comments on commit c2a935e

Please sign in to comment.