diff --git a/src/kit/context_extractor.py b/src/kit/context_extractor.py index 15cf190..dee5558 100644 --- a/src/kit/context_extractor.py +++ b/src/kit/context_extractor.py @@ -1,8 +1,9 @@ from __future__ import annotations import ast +import os from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Tuple from .tree_sitter_symbol_extractor import TreeSitterSymbolExtractor @@ -13,25 +14,76 @@ class ContextExtractor: Supports chunking by lines, symbols, and function/class scope. """ + # LRU-style cache for file contents: path -> (mtime, content, lines) + # Avoids re-reading the same file multiple times across method calls + _file_cache: ClassVar[Dict[str, Tuple[float, str, List[str]]]] = {} + _cache_max_size: ClassVar[int] = 100 # Max files to cache + def __init__(self, repo_path: str) -> None: self.repo_path: Path = Path(repo_path) + def _read_file_cached(self, abs_path: Path) -> Tuple[str, List[str]]: + """Read file content with mtime-based caching. + + Returns (content, lines) tuple. Uses cache if file hasn't changed. + This avoids redundant disk reads when multiple methods access the same file. + """ + path_str = str(abs_path) + try: + current_mtime = os.path.getmtime(abs_path) + except OSError: + # File doesn't exist or can't be accessed + raise FileNotFoundError(f"Cannot access file: {abs_path}") + + # Check cache + if path_str in self._file_cache: + cached_mtime, cached_content, cached_lines = self._file_cache[path_str] + if cached_mtime == current_mtime: + return cached_content, cached_lines + + # Read file + with open(abs_path, "r", encoding="utf-8", errors="ignore") as f: + content = f.read() + lines = content.splitlines(keepends=True) + + # Evict oldest entries if cache is full (simple FIFO eviction) + if len(self._file_cache) >= self._cache_max_size: + # Remove first 10% of entries + keys_to_remove = list(self._file_cache.keys())[: self._cache_max_size // 10] + for key in keys_to_remove: + del self._file_cache[key] + + # Cache the result + self._file_cache[path_str] = (current_mtime, content, lines) + return content, lines + + def invalidate_cache(self, file_path: Optional[str] = None) -> None: + """Invalidate file cache. + + Args: + file_path: Specific file to invalidate, or None to clear entire cache. + """ + if file_path is None: + self._file_cache.clear() + elif file_path in self._file_cache: + del self._file_cache[file_path] + def chunk_file_by_lines(self, file_path: str, max_lines: int = 50) -> List[str]: """ Chunk file into blocks of at most max_lines lines. """ from .utils import validate_relative_path + abs_path = validate_relative_path(self.repo_path, file_path) + try: + _, all_lines = self._read_file_cached(abs_path) + except (FileNotFoundError, OSError): + return [] + chunks: List[str] = [] - with open(validate_relative_path(self.repo_path, file_path), "r", encoding="utf-8", errors="ignore") as f: - lines: List[str] = [] - for i, line in enumerate(f, 1): - lines.append(line) - if i % max_lines == 0: - chunks.append("".join(lines)) - lines = [] - if lines: - chunks.append("".join(lines)) + for i in range(0, len(all_lines), max_lines): + chunk_lines = all_lines[i : i + max_lines] + chunks.append("".join(chunk_lines)) return chunks def chunk_file_by_symbols(self, file_path: str) -> List[Dict[str, Any]]: @@ -40,9 +92,8 @@ def chunk_file_by_symbols(self, file_path: str) -> List[Dict[str, Any]]: ext = Path(file_path).suffix.lower() abs_path = validate_relative_path(self.repo_path, file_path) try: - with open(abs_path, "r", encoding="utf-8", errors="ignore") as f: - code = f.read() - except Exception: + code, _ = self._read_file_cached(abs_path) + except (FileNotFoundError, OSError): return [] if ext in TreeSitterSymbolExtractor.LANGUAGES: return TreeSitterSymbolExtractor.extract_symbols(ext, code) @@ -58,10 +109,8 @@ def extract_context_around_line(self, file_path: str, line: int) -> Optional[Dic ext = Path(file_path).suffix.lower() abs_path = validate_relative_path(self.repo_path, file_path) try: - with open(abs_path, "r", encoding="utf-8", errors="ignore") as f: - all_lines = f.readlines() - code = "".join(all_lines) - except Exception: + code, all_lines = self._read_file_cached(abs_path) + except (FileNotFoundError, OSError): return None if ext == ".py": try: diff --git a/src/kit/vector_searcher.py b/src/kit/vector_searcher.py index 9af4663..ae5eed9 100644 --- a/src/kit/vector_searcher.py +++ b/src/kit/vector_searcher.py @@ -71,6 +71,7 @@ def __init__(self, persist_dir: str, collection_name: Optional[str] = None): self.persist_dir = persist_dir self.client = PersistentClient(path=self.persist_dir) self.is_local = True # Flag to identify local backend + self._needs_reset = True # Track if collection needs clearing before next add final_collection_name = collection_name if final_collection_name is None: @@ -132,36 +133,41 @@ def delete(self, ids: List[str]): pass def _reset_collection(self) -> None: - """Ensure we start from a clean collection before bulk re-add.""" + """Ensure we start from a clean collection before bulk re-add. + + Optimized to: + - Skip if already reset (tracked via _needs_reset flag) + - Use delete_collection as primary fast path (avoids count() call) + - Fall back to other methods only if delete_collection fails + """ + if not self._needs_reset: + return + + # Try delete_collection first - fastest path, avoids count() overhead try: - if self.collection.count() > 0: - cleared = False - try: - self.client.delete_collection(self.collection_name) - cleared = True - except Exception: - pass - - if not cleared: + self.client.delete_collection(self.collection_name) + except Exception: + # Collection might not exist or delete not supported - try alternatives + try: + # Check if there's anything to clear before expensive operations + if self.collection.count() > 0: try: self.collection.delete(where={"source": {"$ne": "__kit__never__"}}) - cleared = True - except Exception: - pass - - if not cleared: - try: - existing = self.collection.get(include=[]) - ids = existing.get("ids") if isinstance(existing, dict) else None - if ids: - self.collection.delete(ids=list(ids)) except Exception: - pass - except Exception: - pass - finally: - self.collection = self.client.get_or_create_collection(self.collection_name) - self._batch_size = _resolve_batch_size(self.collection) + try: + existing = self.collection.get(include=[]) + ids = existing.get("ids") if isinstance(existing, dict) else None + if ids: + self.collection.delete(ids=list(ids)) + except Exception: + pass + except Exception: + pass + + # Recreate collection and mark as reset + self.collection = self.client.get_or_create_collection(self.collection_name) + self._batch_size = _resolve_batch_size(self.collection) + self._needs_reset = False class ChromaCloudBackend(VectorDBBackend):