diff --git a/.github/workflows/benchmarks.yml b/.github/workflows/benchmarks.yml index 62b8f510c..b337599fe 100644 --- a/.github/workflows/benchmarks.yml +++ b/.github/workflows/benchmarks.yml @@ -20,9 +20,6 @@ jobs: python -m pip install . pip install -r dev-requirements.txt - - name: Install universal-ctags - run: sudo apt update && sudo apt install universal-ctags - - name: Run and upload benchmarks run: ./scripts/run_and_upload_benchmarks.sh env: diff --git a/.github/workflows/lint_and_test.yml b/.github/workflows/lint_and_test.yml index 11368880f..1f8f4103e 100644 --- a/.github/workflows/lint_and_test.yml +++ b/.github/workflows/lint_and_test.yml @@ -43,15 +43,6 @@ jobs: uses: actions/setup-node@v4 with: node-version: "16" - - name: Install universal-ctags (Ubuntu) - if: runner.os == 'Linux' - run: sudo apt update && sudo apt install universal-ctags - - name: Install universal-ctags (OSX) - if: runner.os == 'macOS' - run: brew update && brew install universal-ctags - - name: Install universal-ctags (Windows) - if: runner.os == 'Windows' - run: choco install universal-ctags - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v3 with: diff --git a/benchmarks/benchmark_runner.py b/benchmarks/benchmark_runner.py index cffbae0b7..335436e2c 100755 --- a/benchmarks/benchmark_runner.py +++ b/benchmarks/benchmark_runner.py @@ -228,7 +228,7 @@ async def run(self, retries: int = 1) -> list[BenchmarkResult]: ) try: if sample.context and self.config.auto_context_tokens: - score = await run_auto_context_benchmark(sample, self.config, include_context=False) + score = await run_auto_context_benchmark(sample, self.config) result.context_results = {**score, "auto_context_tokens": self.config.auto_context_tokens} result.context_precision = score["precision"] result.context_recall = score["recall"] diff --git a/benchmarks/context_benchmark.py b/benchmarks/context_benchmark.py index a7887d0e8..031785d18 100755 --- a/benchmarks/context_benchmark.py +++ b/benchmarks/context_benchmark.py @@ -5,13 +5,13 @@ from pathlib import Path from typing import Any +from ragdaemon.daemon import Daemon + from benchmarks.arg_parser import common_benchmark_parser from benchmarks.run_sample import setup_sample from benchmarks.swe_bench_runner import SWE_BENCH_SAMPLES_DIR, get_swe_samples -from mentat import Mentat from mentat.config import Config from mentat.sampler.sample import Sample -from mentat.session_context import SESSION_CONTEXT def _score(predicted: set[Path], expected: set[Path]) -> dict[str, Any]: @@ -23,9 +23,7 @@ def _score(predicted: set[Path], expected: set[Path]) -> dict[str, Any]: return {"precision": precision, "recall": recall, "n_true": len(expected)} -async def run_auto_context_benchmark( - sample: Sample, config: Config, cwd: Path | str | None = None, include_context: bool = False -) -> dict[str, Any]: +async def run_auto_context_benchmark(sample: Sample, config: Config, cwd: Path | str | None = None) -> dict[str, Any]: """Run a sample using Mentat and return the resulting diff""" starting_dir = Path.cwd() @@ -34,24 +32,24 @@ async def run_auto_context_benchmark( "In order to run the auto-context benchmark, sample.context must not " "be empty (ground truth) and config.auto_context_tokens must be > 0." ) - paths = [] if not include_context else [Path(a) for a in sample.context] try: _, cwd, _, _ = setup_sample(sample, None, skip_test_exec=True) - exclude_paths = [cwd / ".venv"] - mentat = Mentat(cwd=cwd, paths=paths, exclude_paths=exclude_paths, config=config or Config()) - await mentat.startup() - await asyncio.sleep(0.01) # Required to initialize llm_api_handler for embeddings + ignore_patterns = [cwd / ".venv"] + annotators = { + "hierarchy": {"ignore_patterns": ignore_patterns}, + "chunker_line": {"lines_per_chunk": 100}, + } + daemon = Daemon(cwd=cwd, annotators=annotators) + await daemon.update() # TODO: If there's a conversation history, we might consider the cumulative context. # Setup a mock for the LLM response and run the conversation until this point. - code_context = SESSION_CONTEXT.get().code_context - _ = await code_context.get_code_message(0, sample.message_prompt) - predicted = set(path.relative_to(cwd) for path in code_context.include_files.keys()) + context = daemon.get_context(sample.message_prompt, auto_tokens=config.auto_context_tokens) + predicted = {Path(a) for a in context.context.keys()} actual = {Path(a) for a in sample.context} score = _score(predicted, actual) - await mentat.shutdown() return score finally: os.chdir(starting_dir) diff --git a/benchmarks/run_sample.py b/benchmarks/run_sample.py index 77e1e32df..da89f4305 100644 --- a/benchmarks/run_sample.py +++ b/benchmarks/run_sample.py @@ -31,6 +31,15 @@ def setup_sample( ) cwd = Path(repo.working_dir) + # Make sure there's a .gitignore file, and that '.ragdaemon/*' is in it + gitignore_path = cwd / ".gitignore" + if not gitignore_path.exists(): + gitignore_path.write_text(".ragdaemon/*\n") + else: + gitignore_contents = gitignore_path.read_text() + if ".ragdaemon/*" not in gitignore_contents: + gitignore_path.write_text(gitignore_contents + ".ragdaemon/*\n") + test_executable = None if not skip_test_exec and (sample.FAIL_TO_PASS or sample.PASS_TO_PASS): # If there's an environment_setup_commit, this is what it's needed for. diff --git a/docs/source/developer/mentat.feature_filters.rst b/docs/source/developer/mentat.feature_filters.rst deleted file mode 100644 index db5c1d3c8..000000000 --- a/docs/source/developer/mentat.feature_filters.rst +++ /dev/null @@ -1,61 +0,0 @@ -mentat.feature\_filters package -=============================== - -Submodules ----------- - -mentat.feature\_filters.default\_filter module ----------------------------------------------- - -.. automodule:: mentat.feature_filters.default_filter - :members: - :undoc-members: - :show-inheritance: - -mentat.feature\_filters.embedding\_similarity\_filter module ------------------------------------------------------------- - -.. automodule:: mentat.feature_filters.embedding_similarity_filter - :members: - :undoc-members: - :show-inheritance: - -mentat.feature\_filters.feature\_filter module ----------------------------------------------- - -.. automodule:: mentat.feature_filters.feature_filter - :members: - :undoc-members: - :show-inheritance: - -mentat.feature\_filters.llm\_feature\_filter module ---------------------------------------------------- - -.. automodule:: mentat.feature_filters.llm_feature_filter - :members: - :undoc-members: - :show-inheritance: - -mentat.feature\_filters.truncate\_filter module ------------------------------------------------ - -.. automodule:: mentat.feature_filters.truncate_filter - :members: - :undoc-members: - :show-inheritance: - -mentat.feature\_filters.user\_include\_sort\_filter module ----------------------------------------------------------- - -.. automodule:: mentat.feature_filters.user_include_sort_filter - :members: - :undoc-members: - :show-inheritance: - -Module contents ---------------- - -.. automodule:: mentat.feature_filters - :members: - :undoc-members: - :show-inheritance: diff --git a/docs/source/developer/mentat.rst b/docs/source/developer/mentat.rst index f91ed37f3..f9c6d1f1b 100644 --- a/docs/source/developer/mentat.rst +++ b/docs/source/developer/mentat.rst @@ -8,7 +8,6 @@ Subpackages :maxdepth: 4 mentat.command - mentat.feature_filters mentat.parsers mentat.prompts mentat.python_client @@ -98,14 +97,6 @@ mentat.conversation module :undoc-members: :show-inheritance: -mentat.ctags module -------------------- - -.. automodule:: mentat.ctags - :members: - :undoc-members: - :show-inheritance: - mentat.diff\_context module --------------------------- @@ -122,14 +113,6 @@ mentat.edit\_history module :undoc-members: :show-inheritance: -mentat.embeddings module ------------------------- - -.. automodule:: mentat.embeddings - :members: - :undoc-members: - :show-inheritance: - mentat.errors module -------------------- diff --git a/docs/source/user/getting_started.rst b/docs/source/user/getting_started.rst index 221ff0f7f..ee9c66743 100644 --- a/docs/source/user/getting_started.rst +++ b/docs/source/user/getting_started.rst @@ -39,26 +39,6 @@ If you want to use the whisper transcription on an OS besides macOS or Windows y sudo apt-get install libportaudio2 -In order to use auto context you will need to install ctags: - -Windows: - -.. code-block:: bash - - choco install universal-ctags - -macOS: - -.. code-block:: bash - - brew install universal-ctags - -Ubuntu: - -.. code-block:: bash - - sudo apt install universal-ctags - Basic Usage ----------- diff --git a/mentat/code_context.py b/mentat/code_context.py index 41ad62911..46af2a8b4 100644 --- a/mentat/code_context.py +++ b/mentat/code_context.py @@ -1,25 +1,17 @@ from __future__ import annotations -import os from pathlib import Path -from typing import Dict, Iterable, List, Optional, Set, TypedDict, Union +from typing import Any, Dict, Iterable, List, Optional, Set, TypedDict, Union -from mentat.code_feature import ( - CodeFeature, - get_code_message_from_features, - get_consolidated_feature_refs, - split_file_into_intervals, -) +from ragdaemon.daemon import Daemon + +from mentat.code_feature import CodeFeature, get_consolidated_feature_refs from mentat.diff_context import DiffContext from mentat.errors import PathValidationError -from mentat.feature_filters.default_filter import DefaultFilter -from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter from mentat.include_files import ( PathType, get_code_features_for_path, get_path_type, - get_paths_for_directory, - is_file_text_encoded, match_path_with_patterns, validate_and_format_path, ) @@ -27,6 +19,7 @@ from mentat.llm_api_handler import get_max_tokens from mentat.session_context import SESSION_CONTEXT from mentat.session_stream import SessionStream +from mentat.utils import get_relative_path, mentat_dir_path class ContextStreamMessage(TypedDict): @@ -41,7 +34,13 @@ class ContextStreamMessage(TypedDict): total_cost: float +graphs_dir = mentat_dir_path / "ragdaemon" +graphs_dir.mkdir(parents=True, exist_ok=True) + + class CodeContext: + daemon: Daemon + def __init__( self, stream: SessionStream, @@ -59,6 +58,28 @@ def __init__( self.include_files: Dict[Path, List[CodeFeature]] = {} self.ignore_files: Set[Path] = set() + async def refresh_daemon(self): + """Call before interacting with context to ensure daemon is up to date.""" + if not hasattr(self, "daemon"): + # Daemon is initialized after setup because it needs the embedding_provider. + ctx = SESSION_CONTEXT.get() + cwd = ctx.cwd + llm_api_handler = ctx.llm_api_handler + + annotators: dict[str, dict[str, Any]] = { + "hierarchy": {"ignore_patterns": [str(p) for p in self.ignore_patterns]}, + "chunker_line": {"lines_per_chunk": 50}, + "diff": {"diff": self.diff_context.target}, + } + self.daemon = Daemon( + cwd=cwd, + annotators=annotators, + verbose=False, + graph_path=graphs_dir / f"ragdaemon-{cwd.name}.json", + spice_client=getattr(llm_api_handler, "spice_client", None), + ) + await self.daemon.update() + async def refresh_context_display(self): """ Sends a message to the client with the code context. It is called in the main loop. @@ -107,53 +128,75 @@ async def get_code_message( config = session_context.config llm_api_handler = session_context.llm_api_handler model = config.model + cwd = session_context.cwd + code_file_manager = session_context.code_file_manager - # Setup code message metadata - code_message = list[str]() - - # Since there is no way of knowing when the git diff changes, - # we just refresh the cache every time get_code_message is called + # Setup the header (Mentat-specific, before ragdaemon context) + header_lines = list[str]() self.diff_context.refresh() if self.diff_context.diff_files(): - code_message += [ - "Diff References:", - f' "-" = {self.diff_context.name}', - ' "+" = Active Changes', - "", - ] - - code_message += ["Code Files:\n"] - - # Get auto included features + header_lines += [f"Diff References: {self.diff_context.name}\n"] + header_lines += ["Code Files:\n\n"] + + # Setup a ContextBuilder from Mentat's include_files / diff_context + await self.refresh_daemon() + context_builder = self.daemon.get_context("", max_tokens=0) + diff_nodes: list[str] = [ + node + for node, data in self.daemon.graph.nodes(data=True) # pyright: ignore + if data and "type" in data and data["type"] == "diff" + ] + if not self.include_files.values(): + for node in diff_nodes: + context_builder.add_diff(node) + for path, features in self.include_files.items(): + for feature in features: + interval_string = feature.interval_string() + if interval_string and "-" in interval_string: + start, exclusive_end = interval_string.split("-") + inclusive_end = str(int(exclusive_end) - 1) + interval_string = f"{start}-{inclusive_end}" + ref = feature.rel_path(session_context.cwd) + interval_string + context_builder.add_ref(ref, tags=["user-included"]) + relative_path = get_relative_path(path, cwd).as_posix() + diffs_for_path = [node for node in diff_nodes if f":{relative_path}" in node] + for diff in diffs_for_path: + context_builder.add_diff(diff) + + # If auto-context, replace the context_builder with a new one if config.auto_context_tokens > 0 and prompt: - meta_tokens = llm_api_handler.spice.count_tokens("\n".join(code_message), model, is_message=True) + meta_tokens = llm_api_handler.spice.count_tokens("\n".join(header_lines), model, is_message=True) - # Calculate user included features token size - include_files_message = get_code_message_from_features( - [feature for file_features in self.include_files.values() for feature in file_features] - ) - include_files_tokens = llm_api_handler.spice.count_tokens( - "\n".join(include_files_message), model, is_message=False - ) + include_files_message = context_builder.render() + include_files_tokens = llm_api_handler.spice.count_tokens(include_files_message, model, is_message=False) tokens_used = prompt_tokens + meta_tokens + include_files_tokens auto_tokens = min( get_max_tokens() - tokens_used - config.token_buffer, config.auto_context_tokens, ) - features = self.get_all_features() - feature_filter = DefaultFilter(auto_tokens, prompt, expected_edits) - self.include_features(await feature_filter.filter(features)) - - # TODO: We want to show the auto included features immediately, but refreshing the context display - # also refreshes the token count per message, which calls this function again causing an infinite loop. - # To fix this, we should completely separate the token count per message from the context display message - # await self.refresh_context_display() - - include_features = [feature for file_features in self.include_files.values() for feature in file_features] - code_message += get_code_message_from_features(include_features) - - return "\n".join(code_message) + context_builder = self.daemon.get_context( + query=prompt, + context_builder=context_builder, # Pass include_files / diff_context to ragdaemon + max_tokens=get_max_tokens(), + auto_tokens=auto_tokens, + ) + for ref in context_builder.to_refs(): + path, interval_str = split_intervals_from_path(Path(ref)) + intervals = parse_intervals(interval_str) + for interval in intervals: + feature = CodeFeature(cwd / path, interval) + self.include_features([feature]) # Save ragdaemon context back to include_files + + # The context message is rendered by ragdaemon (ContextBuilder.render()) + context_message = context_builder.render() + for relative_path in context_builder.context.keys(): + path = Path(cwd / relative_path).resolve() + if path not in code_file_manager.file_lines: + with open(path, "r") as file: # Used by code_file_manager to validate file_edits + lines = file.read().split("\n") + code_file_manager.file_lines[path] = lines + return "\n".join(header_lines) + context_message def get_all_features( self, @@ -164,28 +207,19 @@ def get_all_features( Retrieves every CodeFeature under the cwd. If files_only is True the features won't be split into intervals """ session_context = SESSION_CONTEXT.get() + cwd = session_context.cwd - abs_exclude_patterns: Set[Path] = set() - for pattern in self.ignore_patterns.union(session_context.config.file_exclude_glob_list): - if not Path(pattern).is_absolute(): - abs_exclude_patterns.add(session_context.cwd / pattern) - else: - abs_exclude_patterns.add(Path(pattern)) - - all_features: List[CodeFeature] = [] - for path in get_paths_for_directory(path=session_context.cwd, exclude_patterns=abs_exclude_patterns): - if not is_file_text_encoded(path) or os.path.getsize(path) > max_chars: + all_features = list[CodeFeature]() + for _, data in self.daemon.graph.nodes(data=True): # pyright: ignore + if data is None or "type" not in data or "ref" not in data or data["type"] not in {"file", "chunk"}: continue - - if not split_intervals: - _feature = CodeFeature(path) - all_features.append(_feature) - else: - full_feature = CodeFeature(path) - _split_features = split_file_into_intervals(full_feature) - all_features += _split_features - - return sorted(all_features, key=lambda f: f.path) + path, interval = split_intervals_from_path(data["ref"]) # pyright: ignore + intervals = parse_intervals(interval) + if not intervals: + all_features.append(CodeFeature(cwd / path)) + for _interval in intervals: + all_features.append(CodeFeature(cwd / path, _interval)) + return all_features def include_features(self, code_features: Iterable[CodeFeature]): """ @@ -199,7 +233,7 @@ def include_features(self, code_features: Iterable[CodeFeature]): else: code_feature_not_included = True for included_code_feature in self.include_files[code_feature.path]: - # Intervals can still overlap if user includes intervals different than what ctags breaks up, + # Intervals can still overlap if user includes intervals different than what chunker breaks up, # but we merge when making code message and don't duplicate lines if ( included_code_feature.interval == code_feature.interval @@ -372,10 +406,18 @@ async def search( ) -> list[tuple[CodeFeature, float]]: """Return the top n features that are most similar to the query.""" - all_features = self.get_all_features() - - embedding_similarity_filter = EmbeddingSimilarityFilter(query) - all_features_sorted = await embedding_similarity_filter.score(all_features) + cwd = SESSION_CONTEXT.get().cwd + all_nodes_sorted = self.daemon.search(query, max_results) + all_features_sorted = list[tuple[CodeFeature, float]]() + for node in all_nodes_sorted: + if node.get("type") not in {"file", "chunk"}: + continue + distance = node["distance"] + path, interval = split_intervals_from_path(Path(node["ref"])) + intervals = parse_intervals(interval) + for _interval in intervals: + feature = CodeFeature(cwd / path, _interval) + all_features_sorted.append((feature, distance)) if max_results is None: return all_features_sorted else: diff --git a/mentat/code_feature.py b/mentat/code_feature.py index 5f76b0a4d..1782d8bd5 100644 --- a/mentat/code_feature.py +++ b/mentat/code_feature.py @@ -1,17 +1,13 @@ from __future__ import annotations -import asyncio -import logging -from collections import OrderedDict, defaultdict +from collections import defaultdict from pathlib import Path from typing import Optional import attr +from ragdaemon.utils import get_document -from mentat.ctags import get_ctag_lines_and_names -from mentat.diff_context import annotate_file_message, parse_diff from mentat.errors import MentatError -from mentat.git_handler import get_diff_for_file from mentat.interval import INTERVAL_FILE_END, Interval from mentat.session_context import SESSION_CONTEXT from mentat.utils import get_relative_path @@ -19,56 +15,6 @@ MIN_INTERVAL_LINES = 10 -def split_file_into_intervals( - feature: CodeFeature, - min_lines: int | None = None, -) -> list[CodeFeature]: - min_lines = min_lines or MIN_INTERVAL_LINES - session_context = SESSION_CONTEXT.get() - code_file_manager = session_context.code_file_manager - n_lines = len(code_file_manager.read_file(feature.path)) - - lines_and_names = get_ctag_lines_and_names(session_context.cwd.joinpath(feature.path)) - - if len(lines_and_names) == 0: - return [feature] - - lines, names = map(list, zip(*sorted(lines_and_names))) - lines[0] = 1 # first interval covers from start of file - draft_named_intervals = [(name, start, end) for name, start, end in zip(names, lines, lines[1:] + [n_lines])] - - def length(interval: tuple[str, int, int]): - return interval[2] - interval[1] - - def merge_intervals(int1: tuple[str, int, int], int2: tuple[str, int, int]): - return (f"{int1[0]},{int2[0]}", int1[1], int2[2]) - - named_intervals = [draft_named_intervals[0]] - for next_interval in draft_named_intervals[1:]: - last_interval = named_intervals[-1] - if length(last_interval) < min_lines: - named_intervals[-1] = merge_intervals(last_interval, next_interval) - elif length(next_interval) < min_lines and next_interval == draft_named_intervals[-1]: - # this is the last interval it's too short, so merge it with previous - named_intervals[-1] = merge_intervals(last_interval, next_interval) - else: - named_intervals.append(next_interval) - - if len(named_intervals) <= 1: - return [feature] - - # Create and return separate features for each interval - _features = list[CodeFeature]() - for name, start, end in named_intervals: - _feature = CodeFeature( - feature.path, - interval=Interval(start, end), - name=name, - ) - _features.append(_feature) - return _features - - @attr.define(frozen=True) class CodeFeature: """ @@ -112,120 +58,14 @@ def interval_string(self) -> str: def __str__(self, cwd: Optional[Path] = None) -> str: return self.rel_path(cwd) + self.interval_string() - def get_code_message(self, standalone: bool = True) -> list[str]: - """ - Gets this code features code message. - If standalone is true, will include the filename at top and extra newline at the end. - If feature contains entire file, will add inline diff annotations; otherwise, will append them to the end. - """ - if not self.path.exists() or self.path.is_dir(): - return [] - - session_context = SESSION_CONTEXT.get() - code_file_manager = session_context.code_file_manager - parser = session_context.config.parser - code_context = session_context.code_context - - code_message: list[str] = [] - - if standalone: - # We always want to give GPT posix paths - code_message_path = get_relative_path(self.path, session_context.cwd) - code_message.append(str(code_message_path.as_posix())) - - # Get file lines - file_lines = code_file_manager.read_file(self.path) - for i, line in enumerate(file_lines): - if self.interval.contains(i + 1): - if parser.provide_line_numbers(): - code_message.append(f"{i + parser.line_number_starting_index()}:{line}") - else: - code_message.append(f"{line}") - - if standalone: - code_message.append("") - - if self.path in code_context.diff_context.diff_files(): - diff = get_diff_for_file(code_context.diff_context.target, self.path) - diff_annotations = parse_diff(diff) - if self.interval.whole_file(): - code_message = annotate_file_message(code_message, diff_annotations) - else: - for section in diff_annotations: - # TODO: Place diff_annotations inside interval where they belong - if section.start >= self.interval.start and section.start < self.interval.end: - code_message += section.message - return code_message - - def get_checksum(self) -> str: - # TODO: Only update checksum if last modified time of file updates to conserve file system reads - session_context = SESSION_CONTEXT.get() - code_file_manager = session_context.code_file_manager - - return code_file_manager.get_file_checksum(self.path, self.interval) - - def count_tokens(self, model: str) -> int: - ctx = SESSION_CONTEXT.get() - - code_message = self.get_code_message() - return ctx.llm_api_handler.spice.count_tokens("\n".join(code_message), model, is_message=False) - - -async def count_feature_tokens(features: list[CodeFeature], model: str) -> list[int]: - """Return the number of tokens in each feature.""" - sem = asyncio.Semaphore(10) - - feature_tokens = list[int]() - for feature in features: - async with sem: - tokens = feature.count_tokens(model) - feature_tokens.append(tokens) - return feature_tokens - - -def _get_code_message_from_intervals(features: list[CodeFeature]) -> list[str]: - """ - Merge multiple features for the same file into a single code message. - """ - features_sorted = sorted(features, key=lambda f: f.interval) - posix_path = features_sorted[0].get_code_message()[0] - code_message = [posix_path] - next_line = 1 - for feature in features_sorted: - starting_line = feature.interval.start - if starting_line < next_line: - logging.info(f"Features overlap: {feature}") - if feature.interval.end <= next_line: - continue - feature = CodeFeature( - feature.path, - interval=Interval(next_line, feature.interval.end), - name=feature.name, - ) - elif starting_line > next_line: - code_message += ["..."] - code_message += feature.get_code_message(standalone=False) - next_line = feature.interval.end - return code_message + [""] - - -def get_code_message_from_features(features: list[CodeFeature]) -> list[str]: - """ - Generate a code message from a list of features. - Will automatically handle overlapping intervals. - """ - code_message = list[str]() - features_by_path: dict[Path, list[CodeFeature]] = OrderedDict() - for feature in features: - if feature.path not in features_by_path: - features_by_path[feature.path] = list[CodeFeature]() - features_by_path[feature.path].append(feature) - for path_features in features_by_path.values(): - if len(path_features) == 1: - code_message += path_features[0].get_code_message() - else: - code_message += _get_code_message_from_intervals(path_features) - return code_message + +def count_feature_tokens(feature: CodeFeature, model: str) -> int: + ctx = SESSION_CONTEXT.get() + + cwd = ctx.cwd + ref = feature.__str__(cwd) + document = get_document(ref, cwd) + return ctx.llm_api_handler.spice.count_tokens(document, model, is_message=False) def get_consolidated_feature_refs(features: list[CodeFeature]) -> list[str]: diff --git a/mentat/command/commands/search.py b/mentat/command/commands/search.py index d6406ff2c..b804365b6 100644 --- a/mentat/command/commands/search.py +++ b/mentat/command/commands/search.py @@ -2,6 +2,7 @@ from typing_extensions import override +from mentat.code_feature import count_feature_tokens from mentat.command.command import Command, CommandArgument from mentat.errors import UserError from mentat.session_context import SESSION_CONTEXT @@ -60,7 +61,7 @@ async def apply(self, *args: str) -> None: file_interval = feature.interval_string() stream.send(file_interval, color="bright_cyan", end="") - tokens = feature.count_tokens(config.model) + tokens = count_feature_tokens(feature, config.model) cumulative_tokens += tokens tokens_str = f" ({tokens} tokens)" stream.send(tokens_str, color="yellow") diff --git a/mentat/config.py b/mentat/config.py index 5ca1c4b90..97093eb7f 100644 --- a/mentat/config.py +++ b/mentat/config.py @@ -41,13 +41,6 @@ class Config: metadata={"auto_completions": [model.name for model in models if isinstance(model, TextModel)]}, ) provider: Optional[str] = attr.field(default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]}) - feature_selection_model: str = attr.field( - default="gpt-4-1106-preview", - metadata={"auto_completions": [model.name for model in models if isinstance(model, TextModel)]}, - ) - feature_selection_provider: Optional[str] = attr.field( - default=None, metadata={"auto_completions": ["openai", "anthropic", "azure"]} - ) embedding_model: str = attr.field( default="text-embedding-ada-002", metadata={"auto_completions": [model.name for model in models if isinstance(model, EmbeddingModel)]}, @@ -135,21 +128,6 @@ class Config: converter=int, validator=validators.ge(0), # pyright: ignore ) - llm_feature_filter: int = attr.field( # pyright: ignore - default=0, - metadata={ - "description": ( - "Send this many tokens of auto-context-selected code files to an LLM" - " along with the user_prompt to post-select only files which are" - " relevant to the task. Post-files will then be sent to the LLM again" - " to respond to the user's prompt." - ), - "abbreviation": "l", - "const": 5000, - }, - converter=int, - validator=validators.ge(0), # pyright: ignore - ) # Sample specific settings sample_repo: str | None = attr.field( diff --git a/mentat/ctags.py b/mentat/ctags.py deleted file mode 100644 index d68572592..000000000 --- a/mentat/ctags.py +++ /dev/null @@ -1,64 +0,0 @@ -import json -import platform -import subprocess -from functools import cache -from pathlib import Path - -from mentat.errors import MentatError - - -@cache -def ensure_ctags_installed() -> None: - try: - subprocess.run( - ["ctags", "--help"], - check=True, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - ) - return - except subprocess.CalledProcessError: - pass - - os_name = platform.system() - match os_name: - case "Linux": - suggested_install_command = "sudo apt install universal-ctags" - case "Darwin": # macOS - suggested_install_command = "brew install universal-ctags" - case "Windows": - suggested_install_command = "choco install universal-ctags" - case _: - suggested_install_command = None - - error_message = "Missing Dependency: universal-ctags (required for auto-context)" - if suggested_install_command: - error_message += f"\nSuggested install method for your OS: `{suggested_install_command}`" - error_message += "\nSee README.md for full installation details." - raise MentatError(error_message) - - -def get_ctag_lines_and_names(path: Path) -> list[tuple[int, str]]: - ensure_ctags_installed() - - json_tags = ( - subprocess.check_output( - ["ctags", "--output-format=json", "--fields=+n", str(path)], - stderr=subprocess.DEVNULL, - start_new_session=True, - text=True, - ) - .strip() - .splitlines() - ) - - lines_and_names: list[tuple[int, str]] = [] - for json_tag in json_tags: - tag_dict = json.loads(json_tag) - name = tag_dict["name"] - if "scope" in tag_dict: - name = f"{tag_dict['scope']}.{name}" - line = tag_dict["line"] - lines_and_names.append((line, name)) - - return lines_and_names diff --git a/mentat/diff_context.py b/mentat/diff_context.py index 9c8109b5d..f33c46dd2 100644 --- a/mentat/diff_context.py +++ b/mentat/diff_context.py @@ -2,9 +2,6 @@ from pathlib import Path from typing import List, Literal, Optional -import attr - -from mentat.errors import MentatError from mentat.git_handler import ( check_head_exists, get_diff_for_file, @@ -13,84 +10,14 @@ get_treeish_metadata, get_untracked_files, ) -from mentat.interval import Interval from mentat.session_context import SESSION_CONTEXT from mentat.session_stream import SessionStream -@attr.define(frozen=True) -class DiffAnnotation(Interval): - start: int | float = attr.field() - message: List[str] = attr.field() - end: int | float = attr.field( - default=attr.Factory( - lambda self: self.start + sum(bool(line.startswith("-")) for line in self.message), - takes_self=True, - ) - ) - - -def parse_diff(diff: str) -> list[DiffAnnotation]: - """Parse diff into a list of annotations.""" - annotations: list[DiffAnnotation] = [] - active_annotation: Optional[DiffAnnotation] = None - lines = diff.splitlines() - for line in lines: - if line.startswith(("---", "+++", "//", "diff", "index")): - continue - elif line.startswith("@@"): - if active_annotation: - annotations.append(active_annotation) - _new_index = line.split(" ")[2] - if "," in _new_index: - new_start = _new_index[1:].split(",")[0] - else: - new_start = _new_index[1:] - active_annotation = DiffAnnotation(start=int(new_start), message=[]) - elif line.startswith(("+", "-")): - if not active_annotation: - raise MentatError("Invalid diff") - active_annotation.message.append(line) - if active_annotation: - annotations.append(active_annotation) - annotations.sort(key=lambda a: a.start) - return annotations - - -def annotate_file_message(code_message: list[str], annotations: list[DiffAnnotation]) -> list[str]: - """Return the code_message with annotations inserted.""" - active_index = 0 - annotated_message: list[str] = [] - for annotation in annotations: - # Fill-in lines between annotations - if active_index < annotation.start: - unaffected_lines = code_message[active_index : annotation.start] - annotated_message += unaffected_lines - active_index = annotation.start - if annotation.start == 0: - # Make sure the PATH stays on line 1 - annotated_message.append(code_message[0]) - active_index += 1 - i_minus = None - for line in annotation.message: - sign = line[0] - if sign == "+": - # Add '+' lines in place of code_message lines - annotated_message.append(f"{active_index}:{line}") - active_index += 1 - i_minus = None - elif sign == "-": - # Insert '-' lines at the point they were removed - i_minus = 0 if i_minus is None else i_minus - annotated_message.append(f"{annotation.start + i_minus}:{line}") - i_minus += 1 - if active_index < len(code_message): - annotated_message += code_message[active_index:] - - return annotated_message - - class DiffContext: + target: str = "" + name: str = "index (last commit)" + def __init__( self, stream: SessionStream, @@ -114,8 +41,6 @@ def __init__( target = diff or pr_diff if not target: - self.target = "HEAD" - self.name = "HEAD (last commit)" return name = "" @@ -123,13 +48,11 @@ def __init__( if treeish_type is None: stream.send(f"Invalid treeish: {target}", style="failure") stream.send("Disabling diff and pr-diff.", style="warning") - self.target = "HEAD" - self.name = "HEAD (last commit)" return if treeish_type == "branch": name += f"Branch {target}: " - elif treeish_type == "relative": + elif treeish_type in {"relative"}: name += f"{target}: " if pr_diff: @@ -141,15 +64,18 @@ def __init__( f"Cannot identify merge base between HEAD and {pr_diff}. Disabling pr-diff.", style="warning", ) - self.target = "HEAD" - self.name = "HEAD (last commit)" return - meta = get_treeish_metadata(self.git_root, target) - name += f'{meta["hexsha"][:8]}: {meta["summary"]}' - if target == "HEAD": - name = "HEAD (last commit)" + def _get_treeish_metadata(git_root: Path, _target: str): + meta = get_treeish_metadata(git_root, _target) + return f'{meta["hexsha"][:8]}: {meta["summary"]}' + if not target: + return + elif treeish_type == "compare": + name += "Comparing " + ", ".join(_get_treeish_metadata(self.git_root, part) for part in target.split(" ")) + else: + name += _get_treeish_metadata(self.git_root, target) self.target = target self.name = name @@ -182,12 +108,6 @@ def refresh(self): self._diff_files = [(ctx.cwd / f).resolve() for f in get_files_in_diff(self.target)] self._untracked_files = [(ctx.cwd / f).resolve() for f in get_untracked_files(ctx.cwd)] - def get_annotations(self, rel_path: Path) -> list[DiffAnnotation]: - if not self.git_root: - return [] - diff = get_diff_for_file(self.target, rel_path) - return parse_diff(diff) - def get_display_context(self) -> Optional[str]: if not self.git_root: return None @@ -202,15 +122,8 @@ def get_display_context(self) -> Optional[str]: num_lines += len([line for line in diff_lines if line.startswith(("+ ", "- "))]) return f" {self.name} | {num_files} files | {num_lines} lines" - def annotate_file_message(self, rel_path: Path, file_message: list[str]) -> list[str]: - """Return file_message annotated with active diff.""" - if not self.git_root: - return [] - annotations = self.get_annotations(rel_path) - return annotate_file_message(file_message, annotations) - -TreeishType = Literal["commit", "branch", "relative"] +TreeishType = Literal["commit", "branch", "relative", "compare"] def _git_command(git_root: Path, *args: str) -> str | None: @@ -221,6 +134,13 @@ def _git_command(git_root: Path, *args: str) -> str | None: def _get_treeish_type(git_root: Path, treeish: str) -> TreeishType | None: + if " " in treeish: + parts = treeish.split(" ") + types = [_get_treeish_type(git_root, part) for part in parts] + if not all(types): + return None + return "compare" + object_type = _git_command(git_root, "cat-file", "-t", treeish) if not object_type: diff --git a/mentat/embeddings.py b/mentat/embeddings.py deleted file mode 100644 index 22a4b15b3..000000000 --- a/mentat/embeddings.py +++ /dev/null @@ -1,155 +0,0 @@ -import logging -from timeit import default_timer - -import chromadb -from chromadb.api.types import Embeddable, EmbeddingFunction, Embeddings -from spice.spice import get_model_from_name - -from mentat.code_feature import CodeFeature, count_feature_tokens -from mentat.errors import MentatError -from mentat.session_context import SESSION_CONTEXT -from mentat.session_input import ask_yes_no -from mentat.utils import mentat_dir_path - -EMBEDDINGS_API_BATCH_SIZE = 1536 - -client = chromadb.PersistentClient(path=str(mentat_dir_path / "chroma")) - - -class MentatEmbeddingFunction(EmbeddingFunction[Embeddable]): - def __call__(self, input: Embeddable) -> Embeddings: - if not all(isinstance(item, str) for item in input): - raise MentatError("MentatEmbeddings only enabled for text files") - session_context = SESSION_CONTEXT.get() - config = session_context.config - llm_api_handler = session_context.llm_api_handler - - n_batches = 0 if len(input) == 0 else len(input) // EMBEDDINGS_API_BATCH_SIZE + 1 - output: Embeddings = [] - for batch in range(n_batches): - i_start, i_end = ( - batch * EMBEDDINGS_API_BATCH_SIZE, - (batch + 1) * EMBEDDINGS_API_BATCH_SIZE, - ) - response = llm_api_handler.call_embedding_api(input[i_start:i_end], config.embedding_model) - output += response.embeddings - return output - - -class Collection: - def __init__(self, embedding_model: str): - self._collection = client.get_or_create_collection( - name=f"mentat-{embedding_model}", - embedding_function=MentatEmbeddingFunction(), - ) - - def exists(self, id: str) -> bool: - assert self._collection is not None, "Collection not initialized" - return len(self._collection.get(id)["ids"]) > 0 - - def add(self, checksums: list[str], texts: list[str]) -> None: - assert self._collection is not None, "Collection not initialized" - return self._collection.add( # type: ignore - ids=checksums, - documents=texts, - metadatas=[{"active": False} for _ in checksums], - ) - - def query(self, prompt: str, checksums: list[str]) -> dict[str, float]: - assert self._collection is not None, "Collection not initialized" - - self._collection.update( # type: ignore - ids=checksums, - metadatas=[{"active": True} for _ in checksums], - ) - results = self._collection.query( # type: ignore - query_texts=[prompt], - where={"active": True}, - n_results=len(checksums), - ) - self._collection.update( # type: ignore - ids=checksums, - metadatas=[{"active": False} for _ in checksums], - ) - assert results["distances"], "Error calculating distances" - return {c: e for c, e in zip(results["ids"][0], results["distances"][0])} - - -async def get_feature_similarity_scores( - prompt: str, - features: list[CodeFeature], -) -> list[float]: - """Return the similarity scores for a given prompt and list of features.""" - session_context = SESSION_CONTEXT.get() - stream = session_context.stream - config = session_context.config - embedding_model = session_context.config.embedding_model - - max_model_tokens = get_model_from_name(embedding_model).context_length - if max_model_tokens is None: - stream.send( - f"Warning: Could not determine context size for model {embedding_model}. Using default value of 8192.", - style="warning", - ) - max_model_tokens = 8192 - - # Initialize DB - collection = Collection(embedding_model) - - # Identify which items need embeddings. - checksums: list[str] = [f.get_checksum() for f in features] - ignored_checksums = set[str]() - tokens: list[int] = await count_feature_tokens(features, embedding_model) - embed_texts = list[str]() - embed_checksums = list[str]() - embed_tokens = list[int]() - for feature, checksum, token in zip(features, checksums, tokens): - if token > max_model_tokens: - stream.send( - f"Warning: Feature {str(feature)} has {token} tokens, which exceeds the" - f" maximum of {max_model_tokens} for model {config.embedding_model}." - " Skipping." - ) - ignored_checksums.add(checksum) - continue - if not collection.exists(checksum) and checksum not in embed_checksums: - embed_texts.append("\n".join(feature.get_code_message())) - embed_checksums.append(checksum) - embed_tokens.append(token) - - # If it costs more than $1, get confirmation from user. - cost = get_model_from_name(embedding_model).input_cost - if cost is None: - stream.send( - "Warning: Could not determine cost of embeddings. Continuing anyway.", - style="warning", - ) - else: - expected_cost = (sum(embed_tokens) * cost) / 1_000_000 / 100 - if expected_cost > 1: - stream.send(f"Embedding {sum(embed_tokens)} tokens will cost ${expected_cost:.2f}. Continue anyway?") - if not await ask_yes_no(default_yes=True): - stream.send("Ignoring embeddings for now.") - return [0.0 for _ in checksums] - - # Load embeddings - if embed_texts: - start_time = default_timer() - stream.send(None, channel="loading") - collection.add(embed_checksums, embed_texts) - total_time = default_timer() - start_time - - if cost is not None: - call_cost = (sum(embed_tokens) * cost) / 1_000_000 / 100 - - costs_logger = logging.getLogger("costs") - costs_logger.info(f"Cost: ${call_cost:.2f}") - - stream.send(f"Embedding call time and cost: {total_time:.2f}s, ${call_cost:.2f}", style="info") - - # Get similarity scores - stream.send(None, channel="loading", terminate=True) - _checksums = list(c for c in set(checksums) if c not in ignored_checksums) - scores = collection.query(prompt, _checksums) - - return [scores.get(f.get_checksum(), 0) for f in features] diff --git a/mentat/feature_filters/__init__.py b/mentat/feature_filters/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/mentat/feature_filters/default_filter.py b/mentat/feature_filters/default_filter.py deleted file mode 100644 index c7a5ba1a0..000000000 --- a/mentat/feature_filters/default_filter.py +++ /dev/null @@ -1,41 +0,0 @@ -from typing import Optional - -from mentat.code_feature import CodeFeature -from mentat.errors import ModelError, ReturnToUser -from mentat.feature_filters.embedding_similarity_filter import EmbeddingSimilarityFilter -from mentat.feature_filters.feature_filter import FeatureFilter -from mentat.feature_filters.llm_feature_filter import LLMFeatureFilter -from mentat.feature_filters.truncate_filter import TruncateFilter -from mentat.session_context import SESSION_CONTEXT - - -class DefaultFilter(FeatureFilter): - def __init__( - self, - max_tokens: int, - user_prompt: Optional[str] = None, - expected_edits: Optional[list[str]] = None, - ): - self.max_tokens = max_tokens - self.user_prompt = user_prompt or "" - self.expected_edits = expected_edits - - async def filter(self, features: list[CodeFeature]) -> list[CodeFeature]: - ctx = SESSION_CONTEXT.get() - use_llm = bool(ctx.config.llm_feature_filter) - - if ctx.config.auto_context_tokens > 0 and self.user_prompt != "": - features = await EmbeddingSimilarityFilter(self.user_prompt).filter(features) - - if use_llm: - try: - features = await LLMFeatureFilter(self.max_tokens, self.user_prompt, self.expected_edits).filter( - features - ) - except (ModelError, ReturnToUser): - ctx.stream.send("Feature-selection LLM response invalid. Using TruncateFilter instead.") - features = await TruncateFilter(self.max_tokens, ctx.config.model).filter(features) - else: - features = await TruncateFilter(self.max_tokens, ctx.config.model).filter(features) - - return features diff --git a/mentat/feature_filters/embedding_similarity_filter.py b/mentat/feature_filters/embedding_similarity_filter.py deleted file mode 100644 index cf84a2249..000000000 --- a/mentat/feature_filters/embedding_similarity_filter.py +++ /dev/null @@ -1,30 +0,0 @@ -from mentat.code_feature import CodeFeature -from mentat.feature_filters.feature_filter import FeatureFilter - - -class EmbeddingSimilarityFilter(FeatureFilter): - def __init__(self, query: str): - self.query = query - - async def score( - self, - features: list[CodeFeature], - ) -> list[tuple[CodeFeature, float]]: - from mentat.embeddings import ( # dynamic import to improve startup time. - get_feature_similarity_scores, - ) - - if self.query == "": - return [(f, 0.0) for f in features] - - sim_scores = await get_feature_similarity_scores(self.query, features) - features_scored = zip(features, sim_scores) - return sorted(features_scored, key=lambda x: x[1]) - - async def filter( - self, - features: list[CodeFeature], - ) -> list[CodeFeature]: - if self.query == "": - return features - return [f for f, _ in await self.score(features)] diff --git a/mentat/feature_filters/feature_filter.py b/mentat/feature_filters/feature_filter.py deleted file mode 100644 index 7ac602502..000000000 --- a/mentat/feature_filters/feature_filter.py +++ /dev/null @@ -1,22 +0,0 @@ -from abc import ABC, abstractmethod - -from mentat.code_feature import CodeFeature - - -class FeatureFilter(ABC): - """ - Tools to pare down a list of Features to a final list to be put in an LLMs context. - Despite the name they may not be pure filters: - * New CodeFeatures may be introduced by splitting Features into intervals. - * The CodeFeatures may be reordered by priority in future steps. - - A feature filter may want more information than the list of features and user prompt. If so it - can get it from the config or have the information passed into its constructor. - """ - - @abstractmethod - async def filter( - self, - features: list[CodeFeature], - ) -> list[CodeFeature]: - raise NotImplementedError() diff --git a/mentat/feature_filters/llm_feature_filter.py b/mentat/feature_filters/llm_feature_filter.py deleted file mode 100644 index ece78985a..000000000 --- a/mentat/feature_filters/llm_feature_filter.py +++ /dev/null @@ -1,127 +0,0 @@ -import json -from pathlib import Path -from typing import Optional, Set - -from openai.types.chat import ( - ChatCompletionAssistantMessageParam, - ChatCompletionMessageParam, - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -) -from openai.types.chat.completion_create_params import ResponseFormat -from spice.spice import get_model_from_name - -from mentat.code_feature import CodeFeature, get_code_message_from_features -from mentat.errors import ModelError, PathValidationError, UserError -from mentat.feature_filters.feature_filter import FeatureFilter -from mentat.feature_filters.truncate_filter import TruncateFilter -from mentat.include_files import get_code_features_for_path -from mentat.prompts.prompts import read_prompt -from mentat.session_context import SESSION_CONTEXT - - -class LLMFeatureFilter(FeatureFilter): - feature_selection_prompt_path = Path("feature_selection_prompt.txt") - - def __init__( - self, - max_tokens: int, - user_prompt: Optional[str] = None, - expected_edits: Optional[list[str]] = None, - ): - self.max_tokens = max_tokens - self.user_prompt = user_prompt or "" - self.expected_edits = expected_edits - - async def filter( - self, - features: list[CodeFeature], - ) -> list[CodeFeature]: - session_context = SESSION_CONTEXT.get() - stream = session_context.stream - config = session_context.config - llm_api_handler = session_context.llm_api_handler - - stream.send(None, channel="loading") - - # Preselect as many features as fit in the context window - model = config.feature_selection_model - context_size = get_model_from_name(model).context_length - if context_size is None: - raise UserError("Unknown context size for feature selection model: " f"{config.feature_selection_model}") - context_size = min(context_size, config.llm_feature_filter) - system_prompt = read_prompt(self.feature_selection_prompt_path) - system_prompt_tokens = llm_api_handler.spice.count_tokens( - system_prompt, config.feature_selection_model, is_message=True - ) - user_prompt_tokens = llm_api_handler.spice.count_tokens(self.user_prompt, model, is_message=True) - expected_edits_tokens = ( - 0 - if not self.expected_edits - else llm_api_handler.spice.count_tokens("\n".join(self.expected_edits), model, is_message=True) - ) - preselect_max_tokens = ( - context_size - system_prompt_tokens - user_prompt_tokens - expected_edits_tokens - config.token_buffer - ) - truncate_filter = TruncateFilter(preselect_max_tokens, model) - preselected_features = await truncate_filter.filter(features) - - # Ask the model to return only relevant features - messages: list[ChatCompletionMessageParam] = [ - ChatCompletionSystemMessageParam(role="system", content=system_prompt), - ChatCompletionSystemMessageParam( - role="system", - content="\n".join(["CODE FILES:"] + get_code_message_from_features(preselected_features)), - ), - ChatCompletionUserMessageParam(role="user", content=f"USER QUERY: {self.user_prompt}"), - ] - if self.expected_edits: - messages.append( - ChatCompletionAssistantMessageParam(role="assistant", content=f"Expected Edits:\n{self.expected_edits}") - ) - messages.append( - ChatCompletionSystemMessageParam( - role="system", - content=( - "Now, identify the CODE FILES that are relevant to answering the" - " USER QUERY, Return a dict of {path: reason} for each file you" - " identify as relevant. e.g. {'src/main.js': 'Create new file'," - " 'public/index.html': 'Import main.js'}" - ), - ) - ) - selected_refs = list[Path]() - llm_response = await llm_api_handler.call_llm_api( - messages=messages, - model=model, - provider=config.feature_selection_provider, - stream=False, - response_format=ResponseFormat(type="json_object"), - ) - message = llm_response.text - stream.send(None, channel="loading", terminate=True) - - # Parse response into features - try: - response = json.loads(message) # type: ignore - selected_refs = [Path(r) for r in response] - except json.JSONDecodeError: - raise ModelError(f"The response is not valid json: {message}") - postselected_features: Set[CodeFeature] = set() - for selected_ref in selected_refs: - try: - parsed_features = get_code_features_for_path(path=selected_ref, cwd=session_context.cwd) - for feature in parsed_features: - assert any( - in_feat.path == feature.path and in_feat.interval.intersects for in_feat in preselected_features - ) - postselected_features.add(feature) - except (PathValidationError, AssertionError): - stream.send( - f"LLM selected invalid path: {selected_ref}, skipping.", - style="warning", - ) - - # Truncate again to enforce max_tokens - truncate_filter = TruncateFilter(self.max_tokens, config.model) - return await truncate_filter.filter(postselected_features) diff --git a/mentat/feature_filters/truncate_filter.py b/mentat/feature_filters/truncate_filter.py deleted file mode 100644 index b9905125b..000000000 --- a/mentat/feature_filters/truncate_filter.py +++ /dev/null @@ -1,28 +0,0 @@ -from typing import Iterable - -from mentat.code_feature import CodeFeature -from mentat.feature_filters.feature_filter import FeatureFilter - - -class TruncateFilter(FeatureFilter): - def __init__( - self, - max_tokens: int, - model: str = "gpt-4", - ): - self.max_tokens = max_tokens - self.model = model - - async def filter( - self, - features: Iterable[CodeFeature], - ) -> list[CodeFeature]: - """Truncate the features to max_token tokens.""" - output = list[CodeFeature]() - remaining_tokens = self.max_tokens - for feature in features: - if feature.count_tokens(self.model) <= remaining_tokens: - output.append(feature) - remaining_tokens -= feature.count_tokens(self.model) - - return output diff --git a/mentat/git_handler.py b/mentat/git_handler.py index 5fb9499b4..5b0ed4797 100644 --- a/mentat/git_handler.py +++ b/mentat/git_handler.py @@ -133,8 +133,9 @@ def get_diff_for_file(target: str, path: Path) -> str: session_context = SESSION_CONTEXT.get() try: + args = target.split(" ") if target else [] diff_content = subprocess.check_output( - ["git", "diff", "-U0", f"{target}", "--", path], + ["git", "diff", "-U0", *args, "--", path], cwd=session_context.cwd, text=True, stderr=subprocess.DEVNULL, @@ -167,14 +168,16 @@ def get_files_in_diff(target: str) -> list[Path]: session_context = SESSION_CONTEXT.get() try: + args = target.split(" ") if target else [] diff_content = subprocess.check_output( - ["git", "diff", "--name-only", f"{target}", "--"], + ["git", "diff", "--name-only", *args, "--"], cwd=session_context.cwd, text=True, stderr=subprocess.DEVNULL, ).strip() if diff_content: - return [Path(path) for path in diff_content.split("\n")] + paths = [Path(path) for path in diff_content.split("\n") if path] + return paths else: return [] except subprocess.CalledProcessError: diff --git a/mentat/sampler/utils.py b/mentat/sampler/utils.py index b4d3e0c2c..f0511a1f0 100644 --- a/mentat/sampler/utils.py +++ b/mentat/sampler/utils.py @@ -89,6 +89,12 @@ def setup_repo( if errors: raise SampleError(f"Error applying diff_active: {errors}") + # Add a name/email if missing + if not repo.config_reader().has_option("user", "name"): + repo.git.config("user.name", "Test User") + if not repo.config_reader().has_option("user", "email"): + repo.git.config("user.email", "test@example.com") + return repo diff --git a/mentat/session.py b/mentat/session.py index bc8933766..0ebb05078 100644 --- a/mentat/session.py +++ b/mentat/session.py @@ -23,7 +23,6 @@ from mentat.code_file_manager import CodeFileManager from mentat.config import Config from mentat.conversation import Conversation -from mentat.ctags import ensure_ctags_installed from mentat.errors import MentatError, ReturnToUser, SessionExit, UserError from mentat.llm_api_handler import LlmApiHandler, is_test_environment from mentat.logging_config import setup_logging @@ -163,11 +162,11 @@ async def _main(self): code_file_manager = session_context.code_file_manager agent_handler = session_context.agent_handler - # check early for ctags so we can fail fast - if session_context.config.auto_context_tokens > 0: - ensure_ctags_installed() - await session_context.llm_api_handler.initialize_client() + + print("Scanning codebase for updates...") + await code_context.refresh_daemon() + check_model() need_user_request = True diff --git a/requirements.txt b/requirements.txt index e608b030b..8d2cad85a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ attrs==23.1.0 backoff==2.2.1 -chromadb==0.4.22 +chromadb==0.4.24 fire==0.5.0 gitpython==3.1.41 httpx==0.25.1 -jinja2==3.1.2 +jinja2==3.1.3 jsonschema>=4.17.0 numpy==1.26.0 -openai==1.3.0 +openai==1.13.3 pillow==10.1.0 prompt-toolkit==3.0.39 Pygments==2.15.1 @@ -17,6 +17,7 @@ pytest-mock==3.11.1 pytest-reportlog==0.4.0 pytest-timeout==2.2.0 python-dotenv==1.0.0 +ragdaemon==0.1.4 selenium==4.15.2 sentry-sdk==1.34.0 sounddevice==0.4.6 @@ -25,8 +26,8 @@ spiceai==0.1.11 termcolor==2.3.0 textual==0.47.1 textual-autocomplete==2.1.0b0 -tiktoken==0.4.0 +tiktoken==0.6.0 typing_extensions==4.8.0 -tqdm==4.66.1 +tqdm>=4.66.1 webdriver_manager==4.0.1 watchfiles==0.21.0 diff --git a/scripts/sampler/__main__.py b/scripts/sampler/__main__.py index e26370f66..18534971b 100644 --- a/scripts/sampler/__main__.py +++ b/scripts/sampler/__main__.py @@ -10,10 +10,7 @@ from pathlib import Path from typing import Any -from add_context import add_context from finetune import generate_finetune -from remove_context import remove_context -from spice import Spice from validate import validate_sample from mentat.sampler.sample import Sample @@ -50,13 +47,6 @@ async def main(): help="Validate samples conform to spec", ) parser.add_argument("--finetune", "-f", action="store_true", help="Generate fine-tuning examples") - parser.add_argument("--add-context", "-a", action="store_true", help="Add extra context to samples") - parser.add_argument( - "--remove-context", - "-r", - action="store_true", - help="Remove context from samples", - ) args = parser.parse_args() sample_files = [] if args.sample_ids: @@ -81,11 +71,6 @@ async def main(): except Exception as e: warn(f"Error loading sample {sample_file}: {e}") continue - if (args.add_context or args.remove_context) and ( - "[ADDED CONTEXT]" in sample.title or "[REMOVED CONTEXT]" in sample.title - ): - warn(f"Skipping {sample.id[:8]}: has already been modified.") - continue if args.validate: is_valid, reason = await validate_sample(sample) status = "\033[92mPASSED\033[0m" if is_valid else f"\033[91mFAILED: {reason}\033[0m" @@ -104,26 +89,6 @@ async def main(): logs.append(example) except Exception as e: warn(f"Error generating finetune example for sample {sample.id[:8]}: {e}") - elif args.add_context: - try: - new_sample = await add_context(sample) - sample_file = SAMPLES_DIR / f"sample_{new_sample.id}.json" - new_sample.save(sample_file) - print(f"Generated new sample with extra context: {sample_file}") - logs.append({"id": new_sample.id, "prototype_id": sample.id}) - except Exception as e: - warn(f"Error adding extra context to sample {sample.id[:8]}: {e}") - elif args.remove_context: - if not sample.context or len(sample.context) == 1: - warn(f"Skipping {sample.id[:8]}: no context to remove.") - continue - try: - new_sample = await remove_context(sample) - new_sample.save(SAMPLES_DIR / f"sample_{new_sample.id}.json") - print(f"Generated new sample with context removed: {sample_file}") - logs.append({"id": new_sample.id, "prototype_id": sample.id}) - except Exception as e: - warn(f"Error removing context from sample {sample.id[:8]}: {e}") else: print(f"Running sample {sample.id[:8]}") print(f" Prompt: {sample.message_prompt}") @@ -161,10 +126,6 @@ async def main(): del log["tokens"] f.write(json.dumps(log) + "\n") print(f"{len(logs)} fine-tuning examples ({tokens} tokens) saved to {fname}.") - elif args.add_context: - print(f"{len(logs)} samples with extra context generated.") - elif args.remove_context: - print(f"{len(logs)} samples with context removed generated.") if __name__ == "__main__": diff --git a/scripts/sampler/add_context.py b/scripts/sampler/add_context.py deleted file mode 100644 index ba49ad2fa..000000000 --- a/scripts/sampler/add_context.py +++ /dev/null @@ -1,41 +0,0 @@ -from pathlib import Path -from uuid import uuid4 - -import attr - -from mentat.code_feature import get_consolidated_feature_refs -from mentat.python_client.client import PythonClient -from mentat.sampler.sample import Sample -from mentat.sampler.utils import setup_repo -from mentat.session_context import SESSION_CONTEXT - - -async def add_context(sample, extra_tokens: int = 5000) -> Sample: - """Return a duplicate sample with extra (auto-context generated) context.""" - # Setup mentat CodeContext with included_files - repo = setup_repo( - url=sample.repo, - commit=sample.merge_base, - diff_merge_base=sample.diff_merge_base, - diff_active=sample.diff_active, - ) - cwd = Path(repo.working_dir) - paths = list[Path]() - for a in sample.context: - paths.append(Path(a)) - python_client = PythonClient(cwd=cwd, paths=paths) - await python_client.startup() - - # Use auto-context to add extra tokens, then copy the resulting features - ctx = SESSION_CONTEXT.get() - ctx.config.auto_context_tokens = extra_tokens - _ = await ctx.code_context.get_code_message(prompt_tokens=0, prompt=sample.message_prompt) - included_features = list(f for fs in ctx.code_context.include_files.values() for f in fs) - all_features = get_consolidated_feature_refs(included_features) - await python_client.shutdown() - - new_sample = Sample(**attr.asdict(sample)) - new_sample.context = [str(f) for f in all_features] - new_sample.id = uuid4().hex - new_sample.title = f"{sample.title} [ADDED CONTEXT]" - return new_sample diff --git a/scripts/sampler/remove_context.py b/scripts/sampler/remove_context.py deleted file mode 100644 index 2581f51d4..000000000 --- a/scripts/sampler/remove_context.py +++ /dev/null @@ -1,132 +0,0 @@ -import random -from pathlib import Path -from textwrap import dedent -from uuid import uuid4 - -import attr -from openai.types.chat import ( - ChatCompletionSystemMessageParam, - ChatCompletionUserMessageParam, -) - -from mentat.code_feature import CodeFeature, get_code_message_from_features -from mentat.errors import SampleError -from mentat.python_client.client import PythonClient -from mentat.sampler.sample import Sample -from mentat.sampler.utils import setup_repo - - -async def remove_context(sample) -> Sample: - """Return a duplicate sample with one context item removed and a warning message""" - - # Setup the repo and load context files - repo = setup_repo( - url=sample.repo, - commit=sample.merge_base, - diff_merge_base=sample.diff_merge_base, - diff_active=sample.diff_active, - ) - cwd = Path(repo.working_dir) - python_client = PythonClient(cwd=Path("."), paths=[]) - await python_client.startup() - - context = [CodeFeature(cwd / p) for p in sample.context] - i_target = random.randint(0, len(context) - 1) - target = context[i_target] - print("-" * 80) - print("Prompt\n", sample.message_prompt) - print("Context\n", sample.context) - print("Removed:", target) - print("") - - # Build conversation: [rejection_prompt, message_prompt, keep_context, remove_context] - target_context = target.get_code_message(standalone=False) - background_features = context[:i_target] + context[i_target + 1 :] - background_context = "\n".join(get_code_message_from_features(background_features)) - messages = [ - ChatCompletionSystemMessageParam( - role="system", - content=dedent( - """\ - You are part of an LLM Coding Assistant, designed to answer questions and - complete tasks for developers. Specifically, you generate examples of - interactions where the user has not provided enough context to fulfill the - query. You will be shown an example query, some background code which will - be included, and some target code which is NOT be included. - - Pretend you haven't seen the target code, and tell the user what additional - information you'll need in order to fulfill the task. Take a deep breath, - focus, and then complete your task by following this procedure: - - 1. Read the USER QUERY (below) carefully. Consider the steps involved in - completing it. - 2. Read the BACKROUND CONTEXT (below that) carefully. Consider how it - contributes to completing the task. - 3. Read the TARGET CONTEXT (below that) carefully. Consider how it - contributes to completing the task. - 4. Think of a short (1-sentence) explanation of why the TARGET CONTEXT is - required to complete the task. - 5. Return a ~1 paragraph message to the user explaining why the BACKGROUND - CONTEXT is not sufficient to answer the question. - - REMEMBER: - * Don't reference TARGET CONTEXT specifically. Answer as if you've never - seen it, you just know you're missing something essential. - * Return #5 (your response to the user) as a single sentence, without - preamble, notes, extra spacing or additional commentary. - - EXAMPLE - ============= - USER QUERY: "Can you make it so that I can write questions/answers in a - list at the top of the file, and then use that list to populate the - component." - BACKGROUND_CONTEXT: "" - TARGET_CONTEXT: - RESPONSE: "No code files have been included. In order to make the - requested changes, I need to see the context related to \"writing - questions/answers\" and \"populating the component\"." - """ - ), - ), - ChatCompletionUserMessageParam(role="user", content=f"USER QUERY:\n{sample.message_prompt}"), - ChatCompletionSystemMessageParam( - role="system", - content=f"BACKGROUND CONTEXT:\n{background_context}", - ), - ChatCompletionSystemMessageParam( - role="system", - content=f"TARGET CONTEXT:\n{target_context}", - ), - ] - - # Ask gpt-4 to generate rejection prompt - llm_api_handler = python_client.session.ctx.llm_api_handler - llm_api_handler.initialize_client() - llm_response = await llm_api_handler.call_llm_api( - messages=messages, - model=python_client.session.ctx.config.model, - stream=False, - ) - message = (llm_response.choices[0].message.content) or "" - await python_client.shutdown() - - # Ask user to review and accept/reject - print("Generated reason:", message) - print("Press ENTER to accept, 's' to skip this sample, or type a new reason to reject.") - response = input() - if response: - if response.lower() == "s": - raise SampleError("Skipping sample.") - message = response - if not message: - raise SampleError("No rejection reason provided. Aborting.") - - # Create and return a duplicate/udpated sample - new_sample = Sample(**attr.asdict(sample)) - new_sample.context = [str(f) for f in background_features] - new_sample.id = uuid4().hex - new_sample.title = f"{sample.title} [REMOVED CONTEXT]" - new_sample.message_edit = message - new_sample.diff_edit = "" - - return new_sample diff --git a/tests/code_context_test.py b/tests/code_context_test.py index dc9c2e5e7..5c046f5ca 100644 --- a/tests/code_context_test.py +++ b/tests/code_context_test.py @@ -2,19 +2,18 @@ from pathlib import Path from textwrap import dedent from unittest import TestCase -from unittest.mock import AsyncMock import pytest from mentat.code_context import CodeContext from mentat.config import Config -from mentat.feature_filters.default_filter import DefaultFilter from mentat.git_handler import get_non_gitignored_files from mentat.include_files import is_file_text_encoded from mentat.interval import Interval from tests.conftest import run_git_command +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_path_gitignoring(temp_testbed, mock_code_context): gitignore_path = ".gitignore" @@ -47,6 +46,7 @@ async def test_path_gitignoring(temp_testbed, mock_code_context): case.assertListEqual(sorted(expected_file_paths), sorted(file_paths)) +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_bracket_file(temp_testbed, mock_code_context): file_path_1 = Path("[file].tsx") @@ -69,6 +69,7 @@ async def test_bracket_file(temp_testbed, mock_code_context): case.assertListEqual(sorted(expected_file_paths), sorted(file_paths)) +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_config_glob_exclude(mocker, temp_testbed, mock_code_context): # Makes sure glob exclude config works @@ -95,6 +96,7 @@ async def test_config_glob_exclude(mocker, temp_testbed, mock_code_context): assert Path(temp_testbed / directly_added_glob_excluded_path) in mock_code_context.include_files +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_glob_include(temp_testbed, mock_code_context): # Make sure glob include works @@ -120,6 +122,7 @@ async def test_glob_include(temp_testbed, mock_code_context): assert os.path.join(temp_testbed, glob_include_path2) in file_paths +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_cli_glob_exclude(temp_testbed, mock_code_context): # Make sure cli glob exclude works and overrides regular include @@ -140,6 +143,7 @@ async def test_cli_glob_exclude(temp_testbed, mock_code_context): assert os.path.join(temp_testbed, glob_exclude_path) not in file_paths +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_text_encoding_checking(temp_testbed, mock_session_context): # Makes sure we don't include non text encoded files, and we quit if user gives us one @@ -168,6 +172,7 @@ async def test_text_encoding_checking(temp_testbed, mock_session_context): assert not code_context.include_files +@pytest.mark.ragdaemon @pytest.mark.asyncio @pytest.mark.clear_testbed async def test_max_auto_tokens(mocker, temp_testbed, mock_session_context): @@ -201,22 +206,40 @@ def func_4(string): code_context = CodeContext( mock_session_context.stream, - mock_session_context.code_context.diff_context.git_root, + temp_testbed, ) + await code_context.refresh_daemon() code_context.include("file_1.py") mock_session_context.config.auto_context_tokens = 8000 - filter_mock = AsyncMock(side_effect=lambda features: features) - mocker.patch.object(DefaultFilter, "filter", side_effect=filter_mock) - - async def _count_max_tokens_where(tokens_used: int) -> int: - code_message = await code_context.get_code_message(tokens_used, prompt="prompt") - return mock_session_context.llm_api_handler.spice.count_tokens(code_message, "gpt-4", is_message=True) - assert await _count_max_tokens_where(0) == 89 # Code + code_message = await code_context.get_code_message(0, prompt="prompt") + assert mock_session_context.llm_api_handler.spice.count_tokens(code_message, "gpt-4", is_message=True) == 95 # Code + assert ( + code_message + == """\ +Code Files: + +file_1.py (search-result, user-included) +1:def func_1(x, y): +2: return x + y +3: +4:def func_2(): +5: return 3 + +file_2.py (search-result) +1:def func_3(a, b, c): +2: return a * b ** c +3: +4:def func_4(string): +5: print(string) +""" + ) +@pytest.mark.ragdaemon +@pytest.mark.asyncio @pytest.mark.clear_testbed -def test_get_all_features(temp_testbed, mock_code_context): +async def test_get_all_features(temp_testbed, mock_session_context): # Create a sample file path1 = Path(temp_testbed) / "sample_path1.py" path2 = Path(temp_testbed) / "sample_path2.py" @@ -225,6 +248,12 @@ def test_get_all_features(temp_testbed, mock_code_context): with open(path2, "w") as file2: file2.write("def sample_function():\n pass\n") + mock_code_context = CodeContext( + mock_session_context.stream, + temp_testbed, + ) + await mock_code_context.refresh_daemon() + # Test without include_files features = mock_code_context.get_all_features() assert len(features) == 2 @@ -244,12 +273,11 @@ def test_get_all_features(temp_testbed, mock_code_context): assert feature2b.interval.whole_file() +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_get_code_message_ignore(mocker, temp_testbed, mock_session_context): mock_session_context.config.auto_context_tokens = 8000 mocker.patch.object(Config, "maximum_context", new=7000) - filter_mock = AsyncMock(side_effect=lambda features: features) - mocker.patch.object(DefaultFilter, "filter", side_effect=filter_mock) code_context = CodeContext( mock_session_context.stream, temp_testbed, @@ -260,6 +288,8 @@ async def test_get_code_message_ignore(mocker, temp_testbed, mock_session_contex # Iterate through all files in temp_testbed; if they're not in the ignore # list, they should be in the code message. for file in get_non_gitignored_files(temp_testbed): + if str(file).startswith(".ragdaemon"): + continue abs_path = temp_testbed / file rel_path = abs_path.relative_to(temp_testbed).as_posix() if not is_file_text_encoded(abs_path) or "scripts" in rel_path or rel_path.endswith(".txt"): diff --git a/tests/code_feature_test.py b/tests/code_feature_test.py index b5edac716..66990a7f8 100644 --- a/tests/code_feature_test.py +++ b/tests/code_feature_test.py @@ -1,37 +1,7 @@ -from textwrap import dedent - -from mentat.code_feature import ( - CodeFeature, - get_consolidated_feature_refs, - split_file_into_intervals, -) +from mentat.code_feature import CodeFeature, get_consolidated_feature_refs from mentat.interval import Interval -def test_split_file_into_intervals(temp_testbed, mock_session_context): - with open("file_1.py", "w") as f: - f.write( - dedent( - """\ - def func_1(x, y): - return x + y - - def func_2(): - return 3 - """ - ) - ) - code_feature = CodeFeature(mock_session_context.cwd / "file_1.py") - interval_features = split_file_into_intervals(code_feature, 1) - - assert len(interval_features) == 2 - - interval_1 = interval_features[0].interval - interval_2 = interval_features[1].interval - assert (interval_1.start, interval_1.end) == (1, 4) - assert (interval_2.start, interval_2.end) == (4, 6) - - def test_ref_method(temp_testbed): test_file = temp_testbed / "test_file.py" test_file.write_text("\n".join([""] * 10)) diff --git a/tests/code_file_manager_test.py b/tests/code_file_manager_test.py index 13675608c..8873b0da0 100644 --- a/tests/code_file_manager_test.py +++ b/tests/code_file_manager_test.py @@ -20,9 +20,10 @@ async def test_posix_paths(mock_session_context): mock_session_context.code_context.include(file_path) code_message = await mock_session_context.code_context.get_code_message(0) - assert dir_name + "/" + file_name in code_message.split("\n") + assert any(line.startswith(dir_name + "/" + file_name) for line in code_message.split("\n")) +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_partial_files(mocker, mock_session_context): dir_name = "dir" @@ -50,15 +51,17 @@ async def test_partial_files(mocker, mock_session_context): """\ Code Files: - dir/file.txt + dir/file.txt (user-included) 1:I am a file ... 3:third 4:fourth + ... """ ) +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_run_from_subdirectory( temp_testbed, @@ -121,6 +124,7 @@ async def test_run_from_subdirectory( assert echo_output[0].strip() == "# Hello" +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_run_from_superdirectory( temp_testbed, @@ -185,6 +189,7 @@ async def test_run_from_superdirectory( assert echo_output[0].strip() == "# Hello" +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_change_after_creation( mock_collect_user_input, diff --git a/tests/commands_test.py b/tests/commands_test.py index 1cfb56cde..c4585a9a3 100644 --- a/tests/commands_test.py +++ b/tests/commands_test.py @@ -101,6 +101,7 @@ async def test_save_command(temp_testbed, mock_collect_user_input): assert [str(calculator_script_path)] in (saved_code_context.values()) +@pytest.mark.ragdaemon # Required to count loaded context tokens @pytest.mark.asyncio async def test_load_command_success(temp_testbed, mock_collect_user_input): scripts_dir = Path(temp_testbed) / "scripts" @@ -155,7 +156,7 @@ async def test_load_command_file_not_found(temp_testbed, mock_collect_user_input session.start() await session.stream.recv(channel="client_exit") - assert "Context file not found" in session.stream.messages[1].data + assert any("Context file not found" in m.data for m in session.stream.messages) @pytest.mark.asyncio @@ -174,7 +175,7 @@ async def test_load_command_invalid_json(temp_testbed, mock_collect_user_input): session = Session(cwd=temp_testbed) session.start() await session.stream.recv(channel="client_exit") - assert "Failed to parse context file" in session.stream.messages[1].data + assert any("Failed to parse context file" in m.data for m in session.stream.messages) @pytest.mark.asyncio @@ -363,6 +364,7 @@ async def test_undo_all_command(temp_testbed, mock_collect_user_input, mock_call assert content == expected_content +@pytest.mark.ragdaemon @pytest.mark.asyncio async def test_clear_command(temp_testbed, mock_collect_user_input, mock_call_llm_api): mock_collect_user_input.set_stream_messages( diff --git a/tests/conftest.py b/tests/conftest.py index 9f9a02ab9..a7e1f13eb 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ import subprocess import tempfile from pathlib import Path -from unittest.mock import AsyncMock, MagicMock +from unittest.mock import AsyncMock from uuid import uuid4 import pytest @@ -56,6 +56,7 @@ def pytest_configure(config): config.addinivalue_line("markers", "uitest: run ui-tests that get evaluated by humans") config.addinivalue_line("markers", "clear_testbed: create a testbed without any existing files") config.addinivalue_line("markers", "no_git_testbed: create a testbed without git") + config.addinivalue_line("markers", "ragdaemon: DON'T mock the daemon in the testbed") def pytest_collection_modifyitems(config, items): @@ -145,17 +146,6 @@ async def call_llm_api_mock(messages, model, provider, stream, response_format=" return completion_mock -@pytest.fixture(scope="function") -def mock_call_embedding_api(mocker): - embedding_mock = mocker.patch.object(LlmApiHandler, "call_embedding_api") - - def set_embedding_values(value): - embedding_mock.return_value = value - - embedding_mock.set_embedding_values = set_embedding_values - return embedding_mock - - ### Auto-used fixtures @@ -240,6 +230,9 @@ def add_permissions(func, path, exc_info): If the error is due to an access error (read only file) it attempts to add write permission and then retries. + If the error is because the file is being used by another process, + it retries after a short delay. + If the error is for another reason it re-raises the error. """ @@ -253,7 +246,7 @@ def add_permissions(func, path, exc_info): @pytest.fixture(autouse=True) -def temp_testbed(monkeypatch, get_marks): +def temp_testbed(mocker, monkeypatch, get_marks): # Allow us to run tests from any directory base_dir = Path(__file__).parent.parent @@ -282,6 +275,9 @@ def temp_testbed(monkeypatch, get_marks): run_git_command(temp_testbed, "add", ".") run_git_command(temp_testbed, "commit", "-m", "add testbed") + if "ragdaemon" not in get_marks: + mocker.patch("ragdaemon.daemon.Daemon.update", side_effect=AsyncMock()) + # necessary to undo chdir before calling rmtree, or it fails on windows with monkeypatch.context() as m: m.chdir(temp_testbed) diff --git a/tests/diff_context_test.py b/tests/diff_context_test.py index 84ce7ced1..8d4751088 100644 --- a/tests/diff_context_test.py +++ b/tests/diff_context_test.py @@ -67,8 +67,8 @@ def test_diff_context_default(temp_testbed, git_history, mock_session_context): mock_session_context.stream, temp_testbed, ) - assert diff_context.target == "HEAD" - assert diff_context.name == "HEAD (last commit)" + assert diff_context.target == "" + assert diff_context.name == "index (last commit)" assert diff_context.diff_files() == [] # DiffContext.files (property): return git-tracked files with active changes @@ -76,15 +76,6 @@ def test_diff_context_default(temp_testbed, git_history, mock_session_context): diff_context._diff_files = None # This is usually cached assert diff_context.diff_files() == [abs_path] - # DiffContext.annotate_file_message(): modify file_message with diff - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return commit3", - "14:+ return commit5", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_commit(temp_testbed, git_history, mock_session_context): @@ -101,14 +92,6 @@ async def test_diff_context_commit(temp_testbed, git_history, mock_session_conte assert diff_context.name == f"{last_commit[:8]}: add testbed" assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return a / b", - "14:+ return commit3", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_branch(temp_testbed, git_history, mock_session_context): @@ -124,14 +107,6 @@ async def test_diff_context_branch(temp_testbed, git_history, mock_session_conte assert diff_context.name.endswith(": commit4") assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return commit4", - "14:+ return commit3", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_relative(temp_testbed, git_history, mock_session_context): @@ -147,14 +122,6 @@ async def test_diff_context_relative(temp_testbed, git_history, mock_session_con assert diff_context.name.endswith(": add testbed") assert diff_context.diff_files() == [abs_path] - file_message = _get_file_message(abs_path) - annotated_message = diff_context.annotate_file_message(abs_path, file_message) - expected = file_message[:-1] + [ - "14:- return a / b", - "14:+ return commit3", - ] - assert annotated_message == expected - @pytest.mark.asyncio async def test_diff_context_pr(temp_testbed, git_history, mock_session_context): diff --git a/tests/feature_filters/llm_feature_filter_test.py b/tests/feature_filters/llm_feature_filter_test.py deleted file mode 100644 index 8e77ae1e9..000000000 --- a/tests/feature_filters/llm_feature_filter_test.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest - -from mentat.code_feature import CodeFeature -from mentat.feature_filters.llm_feature_filter import LLMFeatureFilter - - -@pytest.mark.asyncio -async def test_llm_feature_filter(mocker, temp_testbed, mock_call_llm_api, mock_session_context): - all_features = [ - CodeFeature(temp_testbed / "multifile_calculator" / "calculator.py"), # 188 tokens - CodeFeature(temp_testbed / "multifile_calculator" / "operations.py"), # 87 tokens - ] - - mock_call_llm_api.set_unstreamed_values('{"multifile_calculator/operations.py": "test reason"}') - mock_session_context.config.llm_feature_filter = 10000 - - feature_filter = LLMFeatureFilter(100, user_prompt="test prompt") - selected = await feature_filter.filter(all_features) - - messages = mock_call_llm_api.call_args.kwargs["messages"] - assert messages[0]["content"].startswith("You are part of") - assert messages[1]["content"].startswith("CODE FILES") - assert messages[2]["content"].startswith("USER QUERY") - assert messages[3]["content"].startswith("Now,") - - # Only one file returned - assert len(selected) == 1 diff --git a/tests/feature_filters/truncate_filter_test.py b/tests/feature_filters/truncate_filter_test.py deleted file mode 100644 index ebca33e8d..000000000 --- a/tests/feature_filters/truncate_filter_test.py +++ /dev/null @@ -1,22 +0,0 @@ -import pytest - -from mentat.code_feature import CodeFeature -from mentat.feature_filters.truncate_filter import TruncateFilter - - -@pytest.mark.asyncio -async def test_truncate_feature_selector(temp_testbed, mock_call_llm_api): - all_features = [ - CodeFeature(temp_testbed / "multifile_calculator" / "calculator.py"), # 188 tokens - CodeFeature(temp_testbed / "multifile_calculator" / "operations.py"), # 87 tokens - ] - - feature_filter = TruncateFilter(100) - selected = await feature_filter.filter(all_features) - assert len(selected) == 1 - assert selected[0].path.name == "operations.py" - - feature_filter = TruncateFilter(200) - selected = await feature_filter.filter(all_features) - assert len(selected) == 1 - assert selected[0].path.name == "calculator.py" diff --git a/tests/sampler_test.py b/tests/sampler_test.py index 91079438f..4d30d614b 100644 --- a/tests/sampler_test.py +++ b/tests/sampler_test.py @@ -202,8 +202,9 @@ async def test_sample_command(temp_testbed, mock_collect_user_input, mock_call_l } +@pytest.mark.ragdaemon @pytest.mark.asyncio -async def test_sample_eval(mock_call_llm_api): +async def test_sample_eval(temp_testbed, mock_call_llm_api): parsedLLMResponse = GitParser().parse_llm_response(test_sample["diff_edit"]) edit_message = BlockParser().file_edits_to_llm_message(parsedLLMResponse) mock_call_llm_api.set_streamed_values(