diff --git a/docs/src/content/docs/mcp/kit-dev-mcp.mdx b/docs/src/content/docs/mcp/kit-dev-mcp.mdx index c7a11f4..8b3ef20 100644 --- a/docs/src/content/docs/mcp/kit-dev-mcp.mdx +++ b/docs/src/content/docs/mcp/kit-dev-mcp.mdx @@ -82,11 +82,13 @@ The server provides many tools including: - **open_repository** - Open local or remote Git repositories - **search_code** - Pattern-based code search -- **grep_code** - Fast literal string search +- **grep_code** - Fast literal string search (120s default timeout, configurable via `KIT_GREP_TIMEOUT`) +- **get_file_tree** - Repository file structure with pagination support (`limit`/`offset` params) - **get_file_content** - Read file contents - **extract_symbols** - Extract functions, classes, and symbols - **find_symbol_usages** - Find where symbols are used - **get_code_summary** - AI-powered code summaries +- **warm_cache** - Pre-warm caches for faster operations on large codebases (100K+ files) - **review_diff** - AI-powered diff reviews - **deep_research_package** - Comprehensive package documentation - **semantic_search** - Vector-based code search @@ -94,6 +96,13 @@ The server provides many tools including: - **package_search_hybrid** - Semantic search in package source code - **package_search_read_file** - Read specific files from packages + + ## Learn More List[Dict[str, Any]]: repo = self.get_repo(repo_id) return repo.get_file_tree() + def warm_cache(self, repo_id: str, warm_file_tree: bool = True, warm_symbols: bool = False) -> Dict[str, Any]: + """Pre-warm caches for faster subsequent operations on large codebases. + + This is useful for very large repos where the first file_tree or symbol + extraction can take 30+ seconds. Warming caches upfront avoids timeouts. + + Args: + repo_id: Repository ID to warm caches for + warm_file_tree: Pre-cache file tree (fast, ~1-5s for 100K files) + warm_symbols: Pre-cache symbols (slower, ~30-60s for 100K files) + + Returns: + Dict with timing stats for each warmed cache + """ + import time + + repo = self.get_repo(repo_id) + stats: Dict[str, Any] = {"repo_id": repo_id} + + if warm_file_tree: + start = time.time() + tree = repo.get_file_tree() + stats["file_tree"] = { + "elapsed_seconds": round(time.time() - start, 2), + "file_count": len(tree), + } + + if warm_symbols: + start = time.time() + # Trigger full repo scan by calling extract_symbols with no file + symbols = repo.extract_symbols() + stats["symbols"] = { + "elapsed_seconds": round(time.time() - start, 2), + "symbol_count": len(symbols), + } + + return stats + def extract_symbols(self, repo_id: str, file_path: str, symbol_type: Optional[str] = None) -> List[Dict[str, Any]]: """Extract symbols from a file.""" repo = self.get_repo(repo_id) @@ -545,6 +599,11 @@ def list_tools(self) -> List[Tool]: description="Get source code of a specific symbol (lazy loading for context efficiency)", inputSchema=GetSymbolCodeParams.model_json_schema(), ), + Tool( + name="warm_cache", + description="Pre-warm caches for faster operations on large codebases (call before get_file_tree on huge repos)", + inputSchema=WarmCacheParams.model_json_schema(), + ), ] @@ -1072,15 +1131,33 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]: elif name == "get_file_tree": tree_params = GetFileTreeParams(**arguments) result = logic.get_file_tree(tree_params.repo_id) + + # Apply pagination for large codebases + total_count = len(result) + start = tree_params.offset + end = start + tree_params.limit + paginated = result[start:end] + has_more = end < total_count + # Compact mode: newline-separated paths (saves ~75% context) if tree_params.compact: paths = [] - for item in result: + for item in paginated: is_dir = item.get("is_dir", False) if tree_params.include_dirs or not is_dir: paths.append(item.get("path", "")) - return [TextContent(type="text", text="\n".join(paths))] - return [TextContent(type="text", text=json.dumps(result, indent=2))] + # Include pagination metadata as header for compact mode + header = f"# total={total_count} offset={start} limit={tree_params.limit} has_more={has_more}\n" + return [TextContent(type="text", text=header + "\n".join(paths))] + # JSON mode: include pagination in response + response = { + "files": paginated, + "total_count": total_count, + "offset": start, + "limit": tree_params.limit, + "has_more": has_more, + } + return [TextContent(type="text", text=json.dumps(response, indent=2))] elif name == "get_code_summary": summary_params = GetCodeSummaryParams(**arguments) result = logic.get_code_summary( @@ -1119,6 +1196,14 @@ async def call_tool(name: str, arguments: dict) -> List[TextContent]: symbol_code_params.symbol_name, ) return [TextContent(type="text", text=json.dumps(result, indent=2))] + elif name == "warm_cache": + cache_params = WarmCacheParams(**arguments) + result = logic.warm_cache( + cache_params.repo_id, + cache_params.warm_file_tree, + cache_params.warm_symbols, + ) + return [TextContent(type="text", text=json.dumps(result, indent=2))] else: # Should not happen since we checked the name is in the list return [TextContent(type="text", text=f"Tool {name} is recognized but not implemented")] diff --git a/src/kit/repository.py b/src/kit/repository.py index 7b3b03c..85a7190 100644 --- a/src/kit/repository.py +++ b/src/kit/repository.py @@ -379,6 +379,7 @@ def grep( max_results: int = 1000, directory: Optional[str] = None, include_hidden: bool = False, + timeout: Optional[int] = None, ) -> List[Dict[str, Any]]: """ Performs literal grep search on repository files using system grep. @@ -389,8 +390,11 @@ def grep( include_pattern: Glob pattern for files to include (e.g. '*.py'). exclude_pattern: Glob pattern for files to exclude. max_results: Maximum number of results to return. Defaults to 1000. + Uses grep's -m flag for early termination on large codebases. directory: Limit search to specific directory within repository (e.g. 'src', 'lib/utils'). include_hidden: Whether to search hidden directories (starting with '.'). Defaults to False. + timeout: Search timeout in seconds. Defaults to 120s (or KIT_GREP_TIMEOUT env var). + For very large codebases (10M+ files), consider increasing this. Returns: List[Dict[str, Any]]: List of matches with file, line_number, line_content. @@ -403,6 +407,17 @@ def grep( self._ensure_git_state_valid() + # Resolve timeout: parameter > env var > default (120s) + if timeout is None: + env_timeout = os.environ.get("KIT_GREP_TIMEOUT") + if env_timeout: + try: + timeout = int(env_timeout) + except ValueError: + timeout = 120 + else: + timeout = 120 + # Build grep command cmd = ["grep", "-r", "-n", "-H"] # -r for recursive, -n for line numbers, -H for filenames @@ -477,6 +492,11 @@ def grep( raise ValueError(f"Directory not found in repository: {directory}") search_path = directory + # Early termination: use -m flag to stop after max_results matches per file + # This provides massive speedups on large codebases by avoiding full traversal + if max_results > 0: + cmd.extend(["-m", str(max_results)]) + # Search recursively in specified directory cmd.append(search_path) @@ -487,10 +507,13 @@ def grep( capture_output=True, text=True, encoding="utf-8", - timeout=30, # 30 second timeout + timeout=timeout, ) except subprocess.TimeoutExpired: - raise RuntimeError("Grep search timed out after 30 seconds") + raise RuntimeError( + f"Grep search timed out after {timeout} seconds. " + "Set KIT_GREP_TIMEOUT env var or timeout parameter for longer searches." + ) except FileNotFoundError: raise RuntimeError("grep command not found. Please ensure grep is installed and in PATH.") diff --git a/tests/test_large_codebase.py b/tests/test_large_codebase.py new file mode 100644 index 0000000..94c1453 --- /dev/null +++ b/tests/test_large_codebase.py @@ -0,0 +1,232 @@ +"""Tests for large codebase support features. + +These tests verify the features added for handling large codebases: +- Grep timeout parameter and early termination (-m flag) +- File tree pagination +- Warm cache MCP tool +""" + +import os +import tempfile +from pathlib import Path +from unittest.mock import patch + +import pytest + +from kit.mcp.dev_server import ( + GetFileTreeParams, + LocalDevServerLogic, + WarmCacheParams, +) +from kit.repository import Repository + + +class TestGrepTimeoutAndEarlyTermination: + """Test grep timeout parameter and early termination.""" + + @pytest.fixture + def repo(self): + """Create a test repository with some files.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a git repo + git_dir = Path(tmpdir) / ".git" + git_dir.mkdir() + + # Create files with searchable content + src_dir = Path(tmpdir) / "src" + src_dir.mkdir() + + for i in range(10): + (src_dir / f"file{i}.py").write_text(f"# TODO: task {i}\ndef func{i}():\n pass\n") + + yield Repository(tmpdir) + + def test_grep_default_timeout(self, repo): + """Test that grep uses default 120s timeout.""" + # Without env var, default should be 120s + results = repo.grep("TODO", max_results=5) + assert len(results) == 5 # Limited by max_results + + def test_grep_timeout_parameter(self, repo): + """Test that grep accepts timeout parameter.""" + results = repo.grep("TODO", timeout=60) + assert len(results) > 0 + + def test_grep_timeout_from_env(self, repo): + """Test that grep reads timeout from KIT_GREP_TIMEOUT env var.""" + with patch.dict(os.environ, {"KIT_GREP_TIMEOUT": "300"}): + results = repo.grep("TODO") + assert len(results) > 0 + + def test_grep_early_termination(self, repo): + """Test that grep uses -m flag for early termination.""" + # With max_results=2, grep should stop early + results = repo.grep("TODO", max_results=2) + assert len(results) <= 2 + + +class TestFileTreePagination: + """Test file tree pagination in MCP server.""" + + @pytest.fixture + def server_with_repo(self): + """Create a server with a test repository.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a git repo + git_dir = Path(tmpdir) / ".git" + git_dir.mkdir() + + # Create 50 test files + for i in range(50): + (Path(tmpdir) / f"file{i:03d}.py").write_text(f"# File {i}\n") + + server = LocalDevServerLogic() + repo_id = server.open_repository(tmpdir) + + yield server, repo_id + + def test_pagination_default_limit(self, server_with_repo): + """Test that default limit is 10000.""" + server, repo_id = server_with_repo + tree = server.get_file_tree(repo_id) + assert len(tree) == 50 # All files (less than default limit) + + def test_pagination_with_limit(self, server_with_repo): + """Test pagination with explicit limit.""" + server, repo_id = server_with_repo + tree = server.get_file_tree(repo_id) + + # Simulate pagination at MCP layer + limit = 10 + offset = 0 + paginated = tree[offset : offset + limit] + + assert len(paginated) == 10 + # File order is not deterministic, just verify they're .py files + assert all(item["path"].endswith(".py") for item in paginated) + + def test_pagination_with_offset(self, server_with_repo): + """Test pagination with offset.""" + server, repo_id = server_with_repo + tree = server.get_file_tree(repo_id) + + # Get second page + limit = 10 + offset = 10 + paginated = tree[offset : offset + limit] + + assert len(paginated) == 10 + + def test_pagination_has_more(self, server_with_repo): + """Test has_more calculation.""" + server, repo_id = server_with_repo + tree = server.get_file_tree(repo_id) + + total_count = len(tree) + limit = 10 + offset = 0 + has_more = offset + limit < total_count + + assert has_more is True + + # Last page + offset = 45 + has_more = offset + limit < total_count + + assert has_more is False + + +class TestWarmCacheTool: + """Test the warm_cache MCP tool.""" + + @pytest.fixture + def server_with_repo(self): + """Create a server with a test repository.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create a git repo + git_dir = Path(tmpdir) / ".git" + git_dir.mkdir() + + # Create some test files + src_dir = Path(tmpdir) / "src" + src_dir.mkdir() + + for i in range(5): + (src_dir / f"module{i}.py").write_text( + f"class Class{i}:\n def method{i}(self):\n pass\n" + ) + + server = LocalDevServerLogic() + repo_id = server.open_repository(tmpdir) + + yield server, repo_id + + def test_warm_cache_file_tree(self, server_with_repo): + """Test warming file tree cache.""" + server, repo_id = server_with_repo + + result = server.warm_cache(repo_id, warm_file_tree=True, warm_symbols=False) + + assert "repo_id" in result + assert result["repo_id"] == repo_id + assert "file_tree" in result + assert "elapsed_seconds" in result["file_tree"] + assert "file_count" in result["file_tree"] + assert result["file_tree"]["file_count"] >= 5 + + def test_warm_cache_symbols(self, server_with_repo): + """Test warming symbol cache.""" + server, repo_id = server_with_repo + + result = server.warm_cache(repo_id, warm_file_tree=False, warm_symbols=True) + + assert "symbols" in result + assert "elapsed_seconds" in result["symbols"] + assert "symbol_count" in result["symbols"] + # Should find classes and methods + assert result["symbols"]["symbol_count"] >= 10 + + def test_warm_cache_both(self, server_with_repo): + """Test warming both caches.""" + server, repo_id = server_with_repo + + result = server.warm_cache(repo_id, warm_file_tree=True, warm_symbols=True) + + assert "file_tree" in result + assert "symbols" in result + + +class TestGetFileTreeParamsModel: + """Test GetFileTreeParams model has pagination fields.""" + + def test_has_limit_field(self): + """Test that GetFileTreeParams has limit field.""" + params = GetFileTreeParams(repo_id="test") + assert params.limit == 10000 # default + + def test_has_offset_field(self): + """Test that GetFileTreeParams has offset field.""" + params = GetFileTreeParams(repo_id="test") + assert params.offset == 0 # default + + def test_custom_pagination(self): + """Test custom pagination values.""" + params = GetFileTreeParams(repo_id="test", limit=100, offset=50) + assert params.limit == 100 + assert params.offset == 50 + + +class TestWarmCacheParamsModel: + """Test WarmCacheParams model.""" + + def test_default_values(self): + """Test default values.""" + params = WarmCacheParams(repo_id="test") + assert params.warm_file_tree is True + assert params.warm_symbols is False + + def test_custom_values(self): + """Test custom values.""" + params = WarmCacheParams(repo_id="test", warm_file_tree=False, warm_symbols=True) + assert params.warm_file_tree is False + assert params.warm_symbols is True