Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 66 additions & 17 deletions src/kit/context_extractor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from __future__ import annotations

import ast
import os
from pathlib import Path
from typing import Any, Dict, List, Optional
from typing import Any, ClassVar, Dict, List, Optional, Tuple

from .tree_sitter_symbol_extractor import TreeSitterSymbolExtractor

Expand All @@ -13,25 +14,76 @@ class ContextExtractor:
Supports chunking by lines, symbols, and function/class scope.
"""

# LRU-style cache for file contents: path -> (mtime, content, lines)
# Avoids re-reading the same file multiple times across method calls
_file_cache: ClassVar[Dict[str, Tuple[float, str, List[str]]]] = {}
_cache_max_size: ClassVar[int] = 100 # Max files to cache

def __init__(self, repo_path: str) -> None:
self.repo_path: Path = Path(repo_path)

def _read_file_cached(self, abs_path: Path) -> Tuple[str, List[str]]:
"""Read file content with mtime-based caching.

Returns (content, lines) tuple. Uses cache if file hasn't changed.
This avoids redundant disk reads when multiple methods access the same file.
"""
path_str = str(abs_path)
try:
current_mtime = os.path.getmtime(abs_path)
except OSError:
# File doesn't exist or can't be accessed
raise FileNotFoundError(f"Cannot access file: {abs_path}")

# Check cache
if path_str in self._file_cache:
cached_mtime, cached_content, cached_lines = self._file_cache[path_str]
if cached_mtime == current_mtime:
return cached_content, cached_lines

# Read file
with open(abs_path, "r", encoding="utf-8", errors="ignore") as f:
content = f.read()
lines = content.splitlines(keepends=True)

# Evict oldest entries if cache is full (simple FIFO eviction)
if len(self._file_cache) >= self._cache_max_size:
# Remove first 10% of entries
keys_to_remove = list(self._file_cache.keys())[: self._cache_max_size // 10]
for key in keys_to_remove:
del self._file_cache[key]

# Cache the result
self._file_cache[path_str] = (current_mtime, content, lines)
return content, lines

def invalidate_cache(self, file_path: Optional[str] = None) -> None:
"""Invalidate file cache.

Args:
file_path: Specific file to invalidate, or None to clear entire cache.
"""
if file_path is None:
self._file_cache.clear()
elif file_path in self._file_cache:
del self._file_cache[file_path]

def chunk_file_by_lines(self, file_path: str, max_lines: int = 50) -> List[str]:
"""
Chunk file into blocks of at most max_lines lines.
"""
from .utils import validate_relative_path

abs_path = validate_relative_path(self.repo_path, file_path)
try:
_, all_lines = self._read_file_cached(abs_path)
except (FileNotFoundError, OSError):
return []

chunks: List[str] = []
with open(validate_relative_path(self.repo_path, file_path), "r", encoding="utf-8", errors="ignore") as f:
lines: List[str] = []
for i, line in enumerate(f, 1):
lines.append(line)
if i % max_lines == 0:
chunks.append("".join(lines))
lines = []
if lines:
chunks.append("".join(lines))
for i in range(0, len(all_lines), max_lines):
chunk_lines = all_lines[i : i + max_lines]
chunks.append("".join(chunk_lines))
return chunks

def chunk_file_by_symbols(self, file_path: str) -> List[Dict[str, Any]]:
Expand All @@ -40,9 +92,8 @@ def chunk_file_by_symbols(self, file_path: str) -> List[Dict[str, Any]]:
ext = Path(file_path).suffix.lower()
abs_path = validate_relative_path(self.repo_path, file_path)
try:
with open(abs_path, "r", encoding="utf-8", errors="ignore") as f:
code = f.read()
except Exception:
code, _ = self._read_file_cached(abs_path)
except (FileNotFoundError, OSError):
return []
if ext in TreeSitterSymbolExtractor.LANGUAGES:
return TreeSitterSymbolExtractor.extract_symbols(ext, code)
Expand All @@ -58,10 +109,8 @@ def extract_context_around_line(self, file_path: str, line: int) -> Optional[Dic
ext = Path(file_path).suffix.lower()
abs_path = validate_relative_path(self.repo_path, file_path)
try:
with open(abs_path, "r", encoding="utf-8", errors="ignore") as f:
all_lines = f.readlines()
code = "".join(all_lines)
except Exception:
code, all_lines = self._read_file_cached(abs_path)
except (FileNotFoundError, OSError):
return None
if ext == ".py":
try:
Expand Down
58 changes: 32 additions & 26 deletions src/kit/vector_searcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def __init__(self, persist_dir: str, collection_name: Optional[str] = None):
self.persist_dir = persist_dir
self.client = PersistentClient(path=self.persist_dir)
self.is_local = True # Flag to identify local backend
self._needs_reset = True # Track if collection needs clearing before next add

final_collection_name = collection_name
if final_collection_name is None:
Expand Down Expand Up @@ -132,36 +133,41 @@ def delete(self, ids: List[str]):
pass

def _reset_collection(self) -> None:
"""Ensure we start from a clean collection before bulk re-add."""
"""Ensure we start from a clean collection before bulk re-add.

Optimized to:
- Skip if already reset (tracked via _needs_reset flag)
- Use delete_collection as primary fast path (avoids count() call)
- Fall back to other methods only if delete_collection fails
"""
if not self._needs_reset:
return

# Try delete_collection first - fastest path, avoids count() overhead
try:
if self.collection.count() > 0:
cleared = False
try:
self.client.delete_collection(self.collection_name)
cleared = True
except Exception:
pass

if not cleared:
self.client.delete_collection(self.collection_name)
except Exception:
# Collection might not exist or delete not supported - try alternatives
try:
# Check if there's anything to clear before expensive operations
if self.collection.count() > 0:
try:
self.collection.delete(where={"source": {"$ne": "__kit__never__"}})
cleared = True
except Exception:
pass

if not cleared:
try:
existing = self.collection.get(include=[])
ids = existing.get("ids") if isinstance(existing, dict) else None
if ids:
self.collection.delete(ids=list(ids))
except Exception:
pass
except Exception:
pass
finally:
self.collection = self.client.get_or_create_collection(self.collection_name)
self._batch_size = _resolve_batch_size(self.collection)
try:
existing = self.collection.get(include=[])
ids = existing.get("ids") if isinstance(existing, dict) else None
if ids:
self.collection.delete(ids=list(ids))
except Exception:
pass
except Exception:
pass

# Recreate collection and mark as reset
self.collection = self.client.get_or_create_collection(self.collection_name)
self._batch_size = _resolve_batch_size(self.collection)
self._needs_reset = False


class ChromaCloudBackend(VectorDBBackend):
Expand Down