Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
greg-assa committed Dec 27, 2023
1 parent 55e29ec commit d22d33d
Show file tree
Hide file tree
Showing 17 changed files with 233 additions and 345 deletions.
36 changes: 17 additions & 19 deletions mentat/code_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set, Union
from rich import print

from mentat.code_feature import (
CodeFeature,
Expand Down Expand Up @@ -31,6 +32,7 @@
from mentat.session_context import SESSION_CONTEXT
from mentat.session_stream import SessionStream
from mentat.config import config
from mentat.utils import dd


class CodeContext:
Expand Down Expand Up @@ -66,18 +68,17 @@ def display_context(self):
prefix = " "
stream.send(f"{prefix}Directory: {session_context.cwd}")
if self.diff_context and self.diff_context.name:
stream.send(f"{prefix}Diff:", end=" ")
stream.send(self.diff_context.get_display_context(), color="green")
print(f"{prefix}Diff:[green]{self.diff_context.get_display_context()}[/green]")

if config.run.auto_context_tokens > 0:
stream.send(f"{prefix}Auto-Context: Enabled")
stream.send(f"{prefix}Auto-Context Tokens: {config.run.auto_context_tokens}")
print(f"{prefix}Auto-Context: [green]Enabled[/green]")
print(f"{prefix}Auto-Context Tokens: {config.run.auto_context_tokens}")
else:
stream.send(f"{prefix}Auto-Context: Disabled")
print(f"{prefix}Auto-Context: [yellow]Disabled[/yellow]")

if self.include_files:
stream.send(f"{prefix}Included files:")
stream.send(f"{prefix + prefix}{session_context.cwd.name}")
print(f"{prefix}Included files:")
print(f"{prefix + prefix}{session_context.cwd.name}")
features = [
feature
for file_features in self.include_files.values()
Expand All @@ -91,11 +92,10 @@ def display_context(self):
prefix + prefix,
)
else:
stream.send(f"{prefix}Included files: ", end="")
stream.send("None", color="yellow")
print(f"{prefix}Included files: [yellow]None[/yellow]")

if self.auto_features:
stream.send(f"{prefix}Auto-Included Features:")
print(f"{prefix}Auto-Included Features:")
refs = get_consolidated_feature_refs(self.auto_features)
print_path_tree(
build_path_tree([Path(r) for r in refs], session_context.cwd),
Expand Down Expand Up @@ -148,6 +148,8 @@ async def get_code_message(
"\n".join(include_files_message), model, full_message=False
)



tokens_used = (
prompt_tokens + meta_tokens + include_files_tokens + config.ai.token_buffer
)
Expand Down Expand Up @@ -291,21 +293,19 @@ def include(
cwd=session_context.cwd,
exclude_patterns=abs_exclude_patterns,
)

except PathValidationError as e:
session_context.stream.send(str(e), color="light_red")
print(f"[red]Path Validation Error:{str(e)}[/red]")
return set()

return self.include_features(code_features)

def _exclude_file(self, path: Path) -> Path | None:
session_context = SESSION_CONTEXT.get()
if path in self.include_files:
del self.include_files[path]
return path
else:
session_context.stream.send(
f"Path {path} not in context", color="light_red"
)
print(f"[red]Path {path} not in context[/red]")

def _exclude_file_interval(self, path: Path) -> Set[Path]:
session_context = SESSION_CONTEXT.get()
Expand All @@ -314,9 +314,7 @@ def _exclude_file_interval(self, path: Path) -> Set[Path]:

interval_path, interval_str = split_intervals_from_path(path)
if interval_path not in self.include_files:
session_context.stream.send(
f"Path {interval_path} not in context", color="light_red"
)
print(f"[red]Path {interval_path} not in context[/red]")
return excluded_paths

intervals = parse_intervals(interval_str)
Expand Down Expand Up @@ -399,7 +397,7 @@ def exclude(self, path: Path | str) -> Set[Path]:
case PathType.GLOB:
excluded_paths.update(self._exclude_glob(validated_path))
except PathValidationError as e:
session_context.stream.send(str(e), color="light_red")
print(f"[red]Path Validation Error: {str(e)}[/red]")

return excluded_paths

Expand Down
29 changes: 14 additions & 15 deletions mentat/command/commands/search.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from typing import List, Set

from termcolor import colored
from typing_extensions import override

from mentat.command.command import Command, CommandArgument
from mentat.errors import UserError
from mentat.session_context import SESSION_CONTEXT
from mentat.utils import get_relative_path
from rich import print

SEARCH_RESULT_BATCH_SIZE = 10

Expand Down Expand Up @@ -34,49 +34,48 @@ def _parse_include_input(user_input: str, max_num: int) -> Set[int] | None:
class SearchCommand(Command, command_name="search"):
@override
async def apply(self, *args: str) -> None:
from mentat.config import config
session_context = SESSION_CONTEXT.get()
stream = session_context.stream

code_context = session_context.code_context
config = session_context.config

if len(args) == 0:
stream.send("No search query specified", color="yellow")
print("[yellow]No search query specified[/]")
return
try:
query = " ".join(args)
results = await code_context.search(query=query)
except UserError as e:
stream.send(str(e), color="red")
print(f"[red]{str(e)}[/]")
return

cumulative_tokens = 0
for i, (feature, _) in enumerate(results, start=1):
prefix = "\n "

file_name = feature.rel_path(session_context.cwd)
file_name = colored(file_name, "blue", attrs=["bold"])
file_name += colored(feature.interval_string(), "light_cyan")
file_name = f"[blue bold]{file_name}[/]"

tokens = feature.count_tokens(config.model)
tokens = feature.count_tokens(config.ai.model)
cumulative_tokens += tokens
tokens_str = colored(f" ({tokens} tokens)", "yellow")
tokens_str = f"[yellow] ({tokens} tokens)[/]"
file_name += tokens_str

name = []
if feature.name:
name = feature.name.split(",")
name = [
f"{'└' if i == len(name) - 1 else '├'}{colored(n, 'cyan')}"
f"{'└' if i == len(name) - 1 else '├'}[blue]{n}[/]"
for i, n in enumerate(name)
]

message = f"{str(i).ljust(3)}" + prefix.join([file_name] + name + [""])
stream.send(message)
print(message)
if i > 1 and i % SEARCH_RESULT_BATCH_SIZE == 0:
# Required to avoid circular imports, but not ideal.
from mentat.session_input import collect_user_input

stream.send(
print(
"(Y/n) for more results or to exit search mode.\nResults to"
' include in context: (eg: "1 3 4" or "1-4")'
)
Expand All @@ -90,14 +89,14 @@ async def apply(self, *args: str) -> None:
rel_path = get_relative_path(
included_path, session_context.cwd
)
stream.send(f"{rel_path} added to context", color="green")
print(f"[green]{rel_path} added to context[/]")
else:
stream.send("(Y/n)")
print("(Y/n)")
user_input: str = (
await collect_user_input(plain=True)
).data.strip()
if user_input.lower() == "n":
stream.send("Exiting search mode...", color="light_blue")
print("[bright_blue]Exiting search mode...[/]")
break

@override
Expand Down
17 changes: 11 additions & 6 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
from mentat.git_handler import get_git_root_for_path
from mentat.parsers.parser_map import parser_map
from mentat.parsers.block_parser import BlockParser
from mentat.utils import mentat_dir_path
from mentat.utils import mentat_dir_path, dd
from dataclasses import dataclass, field
from dataclasses_json import DataClassJsonMixin
from typing import Tuple
from mentat.parsers.parser import Parser
from typing import Any, Dict, List, Optional
from rich.console import Console

console = Console()

config_file_name = Path(".mentat_config.yaml")
user_config_path = mentat_dir_path / config_file_name
Expand Down Expand Up @@ -131,7 +133,7 @@ def load_prompts(prompt_type: str):
"unified_diff_parser_prompt": Path("text/unified_diff_parser_prompt.txt"),
}

def load_settings():
def load_settings(config_session_dict = None):
"""Load the configuration from the `.mentatconf.yaml` file."""

current_conf_path = APP_ROOT / '.mentatconf.yaml'
Expand All @@ -157,6 +159,9 @@ def load_settings():
current_path_config = yaml_to_config(yaml_dict)
yaml_config = merge_configs(yaml_config, current_path_config)

if config_session_dict is not None and config_session_dict.get('file_exclude_glob_list') is not None:
yaml_config["file_exclude_glob_list"].extend(config_session_dict['file_exclude_glob_list'])

file_exclude_glob_list = yaml_config.get("file_exclude_glob_list", [])

#always ignore .mentatconf
Expand Down Expand Up @@ -191,15 +196,15 @@ def load_settings():
}


def update_config(**kwargs):
def update_config(session_config):
"""Reload the configuration using the provided keyword arguments."""
global config
if config is None:
return

# setting the values from kwargs to the global config
for key, value in kwargs.items():
setattr(config, key, value)
settings = load_settings(session_config)
config = MentatConfig(**settings)


def load_config() -> MentatConfig:
init_config()
Expand Down
Loading

0 comments on commit d22d33d

Please sign in to comment.