diff --git a/mentat/config.py b/mentat/config.py index 148344c9f..7e84c0f06 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -36,16 +36,20 @@ class RunSettings(DataClassJsonMixin): auto_context: bool = False auto_tokens: int = 8000 auto_context_tokens: int = 0 + active_plugins: List[str] = field(default_factory=list) def __init__( self, file_exclude_glob_list: Optional[List[Path]] = None, + active_plugins: Optional[List[str]] = None, auto_context: Optional[bool] = None, auto_tokens: Optional[int] = None, auto_context_tokens: Optional[int] = None, ) -> None: if file_exclude_glob_list is not None: self.file_exclude_glob_list = file_exclude_glob_list + if active_plugins is not None: + self.active_plugins = active_plugins if auto_context is not None: self.auto_context = auto_context if auto_tokens is not None: @@ -200,6 +204,7 @@ class RunningSessionConfig(DataClassJsonMixin): ) maximum_context: Optional[int] = None auto_context_tokens: Optional[int] = 0 + active_plugins: Optional[List[str]] = None @classmethod def get_fields(cls) -> List[str]: @@ -269,6 +274,9 @@ def load_settings(config_session: Optional[RunningSessionConfig] = None): if yaml_config.file_exclude_glob_list is None: yaml_config.file_exclude_glob_list = [] + if yaml_config.active_plugins is None: + yaml_config.active_plugins = [] + if yaml_config.temperature is None: yaml_config.temperature = 0.2 @@ -296,6 +304,7 @@ def load_settings(config_session: Optional[RunningSessionConfig] = None): file_exclude_glob_list=[ Path(p) for p in file_exclude_glob_list ], # pyright: ignore[reportUnknownVariableType] + active_plugins=yaml_config.active_plugins, auto_context_tokens=yaml_config.auto_context_tokens, ) @@ -379,3 +388,17 @@ def get_config(setting: str) -> None: def load_config() -> None: init_config() load_settings() + + +def is_active_plugin(plugin: str | None = None) -> bool: + config = mentat.user_session.get("config") + if ( + plugin is not None + and config is not None + and config.run is not None + and config.run.active_plugins is not None + and plugin in config.run.active_plugins + ): + return True + + return False diff --git a/mentat/resources/conf/.mentatconf.yaml b/mentat/resources/conf/.mentatconf.yaml index 497a78677..5c532c080 100644 --- a/mentat/resources/conf/.mentatconf.yaml +++ b/mentat/resources/conf/.mentatconf.yaml @@ -22,6 +22,10 @@ file_exclude_glob_list: # - "**/.*/**" auto_context_tokens: +# a list of plugins that should be active. Current options include sampler +active_plugins: + - sampler + #settings related to the "parser" # Mentat parses files following a specific format, which you can set here. diff --git a/mentat/sampler/sampler.py b/mentat/sampler/sampler.py index e87726f1c..6c3b844a4 100644 --- a/mentat/sampler/sampler.py +++ b/mentat/sampler/sampler.py @@ -7,6 +7,7 @@ import mentat from mentat.code_feature import get_consolidated_feature_refs +from mentat.config import is_active_plugin from mentat.errors import SampleError from mentat.git_handler import get_git_diff, get_git_root_for_path, get_hexsha_active from mentat.parsers.git_parser import GitParser @@ -17,18 +18,6 @@ from mentat.utils import get_relative_path -def init_settings( - repo: str | None = None, merge_base_target: str | None = None -) -> None: - mentat.user_session.set( - "sampler_settings", - { - "repo": repo, - "merge_base_target": merge_base_target, - }, - ) - - def parse_message(message: ChatCompletionMessageParam) -> dict[str, str]: content = message.get("content") text, code = "", "" @@ -58,15 +47,29 @@ def parse_message(message: ChatCompletionMessageParam) -> dict[str, str]: class Sampler: + is_active: bool = False diff_active: str | None = None commit_active: str | None = None last_sample_id: str | None = None last_sample_hexsha: str | None = None + # set up the base config settings that sampler will use. + def __init__(self): + self.is_active = is_active_plugin("sampler") + if not mentat.user_session.get("sampler_settings"): + mentat.user_session.set( + "sampler_settings", + { + "repo": None, + "merge_base_target": None, + }, + ) + def set_active_diff(self): # Create a temporary commit with the active changes ctx = SESSION_CONTEXT.get() git_root = get_git_root_for_path(ctx.cwd, raise_error=False) + if not git_root: return repo = Repo(git_root) diff --git a/mentat/session.py b/mentat/session.py index 02126239a..bb7a3e247 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -182,7 +182,7 @@ async def _main(self): for file_edit in file_edits: file_edit.resolve_conflicts() - if session_context.sampler: + if session_context.sampler and session_context.sampler.is_active: session_context.sampler.set_active_diff() applied_edits = await code_file_manager.write_changes_to_files( diff --git a/mentat/terminal/client.py b/mentat/terminal/client.py index 42a3e23d2..6cae1b9fa 100644 --- a/mentat/terminal/client.py +++ b/mentat/terminal/client.py @@ -13,7 +13,6 @@ import mentat from mentat.config import update_config -from mentat.sampler import sampler from mentat.session import Session from mentat.session_stream import StreamMessageSource from mentat.terminal.loading import LoadingHandler @@ -248,8 +247,6 @@ def start( maximum_context: Optional[int], ) -> None: - sampler.init_settings() - if model is not None: update_config("model", model) if temperature is not None: diff --git a/tests/sampler_test.py b/tests/sampler_test.py index f2930f174..804f694d3 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -10,6 +10,7 @@ ChatCompletionUserMessageParam, ) +import mentat from mentat.errors import SampleError from mentat.git_handler import get_git_diff from mentat.parsers.block_parser import BlockParser @@ -17,7 +18,7 @@ from mentat.python_client.client import PythonClient from mentat.sampler import __version__ from mentat.sampler.sample import Sample -from mentat.sampler.sampler import Sampler, init_settings +from mentat.sampler.sampler import Sampler from mentat.sampler.utils import get_active_snapshot_commit from mentat.session import Session from scripts.evaluate_samples import evaluate_sample @@ -35,7 +36,13 @@ async def test_sample_from_context( mock_session_context, mock_collect_user_input, ): - init_settings(repo="test_sample_repo", merge_base_target="") + mentat.user_session.set( + "sampler_settings", + { + "repo": "test_sample_repo", + "merge_base_target": "", + }, + ) mocker.patch( "mentat.conversation.Conversation.get_messages", @@ -97,7 +104,13 @@ def is_sha1(string: str) -> bool: @pytest.mark.asyncio async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_llm_api): - init_settings(repo=None) + mentat.user_session.set( + "sampler_settings", + { + "repo": None, + "merge_base_target": None, + }, + ) mock_collect_user_input.set_stream_messages([ "Request", @@ -325,8 +338,9 @@ def get_updates_as_parsed_llm_message(cwd): async def test_sampler_integration( temp_testbed, mock_session_context, mock_call_llm_api ): - init_settings(repo=None) - # Setup the environemnt + mentat.user_session.set("sampler_settings", {"repo": None}) + + # Setup the environment repo = Repo(temp_testbed) (temp_testbed / "test_file.py").write_text("permanent commit") repo.git.add("test_file.py")