diff --git a/docs/src/content/docs/mcp/kit-dev-mcp.mdx b/docs/src/content/docs/mcp/kit-dev-mcp.mdx
index c7a11f4..8b3ef20 100644
--- a/docs/src/content/docs/mcp/kit-dev-mcp.mdx
+++ b/docs/src/content/docs/mcp/kit-dev-mcp.mdx
@@ -82,11 +82,13 @@ The server provides many tools including:
- **open_repository** - Open local or remote Git repositories
- **search_code** - Pattern-based code search
-- **grep_code** - Fast literal string search
+- **grep_code** - Fast literal string search (120s default timeout, configurable via `KIT_GREP_TIMEOUT`)
+- **get_file_tree** - Repository file structure with pagination support (`limit`/`offset` params)
- **get_file_content** - Read file contents
- **extract_symbols** - Extract functions, classes, and symbols
- **find_symbol_usages** - Find where symbols are used
- **get_code_summary** - AI-powered code summaries
+- **warm_cache** - Pre-warm caches for faster operations on large codebases (100K+ files)
- **review_diff** - AI-powered diff reviews
- **deep_research_package** - Comprehensive package documentation
- **semantic_search** - Vector-based code search
@@ -94,6 +96,13 @@ The server provides many tools including:
- **package_search_hybrid** - Semantic search in package source code
- **package_search_read_file** - Read specific files from packages
+
+
## Learn More
List[Dict[str, Any]]:
repo = self.get_repo(repo_id)
return repo.get_file_tree()
+ def warm_cache(self, repo_id: str, warm_file_tree: bool = True, warm_symbols: bool = False) -> Dict[str, Any]:
+ """Pre-warm caches for faster subsequent operations on large codebases.
+
+ This is useful for very large repos where the first file_tree or symbol
+ extraction can take 30+ seconds. Warming caches upfront avoids timeouts.
+
+ Args:
+ repo_id: Repository ID to warm caches for
+ warm_file_tree: Pre-cache file tree (fast, ~1-5s for 100K files)
+ warm_symbols: Pre-cache symbols (slower, ~30-60s for 100K files)
+
+ Returns:
+ Dict with timing stats for each warmed cache
+ """
+ import time
+
+ repo = self.get_repo(repo_id)
+ stats: Dict[str, Any] = {"repo_id": repo_id}
+
+ if warm_file_tree:
+ start = time.time()
+ tree = repo.get_file_tree()
+ stats["file_tree"] = {
+ "elapsed_seconds": round(time.time() - start, 2),
+ "file_count": len(tree),
+ }
+
+ if warm_symbols:
+ start = time.time()
+ # Trigger full repo scan by calling extract_symbols with no file
+ symbols = repo.extract_symbols()
+ stats["symbols"] = {
+ "elapsed_seconds": round(time.time() - start, 2),
+ "symbol_count": len(symbols),
+ }
+
+ return stats
+
def extract_symbols(self, repo_id: str, file_path: str, symbol_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""Extract symbols from a file."""
repo = self.get_repo(repo_id)
@@ -545,6 +599,11 @@ def list_tools(self) -> List[Tool]:
description="Get source code of a specific symbol (lazy loading for context efficiency)",
inputSchema=GetSymbolCodeParams.model_json_schema(),
),
+ Tool(
+ name="warm_cache",
+ description="Pre-warm caches for faster operations on large codebases (call before get_file_tree on huge repos)",
+ inputSchema=WarmCacheParams.model_json_schema(),
+ ),
]
@@ -1072,15 +1131,33 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
elif name == "get_file_tree":
tree_params = GetFileTreeParams(**arguments)
result = logic.get_file_tree(tree_params.repo_id)
+
+ # Apply pagination for large codebases
+ total_count = len(result)
+ start = tree_params.offset
+ end = start + tree_params.limit
+ paginated = result[start:end]
+ has_more = end < total_count
+
# Compact mode: newline-separated paths (saves ~75% context)
if tree_params.compact:
paths = []
- for item in result:
+ for item in paginated:
is_dir = item.get("is_dir", False)
if tree_params.include_dirs or not is_dir:
paths.append(item.get("path", ""))
- return [TextContent(type="text", text="\n".join(paths))]
- return [TextContent(type="text", text=json.dumps(result, indent=2))]
+ # Include pagination metadata as header for compact mode
+ header = f"# total={total_count} offset={start} limit={tree_params.limit} has_more={has_more}\n"
+ return [TextContent(type="text", text=header + "\n".join(paths))]
+ # JSON mode: include pagination in response
+ response = {
+ "files": paginated,
+ "total_count": total_count,
+ "offset": start,
+ "limit": tree_params.limit,
+ "has_more": has_more,
+ }
+ return [TextContent(type="text", text=json.dumps(response, indent=2))]
elif name == "get_code_summary":
summary_params = GetCodeSummaryParams(**arguments)
result = logic.get_code_summary(
@@ -1119,6 +1196,14 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]:
symbol_code_params.symbol_name,
)
return [TextContent(type="text", text=json.dumps(result, indent=2))]
+ elif name == "warm_cache":
+ cache_params = WarmCacheParams(**arguments)
+ result = logic.warm_cache(
+ cache_params.repo_id,
+ cache_params.warm_file_tree,
+ cache_params.warm_symbols,
+ )
+ return [TextContent(type="text", text=json.dumps(result, indent=2))]
else:
# Should not happen since we checked the name is in the list
return [TextContent(type="text", text=f"Tool {name} is recognized but not implemented")]
diff --git a/src/kit/repository.py b/src/kit/repository.py
index 7b3b03c..85a7190 100644
--- a/src/kit/repository.py
+++ b/src/kit/repository.py
@@ -379,6 +379,7 @@ def grep(
max_results: int = 1000,
directory: Optional[str] = None,
include_hidden: bool = False,
+ timeout: Optional[int] = None,
) -> List[Dict[str, Any]]:
"""
Performs literal grep search on repository files using system grep.
@@ -389,8 +390,11 @@ def grep(
include_pattern: Glob pattern for files to include (e.g. '*.py').
exclude_pattern: Glob pattern for files to exclude.
max_results: Maximum number of results to return. Defaults to 1000.
+ Uses grep's -m flag for early termination on large codebases.
directory: Limit search to specific directory within repository (e.g. 'src', 'lib/utils').
include_hidden: Whether to search hidden directories (starting with '.'). Defaults to False.
+ timeout: Search timeout in seconds. Defaults to 120s (or KIT_GREP_TIMEOUT env var).
+ For very large codebases (10M+ files), consider increasing this.
Returns:
List[Dict[str, Any]]: List of matches with file, line_number, line_content.
@@ -403,6 +407,17 @@ def grep(
self._ensure_git_state_valid()
+ # Resolve timeout: parameter > env var > default (120s)
+ if timeout is None:
+ env_timeout = os.environ.get("KIT_GREP_TIMEOUT")
+ if env_timeout:
+ try:
+ timeout = int(env_timeout)
+ except ValueError:
+ timeout = 120
+ else:
+ timeout = 120
+
# Build grep command
cmd = ["grep", "-r", "-n", "-H"] # -r for recursive, -n for line numbers, -H for filenames
@@ -477,6 +492,11 @@ def grep(
raise ValueError(f"Directory not found in repository: {directory}")
search_path = directory
+ # Early termination: use -m flag to stop after max_results matches per file
+ # This provides massive speedups on large codebases by avoiding full traversal
+ if max_results > 0:
+ cmd.extend(["-m", str(max_results)])
+
# Search recursively in specified directory
cmd.append(search_path)
@@ -487,10 +507,13 @@ def grep(
capture_output=True,
text=True,
encoding="utf-8",
- timeout=30, # 30 second timeout
+ timeout=timeout,
)
except subprocess.TimeoutExpired:
- raise RuntimeError("Grep search timed out after 30 seconds")
+ raise RuntimeError(
+ f"Grep search timed out after {timeout} seconds. "
+ "Set KIT_GREP_TIMEOUT env var or timeout parameter for longer searches."
+ )
except FileNotFoundError:
raise RuntimeError("grep command not found. Please ensure grep is installed and in PATH.")
diff --git a/tests/test_large_codebase.py b/tests/test_large_codebase.py
new file mode 100644
index 0000000..94c1453
--- /dev/null
+++ b/tests/test_large_codebase.py
@@ -0,0 +1,232 @@
+"""Tests for large codebase support features.
+
+These tests verify the features added for handling large codebases:
+- Grep timeout parameter and early termination (-m flag)
+- File tree pagination
+- Warm cache MCP tool
+"""
+
+import os
+import tempfile
+from pathlib import Path
+from unittest.mock import patch
+
+import pytest
+
+from kit.mcp.dev_server import (
+ GetFileTreeParams,
+ LocalDevServerLogic,
+ WarmCacheParams,
+)
+from kit.repository import Repository
+
+
+class TestGrepTimeoutAndEarlyTermination:
+ """Test grep timeout parameter and early termination."""
+
+ @pytest.fixture
+ def repo(self):
+ """Create a test repository with some files."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a git repo
+ git_dir = Path(tmpdir) / ".git"
+ git_dir.mkdir()
+
+ # Create files with searchable content
+ src_dir = Path(tmpdir) / "src"
+ src_dir.mkdir()
+
+ for i in range(10):
+ (src_dir / f"file{i}.py").write_text(f"# TODO: task {i}\ndef func{i}():\n pass\n")
+
+ yield Repository(tmpdir)
+
+ def test_grep_default_timeout(self, repo):
+ """Test that grep uses default 120s timeout."""
+ # Without env var, default should be 120s
+ results = repo.grep("TODO", max_results=5)
+ assert len(results) == 5 # Limited by max_results
+
+ def test_grep_timeout_parameter(self, repo):
+ """Test that grep accepts timeout parameter."""
+ results = repo.grep("TODO", timeout=60)
+ assert len(results) > 0
+
+ def test_grep_timeout_from_env(self, repo):
+ """Test that grep reads timeout from KIT_GREP_TIMEOUT env var."""
+ with patch.dict(os.environ, {"KIT_GREP_TIMEOUT": "300"}):
+ results = repo.grep("TODO")
+ assert len(results) > 0
+
+ def test_grep_early_termination(self, repo):
+ """Test that grep uses -m flag for early termination."""
+ # With max_results=2, grep should stop early
+ results = repo.grep("TODO", max_results=2)
+ assert len(results) <= 2
+
+
+class TestFileTreePagination:
+ """Test file tree pagination in MCP server."""
+
+ @pytest.fixture
+ def server_with_repo(self):
+ """Create a server with a test repository."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a git repo
+ git_dir = Path(tmpdir) / ".git"
+ git_dir.mkdir()
+
+ # Create 50 test files
+ for i in range(50):
+ (Path(tmpdir) / f"file{i:03d}.py").write_text(f"# File {i}\n")
+
+ server = LocalDevServerLogic()
+ repo_id = server.open_repository(tmpdir)
+
+ yield server, repo_id
+
+ def test_pagination_default_limit(self, server_with_repo):
+ """Test that default limit is 10000."""
+ server, repo_id = server_with_repo
+ tree = server.get_file_tree(repo_id)
+ assert len(tree) == 50 # All files (less than default limit)
+
+ def test_pagination_with_limit(self, server_with_repo):
+ """Test pagination with explicit limit."""
+ server, repo_id = server_with_repo
+ tree = server.get_file_tree(repo_id)
+
+ # Simulate pagination at MCP layer
+ limit = 10
+ offset = 0
+ paginated = tree[offset : offset + limit]
+
+ assert len(paginated) == 10
+ # File order is not deterministic, just verify they're .py files
+ assert all(item["path"].endswith(".py") for item in paginated)
+
+ def test_pagination_with_offset(self, server_with_repo):
+ """Test pagination with offset."""
+ server, repo_id = server_with_repo
+ tree = server.get_file_tree(repo_id)
+
+ # Get second page
+ limit = 10
+ offset = 10
+ paginated = tree[offset : offset + limit]
+
+ assert len(paginated) == 10
+
+ def test_pagination_has_more(self, server_with_repo):
+ """Test has_more calculation."""
+ server, repo_id = server_with_repo
+ tree = server.get_file_tree(repo_id)
+
+ total_count = len(tree)
+ limit = 10
+ offset = 0
+ has_more = offset + limit < total_count
+
+ assert has_more is True
+
+ # Last page
+ offset = 45
+ has_more = offset + limit < total_count
+
+ assert has_more is False
+
+
+class TestWarmCacheTool:
+ """Test the warm_cache MCP tool."""
+
+ @pytest.fixture
+ def server_with_repo(self):
+ """Create a server with a test repository."""
+ with tempfile.TemporaryDirectory() as tmpdir:
+ # Create a git repo
+ git_dir = Path(tmpdir) / ".git"
+ git_dir.mkdir()
+
+ # Create some test files
+ src_dir = Path(tmpdir) / "src"
+ src_dir.mkdir()
+
+ for i in range(5):
+ (src_dir / f"module{i}.py").write_text(
+ f"class Class{i}:\n def method{i}(self):\n pass\n"
+ )
+
+ server = LocalDevServerLogic()
+ repo_id = server.open_repository(tmpdir)
+
+ yield server, repo_id
+
+ def test_warm_cache_file_tree(self, server_with_repo):
+ """Test warming file tree cache."""
+ server, repo_id = server_with_repo
+
+ result = server.warm_cache(repo_id, warm_file_tree=True, warm_symbols=False)
+
+ assert "repo_id" in result
+ assert result["repo_id"] == repo_id
+ assert "file_tree" in result
+ assert "elapsed_seconds" in result["file_tree"]
+ assert "file_count" in result["file_tree"]
+ assert result["file_tree"]["file_count"] >= 5
+
+ def test_warm_cache_symbols(self, server_with_repo):
+ """Test warming symbol cache."""
+ server, repo_id = server_with_repo
+
+ result = server.warm_cache(repo_id, warm_file_tree=False, warm_symbols=True)
+
+ assert "symbols" in result
+ assert "elapsed_seconds" in result["symbols"]
+ assert "symbol_count" in result["symbols"]
+ # Should find classes and methods
+ assert result["symbols"]["symbol_count"] >= 10
+
+ def test_warm_cache_both(self, server_with_repo):
+ """Test warming both caches."""
+ server, repo_id = server_with_repo
+
+ result = server.warm_cache(repo_id, warm_file_tree=True, warm_symbols=True)
+
+ assert "file_tree" in result
+ assert "symbols" in result
+
+
+class TestGetFileTreeParamsModel:
+ """Test GetFileTreeParams model has pagination fields."""
+
+ def test_has_limit_field(self):
+ """Test that GetFileTreeParams has limit field."""
+ params = GetFileTreeParams(repo_id="test")
+ assert params.limit == 10000 # default
+
+ def test_has_offset_field(self):
+ """Test that GetFileTreeParams has offset field."""
+ params = GetFileTreeParams(repo_id="test")
+ assert params.offset == 0 # default
+
+ def test_custom_pagination(self):
+ """Test custom pagination values."""
+ params = GetFileTreeParams(repo_id="test", limit=100, offset=50)
+ assert params.limit == 100
+ assert params.offset == 50
+
+
+class TestWarmCacheParamsModel:
+ """Test WarmCacheParams model."""
+
+ def test_default_values(self):
+ """Test default values."""
+ params = WarmCacheParams(repo_id="test")
+ assert params.warm_file_tree is True
+ assert params.warm_symbols is False
+
+ def test_custom_values(self):
+ """Test custom values."""
+ params = WarmCacheParams(repo_id="test", warm_file_tree=False, warm_symbols=True)
+ assert params.warm_file_tree is False
+ assert params.warm_symbols is True