Skip to content

Commit

Permalink
Added ability to disable and enable plugins / Commands based on YAML.
Browse files Browse the repository at this point in the history
  • Loading branch information
use-the-fork committed Dec 30, 2023
1 parent 19d3bf4 commit efdb45b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 21 deletions.
23 changes: 23 additions & 0 deletions mentat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,20 @@ class RunSettings(DataClassJsonMixin):
auto_context: bool = False
auto_tokens: int = 8000
auto_context_tokens: int = 0
active_plugins: List[str] = field(default_factory=list)

def __init__(
self,
file_exclude_glob_list: Optional[List[Path]] = None,
active_plugins: Optional[List[str]] = None,
auto_context: Optional[bool] = None,
auto_tokens: Optional[int] = None,
auto_context_tokens: Optional[int] = None,
) -> None:
if file_exclude_glob_list is not None:
self.file_exclude_glob_list = file_exclude_glob_list
if active_plugins is not None:
self.active_plugins = active_plugins
if auto_context is not None:
self.auto_context = auto_context
if auto_tokens is not None:
Expand Down Expand Up @@ -200,6 +204,7 @@ class RunningSessionConfig(DataClassJsonMixin):
)
maximum_context: Optional[int] = None
auto_context_tokens: Optional[int] = 0
active_plugins: Optional[List[str]] = None

@classmethod
def get_fields(cls) -> List[str]:
Expand Down Expand Up @@ -269,6 +274,9 @@ def load_settings(config_session: Optional[RunningSessionConfig] = None):
if yaml_config.file_exclude_glob_list is None:
yaml_config.file_exclude_glob_list = []

if yaml_config.active_plugins is None:
yaml_config.active_plugins = []

if yaml_config.temperature is None:
yaml_config.temperature = 0.2

Expand Down Expand Up @@ -296,6 +304,7 @@ def load_settings(config_session: Optional[RunningSessionConfig] = None):
file_exclude_glob_list=[
Path(p) for p in file_exclude_glob_list
], # pyright: ignore[reportUnknownVariableType]
active_plugins=yaml_config.active_plugins,
auto_context_tokens=yaml_config.auto_context_tokens,
)

Expand Down Expand Up @@ -379,3 +388,17 @@ def get_config(setting: str) -> None:
def load_config() -> None:
init_config()
load_settings()


def is_active_plugin(plugin: str | None = None) -> bool:
config = mentat.user_session.get("config")
if (
plugin is not None
and config is not None
and config.run is not None
and config.run.active_plugins is not None
and plugin in config.run.active_plugins
):
return True

return False
4 changes: 4 additions & 0 deletions mentat/resources/conf/.mentatconf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ file_exclude_glob_list:
# - "**/.*/**"
auto_context_tokens:

# a list of plugins that should be active. Current options include sampler
active_plugins:
- sampler

#settings related to the "parser"

# Mentat parses files following a specific format, which you can set here.
Expand Down
27 changes: 15 additions & 12 deletions mentat/sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import mentat
from mentat.code_feature import get_consolidated_feature_refs
from mentat.config import is_active_plugin
from mentat.errors import SampleError
from mentat.git_handler import get_git_diff, get_git_root_for_path, get_hexsha_active
from mentat.parsers.git_parser import GitParser
Expand All @@ -17,18 +18,6 @@
from mentat.utils import get_relative_path


def init_settings(
repo: str | None = None, merge_base_target: str | None = None
) -> None:
mentat.user_session.set(
"sampler_settings",
{
"repo": repo,
"merge_base_target": merge_base_target,
},
)


def parse_message(message: ChatCompletionMessageParam) -> dict[str, str]:
content = message.get("content")
text, code = "", ""
Expand Down Expand Up @@ -58,15 +47,29 @@ def parse_message(message: ChatCompletionMessageParam) -> dict[str, str]:


class Sampler:
is_active: bool = False
diff_active: str | None = None
commit_active: str | None = None
last_sample_id: str | None = None
last_sample_hexsha: str | None = None

# set up the base config settings that sampler will use.
def __init__(self):
self.is_active = is_active_plugin("sampler")
if not mentat.user_session.get("sampler_settings"):
mentat.user_session.set(
"sampler_settings",
{
"repo": None,
"merge_base_target": None,
},
)

def set_active_diff(self):
# Create a temporary commit with the active changes
ctx = SESSION_CONTEXT.get()
git_root = get_git_root_for_path(ctx.cwd, raise_error=False)

if not git_root:
return
repo = Repo(git_root)
Expand Down
2 changes: 1 addition & 1 deletion mentat/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ async def _main(self):
for file_edit in file_edits:
file_edit.resolve_conflicts()

if session_context.sampler:
if session_context.sampler and session_context.sampler.is_active:
session_context.sampler.set_active_diff()

applied_edits = await code_file_manager.write_changes_to_files(
Expand Down
3 changes: 0 additions & 3 deletions mentat/terminal/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

import mentat
from mentat.config import update_config
from mentat.sampler import sampler
from mentat.session import Session
from mentat.session_stream import StreamMessageSource
from mentat.terminal.loading import LoadingHandler
Expand Down Expand Up @@ -248,8 +247,6 @@ def start(
maximum_context: Optional[int],
) -> None:

sampler.init_settings()

if model is not None:
update_config("model", model)
if temperature is not None:
Expand Down
24 changes: 19 additions & 5 deletions tests/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
ChatCompletionUserMessageParam,
)

import mentat
from mentat.errors import SampleError
from mentat.git_handler import get_git_diff
from mentat.parsers.block_parser import BlockParser
from mentat.parsers.git_parser import GitParser
from mentat.python_client.client import PythonClient
from mentat.sampler import __version__
from mentat.sampler.sample import Sample
from mentat.sampler.sampler import Sampler, init_settings
from mentat.sampler.sampler import Sampler
from mentat.sampler.utils import get_active_snapshot_commit
from mentat.session import Session
from scripts.evaluate_samples import evaluate_sample
Expand All @@ -35,7 +36,13 @@ async def test_sample_from_context(
mock_session_context,
mock_collect_user_input,
):
init_settings(repo="test_sample_repo", merge_base_target="")
mentat.user_session.set(
"sampler_settings",
{
"repo": "test_sample_repo",
"merge_base_target": "",
},
)

mocker.patch(
"mentat.conversation.Conversation.get_messages",
Expand Down Expand Up @@ -97,7 +104,13 @@ def is_sha1(string: str) -> bool:

@pytest.mark.asyncio
async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_llm_api):
init_settings(repo=None)
mentat.user_session.set(
"sampler_settings",
{
"repo": None,
"merge_base_target": None,
},
)

mock_collect_user_input.set_stream_messages([
"Request",
Expand Down Expand Up @@ -325,8 +338,9 @@ def get_updates_as_parsed_llm_message(cwd):
async def test_sampler_integration(
temp_testbed, mock_session_context, mock_call_llm_api
):
init_settings(repo=None)
# Setup the environemnt
mentat.user_session.set("sampler_settings", {"repo": None})

# Setup the environment
repo = Repo(temp_testbed)
(temp_testbed / "test_file.py").write_text("permanent commit")
repo.git.add("test_file.py")
Expand Down

0 comments on commit efdb45b

Please sign in to comment.