Skip to content

Commit

Permalink
Refactor code to use global config object
Browse files Browse the repository at this point in the history
The code has been refactored to use a global configuration object instead of a local one. This change standardizes how the config is accessed across multiple modules and simplifies the code by reducing redundant variable assignments. Along with this, color print formatting has been updated to use the 'rich' module's syntax.
  • Loading branch information
use-the-fork committed Dec 27, 2023
1 parent 8611ffa commit eb9725d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 48 deletions.
31 changes: 11 additions & 20 deletions mentat/agent_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from mentat.session_input import ask_yes_no, collect_user_input
from mentat.transcripts import ModelMessage
from mentat.config import config
from rich import print

agent_file_selection_prompt_path = config.ai.prompts.get("agent_file_selection_prompt")
agent_command_prompt_path = config.ai.prompts.get("agent_command_selection_prompt")
Expand All @@ -39,9 +40,7 @@ def disable_agent_mode(self):
async def enable_agent_mode(self):
ctx = SESSION_CONTEXT.get()

ctx.stream.send(
"Finding files to determine how to test changes...", color="cyan"
)
print(f"* [cyan]Finding files to determine how to test changes...[/cyan]")
features = ctx.code_context.get_all_features(split_intervals=False)
messages: List[ChatCompletionMessageParam] = [
ChatCompletionSystemMessageParam(
Expand All @@ -66,11 +65,8 @@ async def enable_agent_mode(self):
file_contents = "\n\n".join(ctx.code_file_manager.read_file(path))
self.agent_file_message += f"{path}\n\n{file_contents}"

ctx.stream.send(
"The model has chosen these files to help it determine how to test its"
" changes:",
color="cyan",
)
print(f"[cyan]The model has chosen these files to help it determine how to test its changes:[/cyan]")

ctx.stream.send("\n".join(str(path) for path in paths))
ctx.cost_tracker.display_last_api_call()

Expand Down Expand Up @@ -107,7 +103,7 @@ async def _determine_commands(self) -> List[str]:
response = await ctx.llm_api_handler.call_llm_api(messages, model, False)
ctx.cost_tracker.display_last_api_call()
except BadRequestError as e:
ctx.stream.send(f"Error accessing OpenAI API: {e.message}", color="red")
print(f"[red]Error accessing OpenAI API: {e.message}[/red]")
return []

content = response.choices[0].message.content or ""
Expand All @@ -129,20 +125,15 @@ async def add_agent_context(self) -> bool:
commands = await self._determine_commands()
if not commands:
return True
ctx.stream.send(
"The model has chosen these commands to test its changes:", color="cyan"
)
print(f"[cyan]The model has chosen these commands to test its changes:[/cyan]")

for command in commands:
ctx.stream.send("* ", end="")
ctx.stream.send(command, color="light_yellow")
ctx.stream.send("Run these commands?", color="cyan")
print(f"* [yellow]{command}[/yellow]")

print(f"* [cyan]Run these commands?[/cyan]")
run_commands = await ask_yes_no(default_yes=True)
if not run_commands:
ctx.stream.send(
"Enter a new-line separated list of commands to run, or nothing to"
" return control to the user:",
color="cyan",
)
print(f"* [cyan]Enter a new-line separated list of commands to run, or nothing to return control to the user:[/cyan]")
commands: list[str] = (await collect_user_input()).data.strip().splitlines()
if not commands:
return True
Expand Down
20 changes: 9 additions & 11 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from mentat.llm_api_handler import count_tokens, get_max_tokens, is_context_sufficient
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream
from mentat.config import config


class CodeContext:
Expand Down Expand Up @@ -60,7 +61,6 @@ def display_context(self):
"""Display the baseline context: included files and auto-context settings"""
session_context = SESSION_CONTEXT.get()
stream = session_context.stream
config = session_context.config

stream.send("Code Context:", color="blue")
prefix = " "
Expand All @@ -69,9 +69,9 @@ def display_context(self):
stream.send(f"{prefix}Diff:", end=" ")
stream.send(self.diff_context.get_display_context(), color="green")

if config.auto_context_tokens > 0:
if config.run.auto_context_tokens > 0:
stream.send(f"{prefix}Auto-Context: Enabled")
stream.send(f"{prefix}Auto-Context Tokens: {config.auto_context_tokens}")
stream.send(f"{prefix}Auto-Context Tokens: {config.run.auto_context_tokens}")
else:
stream.send(f"{prefix}Auto-Context: Disabled")

Expand Down Expand Up @@ -120,9 +120,7 @@ async def get_code_message(
'prompt_tokens' argument is the total number of tokens used by the prompt before the code message,
used to ensure that the code message won't overflow the model's context size
"""
session_context = SESSION_CONTEXT.get()
config = session_context.config
model = config.model
model = config.ai.model

# Setup code message metadata
code_message = list[str]()
Expand Down Expand Up @@ -151,14 +149,14 @@ async def get_code_message(
)

tokens_used = (
prompt_tokens + meta_tokens + include_files_tokens + config.token_buffer
prompt_tokens + meta_tokens + include_files_tokens + config.ai.token_buffer
)
if not is_context_sufficient(tokens_used):
raise ContextSizeInsufficient()
auto_tokens = min(get_max_tokens() - tokens_used, config.auto_context_tokens)
auto_tokens = min(get_max_tokens() - tokens_used, config.run.auto_context_tokens)

# Get auto included features
if config.auto_context_tokens > 0 and prompt:
if config.run.auto_context_tokens > 0 and prompt:
features = self.get_all_features()
feature_filter = DefaultFilter(
auto_tokens,
Expand Down Expand Up @@ -190,7 +188,7 @@ def get_all_features(

abs_exclude_patterns: Set[Path] = set()
for pattern in self.ignore_patterns.union(
session_context.config.file_exclude_glob_list
config.run.file_exclude_glob_list
):
if not Path(pattern).is_absolute():
abs_exclude_patterns.add(session_context.cwd / pattern)
Expand Down Expand Up @@ -278,7 +276,7 @@ def include(
[
*exclude_patterns,
*self.ignore_patterns,
*session_context.config.file_exclude_glob_list,
*config.run.file_exclude_glob_list,
]
)
for pattern in all_exclude_patterns:
Expand Down
2 changes: 2 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class RunSettings(DataClassJsonMixin):
file_exclude_glob_list: List[Path] = field(default_factory=list)
auto_context: bool = False
auto_tokens: int = 8000
#Automatically selects code files for every request to include in context. Adds this many tokens to context each request.
auto_context_tokens: int = 0

@dataclass()
class AIModelSettings(DataClassJsonMixin):
Expand Down
30 changes: 13 additions & 17 deletions mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
from mentat.code_context import CodeContext
from mentat.code_edit_feedback import get_user_feedback_on_edits
from mentat.code_file_manager import CodeFileManager
from mentat.config import Config
from mentat.config import config
from mentat.conversation import Conversation
from mentat.cost_tracker import CostTracker
from mentat.ctags import ensure_ctags_installed
from mentat.errors import MentatError, SessionExit, UserError
from mentat.errors import MentatError, SessionExit, UserError, ContextSizeInsufficient
from mentat.git_handler import get_git_root_for_path
from mentat.llm_api_handler import LlmApiHandler, is_test_environment
from mentat.logging_config import setup_logging
Expand All @@ -47,7 +47,6 @@ def __init__(
ignore_paths: List[Path] = [],
diff: Optional[str] = None,
pr_diff: Optional[str] = None,
config: Config = Config(),
):
# All errors thrown here need to be caught here
self.stopped = False
Expand Down Expand Up @@ -126,15 +125,14 @@ async def _main(self):
agent_handler = session_context.agent_handler

# check early for ctags so we can fail fast
if config.run.auto_context:
if session_context.config.auto_context_tokens > 0:
if config.run.auto_context_tokens > 0:
ensure_ctags_installed()

session_context.llm_api_handler.initialize_client()
code_context.display_context()
await conversation.display_token_count()

stream.send("Type 'q' or use Ctrl-C to quit at any time.")
print(f"Type 'q' or use Ctrl-C to quit at any time.")
need_user_request = True
while True:
try:
Expand All @@ -143,12 +141,9 @@ async def _main(self):
# edits made between user input to be collected together.
if agent_handler.agent_enabled:
code_file_manager.history.push_edits()
stream.send(
"Use /undo to undo all changes from agent mode since last"
" input.",
color="green",
)
stream.send("\nWhat can I do for you?", color="light_blue")
print(f"[green]Use /undo to undo all changes from agent mode since last input.[/green]")

print(f"[blue]What can I do for you?[/blue]")
message = await collect_input_with_commands()
if message.data.strip() == "":
continue
Expand All @@ -174,10 +169,11 @@ async def _main(self):
applied_edits = await code_file_manager.write_changes_to_files(
file_edits
)
stream.send(
"Changes applied." if applied_edits else "No changes applied.",
color="light_blue",
)

if applied_edits:
print(f"[blue]Changes applied.[/blue]")
else:
print(f"[blue]No Changes applied.[/blue]")

if agent_handler.agent_enabled:
if parsed_llm_response.interrupted:
Expand All @@ -193,7 +189,7 @@ async def _main(self):
need_user_request = True
continue
except (APITimeoutError, RateLimitError, BadRequestError) as e:
stream.send(f"Error accessing OpenAI API: {e.message}", color="red")
print(f"[red]Error accessing OpenAI API: {e.message}[/red]")
break

async def listen_for_session_exit(self):
Expand Down

0 comments on commit eb9725d

Please sign in to comment.