diff --git a/docs/configuration-guide.md b/docs/configuration-guide.md index 3e7140d8..570745f5 100644 --- a/docs/configuration-guide.md +++ b/docs/configuration-guide.md @@ -158,6 +158,95 @@ builder.build_index("./indexes/my-notes", chunks) `embedding_options` is persisted to the index `meta.json`, so subsequent `LeannSearcher` or `LeannChat` sessions automatically reuse the same provider settings (the embedding server manager forwards them to the provider for you). +## Optional Embedding Features + +### Task-Specific Prompt Templates + +Some embedding models are trained with task-specific prompts to differentiate between documents and queries. The most notable example is **Google's EmbeddingGemma**, which requires different prompts depending on the use case: + +- **Indexing documents**: `"title: none | text: "` +- **Search queries**: `"task: search result | query: "` + +LEANN supports automatic prompt prepending via the `--embedding-prompt-template` flag: + +```bash +# Build index with EmbeddingGemma (via LM Studio or Ollama) +leann build my-docs \ + --docs ./documents \ + --embedding-mode openai \ + --embedding-model text-embedding-embeddinggemma-300m-qat \ + --embedding-api-base http://localhost:1234/v1 \ + --embedding-prompt-template "title: none | text: " \ + --force + +# Search with query-specific prompt +leann search my-docs \ + --query "What is quantum computing?" \ + --embedding-prompt-template "task: search result | query: " +``` + +**Important Notes:** +- **Only use with compatible models**: EmbeddingGemma and similar task-specific models +- **NOT for regular models**: Adding prompts to models like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` will corrupt embeddings +- **Template is saved**: Build-time templates are saved to `.meta.json` for reference +- **Flexible prompts**: You can use any prompt string, or leave it empty (`""`) + +**Python API:** +```python +from leann.api import LeannBuilder + +builder = LeannBuilder( + embedding_mode="openai", + embedding_model="text-embedding-embeddinggemma-300m-qat", + embedding_options={ + "base_url": "http://localhost:1234/v1", + "api_key": "lm-studio", + "prompt_template": "title: none | text: ", + }, +) +builder.build_index("./indexes/my-docs", chunks) +``` + +**References:** +- [HuggingFace Blog: EmbeddingGemma](https://huggingface.co/blog/embeddinggemma) - Technical details + +### LM Studio Auto-Detection (Optional) + +When using LM Studio with the OpenAI-compatible API, LEANN can optionally auto-detect model context lengths via the LM Studio SDK. This eliminates manual configuration for token limits. + +**Prerequisites:** +```bash +# Install Node.js (if not already installed) +# Then install the LM Studio SDK globally +npm install -g @lmstudio/sdk +``` + +**How it works:** +1. LEANN detects LM Studio URLs (`:1234`, `lmstudio` in URL) +2. Queries model metadata via Node.js subprocess +3. Automatically unloads model after query (respects your JIT auto-evict settings) +4. Falls back to static registry if SDK unavailable + +**No configuration needed** - it works automatically when SDK is installed: + +```bash +leann build my-docs \ + --docs ./documents \ + --embedding-mode openai \ + --embedding-model text-embedding-nomic-embed-text-v1.5 \ + --embedding-api-base http://localhost:1234/v1 + # Context length auto-detected if SDK available + # Falls back to registry (2048) if not +``` + +**Benefits:** +- ✅ Automatic token limit detection +- ✅ Respects LM Studio JIT auto-evict settings +- ✅ No manual registry maintenance +- ✅ Graceful fallback if SDK unavailable + +**Note:** This is completely optional. LEANN works perfectly fine without the SDK using the built-in token limit registry. + ## Index Selection: Matching Your Scale ### HNSW (Hierarchical Navigable Small World) diff --git a/docs/faq.md b/docs/faq.md index a2fdd522..c469a918 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -8,3 +8,51 @@ You can speed up the process by using a lightweight embedding model. Add this to --embedding-model sentence-transformers/all-MiniLM-L6-v2 ``` **Model sizes:** `all-MiniLM-L6-v2` (30M parameters), `facebook/contriever` (~100M parameters), `Qwen3-0.6B` (600M parameters) + +## 2. When should I use prompt templates? + +**Use prompt templates ONLY with task-specific embedding models** like Google's EmbeddingGemma. These models are specially trained to use different prompts for documents vs queries. + +**DO NOT use with regular models** like `nomic-embed-text`, `text-embedding-3-small`, or `bge-base-en-v1.5` - adding prompts to these models will corrupt the embeddings. + +**Example usage with EmbeddingGemma:** +```bash +# Build with document prompt +leann build my-docs --embedding-prompt-template "title: none | text: " + +# Search with query prompt +leann search my-docs --query "your question" --embedding-prompt-template "task: search result | query: " +``` + +See the [Configuration Guide: Task-Specific Prompt Templates](configuration-guide.md#task-specific-prompt-templates) for detailed usage. + +## 3. Why is LM Studio loading multiple copies of my model? + +This was fixed in recent versions. LEANN now properly unloads models after querying metadata, respecting your LM Studio JIT auto-evict settings. + +**If you still see duplicates:** +- Update to the latest LEANN version +- Restart LM Studio to clear loaded models +- Check that you have JIT auto-evict enabled in LM Studio settings + +**How it works now:** +1. LEANN loads model temporarily to get context length +2. Immediately unloads after query +3. LM Studio JIT loads model on-demand for actual embeddings +4. Auto-evicts per your settings + +## 4. Do I need Node.js and @lmstudio/sdk? + +**No, it's completely optional.** LEANN works perfectly fine without them using a built-in token limit registry. + +**Benefits if you install it:** +- Automatic context length detection for LM Studio models +- No manual registry maintenance +- Always gets accurate token limits from the model itself + +**To install (optional):** +```bash +npm install -g @lmstudio/sdk +``` + +See [Configuration Guide: LM Studio Auto-Detection](configuration-guide.md#lm-studio-auto-detection-optional) for details. diff --git a/packages/leann-core/src/leann/api.py b/packages/leann-core/src/leann/api.py index bbcc8a3a..c5fb3ddd 100644 --- a/packages/leann-core/src/leann/api.py +++ b/packages/leann-core/src/leann/api.py @@ -916,6 +916,7 @@ def search( metadata_filters: Optional[dict[str, dict[str, Union[str, int, float, bool, list]]]] = None, batch_size: int = 0, use_grep: bool = False, + provider_options: Optional[dict[str, Any]] = None, **kwargs, ) -> list[SearchResult]: """ @@ -979,10 +980,24 @@ def search( start_time = time.time() + # Extract query template from stored embedding_options with fallback chain: + # 1. Check provider_options override (highest priority) + # 2. Check query_prompt_template (new format) + # 3. Check prompt_template (old format for backward compat) + # 4. None (no template) + query_template = None + if provider_options and "prompt_template" in provider_options: + query_template = provider_options["prompt_template"] + elif "query_prompt_template" in self.embedding_options: + query_template = self.embedding_options["query_prompt_template"] + elif "prompt_template" in self.embedding_options: + query_template = self.embedding_options["prompt_template"] + query_embedding = self.backend_impl.compute_query_embedding( query, use_server_if_available=recompute_embeddings, zmq_port=zmq_port, + query_template=query_template, ) logger.info(f" Generated embedding shape: {query_embedding.shape}") embedding_time = time.time() - start_time diff --git a/packages/leann-core/src/leann/cli.py b/packages/leann-core/src/leann/cli.py index 7dbd5af3..982ae9c6 100644 --- a/packages/leann-core/src/leann/cli.py +++ b/packages/leann-core/src/leann/cli.py @@ -144,6 +144,18 @@ def create_parser(self) -> argparse.ArgumentParser: default=None, help="API key for embedding service (defaults to OPENAI_API_KEY)", ) + build_parser.add_argument( + "--embedding-prompt-template", + type=str, + default=None, + help="Prompt template to prepend to all texts for embedding (e.g., 'query: ' for search)", + ) + build_parser.add_argument( + "--query-prompt-template", + type=str, + default=None, + help="Prompt template for queries (different from build template for task-specific models)", + ) build_parser.add_argument( "--force", "-f", action="store_true", help="Force rebuild existing index" ) @@ -260,6 +272,12 @@ def create_parser(self) -> argparse.ArgumentParser: action="store_true", help="Display file paths and metadata in search results", ) + search_parser.add_argument( + "--embedding-prompt-template", + type=str, + default=None, + help="Prompt template to prepend to query for embedding (e.g., 'query: ' for search)", + ) # Ask command ask_parser = subparsers.add_parser("ask", help="Ask questions") @@ -1398,6 +1416,14 @@ async def build_index(self, args): resolved_embedding_key = resolve_openai_api_key(args.embedding_api_key) if resolved_embedding_key: embedding_options["api_key"] = resolved_embedding_key + if args.query_prompt_template: + # New format: separate templates + if args.embedding_prompt_template: + embedding_options["build_prompt_template"] = args.embedding_prompt_template + embedding_options["query_prompt_template"] = args.query_prompt_template + elif args.embedding_prompt_template: + # Old format: single template (backward compat) + embedding_options["prompt_template"] = args.embedding_prompt_template builder = LeannBuilder( backend_name=args.backend_name, diff --git a/packages/leann-core/src/leann/embedding_compute.py b/packages/leann-core/src/leann/embedding_compute.py index 2af44698..a8dba9d6 100644 --- a/packages/leann-core/src/leann/embedding_compute.py +++ b/packages/leann-core/src/leann/embedding_compute.py @@ -4,8 +4,10 @@ Preserves all optimization parameters to ensure performance """ +import json import logging import os +import subprocess import time from typing import Any, Optional @@ -40,6 +42,11 @@ "text-embedding-ada-002": 8192, } +# Runtime cache for dynamically discovered token limits +# Key: (model_name, base_url), Value: token_limit +# Prevents repeated SDK/API calls for the same model +_token_limit_cache: dict[tuple[str, str], int] = {} + def get_model_token_limit( model_name: str, @@ -49,6 +56,7 @@ def get_model_token_limit( """ Get token limit for a given embedding model. Uses hybrid approach: dynamic discovery for Ollama, registry fallback for others. + Caches discovered limits to prevent repeated API/SDK calls. Args: model_name: Name of the embedding model @@ -58,12 +66,33 @@ def get_model_token_limit( Returns: Token limit for the model in tokens """ + # Check cache first to avoid repeated SDK/API calls + cache_key = (model_name, base_url or "") + if cache_key in _token_limit_cache: + cached_limit = _token_limit_cache[cache_key] + logger.debug(f"Using cached token limit for {model_name}: {cached_limit}") + return cached_limit + # Try Ollama dynamic discovery if base_url provided if base_url: # Detect Ollama servers by port or "ollama" in URL if "11434" in base_url or "ollama" in base_url.lower(): limit = _query_ollama_context_limit(model_name, base_url) if limit: + _token_limit_cache[cache_key] = limit + return limit + + # Try LM Studio SDK discovery + if "1234" in base_url or "lmstudio" in base_url.lower() or "lm.studio" in base_url.lower(): + # Convert HTTP to WebSocket URL + ws_url = base_url.replace("https://", "wss://").replace("http://", "ws://") + # Remove /v1 suffix if present + if ws_url.endswith("/v1"): + ws_url = ws_url[:-3] + + limit = _query_lmstudio_context_limit(model_name, ws_url) + if limit: + _token_limit_cache[cache_key] = limit return limit # Fallback to known model registry with version handling (from PR #154) @@ -72,19 +101,25 @@ def get_model_token_limit( # Check exact match first if model_name in EMBEDDING_MODEL_LIMITS: - return EMBEDDING_MODEL_LIMITS[model_name] + limit = EMBEDDING_MODEL_LIMITS[model_name] + _token_limit_cache[cache_key] = limit + return limit # Check base name match if base_model_name in EMBEDDING_MODEL_LIMITS: - return EMBEDDING_MODEL_LIMITS[base_model_name] + limit = EMBEDDING_MODEL_LIMITS[base_model_name] + _token_limit_cache[cache_key] = limit + return limit # Check partial matches for common patterns - for known_model, limit in EMBEDDING_MODEL_LIMITS.items(): + for known_model, registry_limit in EMBEDDING_MODEL_LIMITS.items(): if known_model in base_model_name or base_model_name in known_model: - return limit + _token_limit_cache[cache_key] = registry_limit + return registry_limit # Default fallback logger.warning(f"Unknown model '{model_name}', using default {default} token limit") + _token_limit_cache[cache_key] = default return default @@ -185,6 +220,91 @@ def _query_ollama_context_limit(model_name: str, base_url: str) -> Optional[int] return None +def _query_lmstudio_context_limit(model_name: str, base_url: str) -> Optional[int]: + """ + Query LM Studio SDK for model context length via Node.js subprocess. + + Args: + model_name: Name of the LM Studio model + base_url: Base URL of the LM Studio server (WebSocket format, e.g., "ws://localhost:1234") + + Returns: + Context limit in tokens if found, None otherwise + """ + # Inline JavaScript using @lmstudio/sdk + # Note: Load model temporarily for metadata, then unload to respect JIT auto-evict + js_code = f""" + const {{ LMStudioClient }} = require('@lmstudio/sdk'); + (async () => {{ + try {{ + const client = new LMStudioClient({{ baseUrl: '{base_url}' }}); + const model = await client.embedding.load('{model_name}', {{ verbose: false }}); + const contextLength = await model.getContextLength(); + await model.unload(); // Unload immediately to respect JIT auto-evict settings + console.log(JSON.stringify({{ contextLength, identifier: '{model_name}' }})); + }} catch (error) {{ + console.error(JSON.stringify({{ error: error.message }})); + process.exit(1); + }} + }})(); + """ + + try: + # Set NODE_PATH to include global modules for @lmstudio/sdk resolution + env = os.environ.copy() + + # Try to get npm global root (works with nvm, brew node, etc.) + try: + npm_root = subprocess.run( + ["npm", "root", "-g"], + capture_output=True, + text=True, + timeout=5, + ) + if npm_root.returncode == 0: + global_modules = npm_root.stdout.strip() + # Append to existing NODE_PATH if present + existing_node_path = env.get("NODE_PATH", "") + env["NODE_PATH"] = ( + f"{global_modules}:{existing_node_path}" + if existing_node_path + else global_modules + ) + except Exception: + # If npm not available, continue with existing NODE_PATH + pass + + result = subprocess.run( + ["node", "-e", js_code], + capture_output=True, + text=True, + timeout=10, + env=env, + ) + + if result.returncode != 0: + logger.debug(f"LM Studio SDK error: {result.stderr}") + return None + + data = json.loads(result.stdout) + context_length = data.get("contextLength") + + if context_length and context_length > 0: + logger.info(f"LM Studio SDK detected {model_name} context length: {context_length}") + return context_length + + except FileNotFoundError: + logger.debug("Node.js not found - install Node.js for LM Studio SDK features") + except subprocess.TimeoutExpired: + logger.debug("LM Studio SDK query timeout") + except json.JSONDecodeError: + logger.debug("LM Studio SDK returned invalid JSON") + except Exception as e: + logger.debug(f"LM Studio SDK query failed: {e}") + + return None + + # Global model cache to avoid repeated loading _model_cache: dict[str, Any] = {} @@ -232,6 +352,7 @@ def compute_embeddings( model_name, base_url=provider_options.get("base_url"), api_key=provider_options.get("api_key"), + provider_options=provider_options, ) elif mode == "mlx": return compute_embeddings_mlx(texts, model_name) @@ -241,6 +362,7 @@ def compute_embeddings( model_name, is_build=is_build, host=provider_options.get("host"), + provider_options=provider_options, ) elif mode == "gemini": return compute_embeddings_gemini(texts, model_name, is_build=is_build) @@ -579,6 +701,7 @@ def compute_embeddings_openai( model_name: str, base_url: Optional[str] = None, api_key: Optional[str] = None, + provider_options: Optional[dict[str, Any]] = None, ) -> np.ndarray: # TODO: @yichuan-w add progress bar only in build mode """Compute embeddings using OpenAI API""" @@ -597,26 +720,37 @@ def compute_embeddings_openai( f"Found {invalid_count} empty/invalid text(s) in input. Upstream should filter before calling OpenAI." ) - resolved_base_url = resolve_openai_base_url(base_url) - resolved_api_key = resolve_openai_api_key(api_key) + # Extract base_url and api_key from provider_options if not provided directly + provider_options = provider_options or {} + effective_base_url = base_url or provider_options.get("base_url") + effective_api_key = api_key or provider_options.get("api_key") + + resolved_base_url = resolve_openai_base_url(effective_base_url) + resolved_api_key = resolve_openai_api_key(effective_api_key) if not resolved_api_key: raise RuntimeError("OPENAI_API_KEY environment variable not set") - # Cache OpenAI client - cache_key = f"openai_client::{resolved_base_url}" - if cache_key in _model_cache: - client = _model_cache[cache_key] - else: - client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url) - _model_cache[cache_key] = client - logger.info("OpenAI client cached") + # Create OpenAI client + client = openai.OpenAI(api_key=resolved_api_key, base_url=resolved_base_url) logger.info( f"Computing embeddings for {len(texts)} texts using OpenAI API, model: '{model_name}'" ) print(f"len of texts: {len(texts)}") + # Apply prompt template if provided + prompt_template = provider_options.get("prompt_template") + + if prompt_template: + logger.warning(f"Applying prompt template: '{prompt_template}'") + texts = [f"{prompt_template}{text}" for text in texts] + + # Query token limit and apply truncation + token_limit = get_model_token_limit(model_name, base_url=effective_base_url) + logger.info(f"Using token limit: {token_limit} for model '{model_name}'") + texts = truncate_to_token_limit(texts, token_limit) + # OpenAI has limits on batch size and input length max_batch_size = 800 # Conservative batch size because the token limit is 300K all_embeddings = [] @@ -647,7 +781,15 @@ def compute_embeddings_openai( try: response = client.embeddings.create(model=model_name, input=batch_texts) batch_embeddings = [embedding.embedding for embedding in response.data] - all_embeddings.extend(batch_embeddings) + + # Verify we got the expected number of embeddings + if len(batch_embeddings) != len(batch_texts): + logger.warning( + f"Expected {len(batch_texts)} embeddings but got {len(batch_embeddings)}" + ) + + # Only take the number of embeddings that match the batch size + all_embeddings.extend(batch_embeddings[: len(batch_texts)]) except Exception as e: logger.error(f"Batch {i} failed: {e}") raise @@ -737,6 +879,7 @@ def compute_embeddings_ollama( model_name: str, is_build: bool = False, host: Optional[str] = None, + provider_options: Optional[dict[str, Any]] = None, ) -> np.ndarray: """ Compute embeddings using Ollama API with true batch processing. @@ -749,6 +892,7 @@ def compute_embeddings_ollama( model_name: Ollama model name (e.g., "nomic-embed-text", "mxbai-embed-large") is_build: Whether this is a build operation (shows progress bar) host: Ollama host URL (defaults to environment or http://localhost:11434) + provider_options: Optional provider-specific options (e.g., prompt_template) Returns: Normalized embeddings array, shape: (len(texts), embedding_dim) @@ -885,6 +1029,14 @@ def compute_embeddings_ollama( logger.info(f"Using batch size: {batch_size} for true batch processing") + # Apply prompt template if provided + provider_options = provider_options or {} + prompt_template = provider_options.get("prompt_template") + + if prompt_template: + logger.warning(f"Applying prompt template: '{prompt_template}'") + texts = [f"{prompt_template}{text}" for text in texts] + # Get model token limit and apply truncation before batching token_limit = get_model_token_limit(model_name, base_url=resolved_host) logger.info(f"Model '{model_name}' token limit: {token_limit}") diff --git a/packages/leann-core/src/leann/interface.py b/packages/leann-core/src/leann/interface.py index 83803a0c..6b7d7b7c 100644 --- a/packages/leann-core/src/leann/interface.py +++ b/packages/leann-core/src/leann/interface.py @@ -77,6 +77,7 @@ def compute_query_embedding( query: str, use_server_if_available: bool = True, zmq_port: Optional[int] = None, + query_template: Optional[str] = None, ) -> np.ndarray: """Compute embedding for a query string @@ -84,6 +85,7 @@ def compute_query_embedding( query: The query string to embed zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first + query_template: Optional prompt template to prepend to query Returns: Query embedding as numpy array with shape (1, D) diff --git a/packages/leann-core/src/leann/searcher_base.py b/packages/leann-core/src/leann/searcher_base.py index 47266057..ba5b1890 100644 --- a/packages/leann-core/src/leann/searcher_base.py +++ b/packages/leann-core/src/leann/searcher_base.py @@ -90,6 +90,7 @@ def compute_query_embedding( query: str, use_server_if_available: bool = True, zmq_port: int = 5557, + query_template: Optional[str] = None, ) -> np.ndarray: """ Compute embedding for a query string. @@ -98,10 +99,16 @@ def compute_query_embedding( query: The query string to embed zmq_port: ZMQ port for embedding server use_server_if_available: Whether to try using embedding server first + query_template: Optional prompt template to prepend to query Returns: Query embedding as numpy array """ + # Apply query template BEFORE any computation path + # This ensures template is applied consistently for both server and fallback paths + if query_template: + query = f"{query_template}{query}" + # Try to use embedding server if available and requested if use_server_if_available: try: diff --git a/pyproject.toml b/pyproject.toml index fe49ce43..b08e15f9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -164,6 +164,7 @@ python_functions = ["test_*"] markers = [ "slow: marks tests as slow (deselect with '-m \"not slow\"')", "openai: marks tests that require OpenAI API key", + "integration: marks tests that require live services (Ollama, LM Studio, etc.)", ] timeout = 300 # Reduced from 600s (10min) to 300s (5min) for CI safety addopts = [ diff --git a/tests/README.md b/tests/README.md index d3b3ec97..48ca7f59 100644 --- a/tests/README.md +++ b/tests/README.md @@ -36,6 +36,14 @@ Tests DiskANN graph partitioning functionality: - Includes performance comparison between DiskANN (with partition) and HNSW - **Note**: These tests are skipped in CI due to hardware requirements and computation time +### `test_prompt_template_e2e.py` +Integration tests for prompt template feature with live embedding services: +- Tests prompt template prepending with EmbeddingGemma (OpenAI-compatible API via LM Studio) +- Tests hybrid token limit discovery (Ollama dynamic detection, registry fallback, default) +- Tests LM Studio SDK bridge for automatic context length detection (requires Node.js + @lmstudio/sdk) +- **Note**: These tests require live services (LM Studio, Ollama) and are marked with `@pytest.mark.integration` +- **Important**: Prompt templates are ONLY for EmbeddingGemma and similar task-specific models, NOT regular embedding models + ## Running Tests ### Install test dependencies: @@ -66,6 +74,12 @@ pytest tests/ -m "not openai" # Skip slow tests pytest tests/ -m "not slow" +# Skip integration tests (that require live services) +pytest tests/ -m "not integration" + +# Run only integration tests (requires LM Studio or Ollama running) +pytest tests/test_prompt_template_e2e.py -v -s + # Run DiskANN partition tests (requires local machine, not CI) pytest tests/test_diskann_partition.py ``` @@ -101,6 +115,20 @@ The `pytest.ini` file configures: - Custom markers for slow and OpenAI tests - Verbose output with short tracebacks +### Integration Test Prerequisites + +Integration tests (`test_prompt_template_e2e.py`) require live services: + +**Required:** +- LM Studio running at `http://localhost:1234` with EmbeddingGemma model loaded + +**Optional:** +- Ollama running at `http://localhost:11434` for token limit detection tests +- Node.js + @lmstudio/sdk installed (`npm install -g @lmstudio/sdk`) for SDK bridge tests + +Tests gracefully skip if services are unavailable. + ### Known Issues - OpenAI tests are automatically skipped if no API key is provided +- Integration tests require live embedding services and may fail due to proxy settings (set `unset ALL_PROXY all_proxy` if needed) diff --git a/tests/test_cli_prompt_template.py b/tests/test_cli_prompt_template.py new file mode 100644 index 00000000..981bb785 --- /dev/null +++ b/tests/test_cli_prompt_template.py @@ -0,0 +1,533 @@ +""" +Tests for CLI argument integration of --embedding-prompt-template. + +These tests verify that: +1. The --embedding-prompt-template flag is properly registered on build and search commands +2. The template value flows from CLI args to embedding_options dict +3. The template is passed through to compute_embeddings() function +4. Default behavior (no flag) is handled correctly +""" + +from unittest.mock import Mock, patch + +from leann.cli import LeannCLI + + +class TestCLIPromptTemplateArgument: + """Tests for --embedding-prompt-template on build and search commands.""" + + def test_commands_accept_prompt_template_argument(self): + """Verify that build and search parsers accept --embedding-prompt-template flag.""" + cli = LeannCLI() + parser = cli.create_parser() + template_value = "search_query: " + + # Test build command + build_args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + "/tmp/test-docs", + "--embedding-prompt-template", + template_value, + ] + ) + assert build_args.command == "build" + assert hasattr(build_args, "embedding_prompt_template"), ( + "build command should have embedding_prompt_template attribute" + ) + assert build_args.embedding_prompt_template == template_value + + # Test search command + search_args = parser.parse_args( + ["search", "test-index", "my query", "--embedding-prompt-template", template_value] + ) + assert search_args.command == "search" + assert hasattr(search_args, "embedding_prompt_template"), ( + "search command should have embedding_prompt_template attribute" + ) + assert search_args.embedding_prompt_template == template_value + + def test_commands_default_to_none(self): + """Verify default value is None when flag not provided (backward compatibility).""" + cli = LeannCLI() + parser = cli.create_parser() + + # Test build command default + build_args = parser.parse_args(["build", "test-index", "--docs", "/tmp/test-docs"]) + assert hasattr(build_args, "embedding_prompt_template"), ( + "build command should have embedding_prompt_template attribute" + ) + assert build_args.embedding_prompt_template is None, ( + "Build default value should be None when flag not provided" + ) + + # Test search command default + search_args = parser.parse_args(["search", "test-index", "my query"]) + assert hasattr(search_args, "embedding_prompt_template"), ( + "search command should have embedding_prompt_template attribute" + ) + assert search_args.embedding_prompt_template is None, ( + "Search default value should be None when flag not provided" + ) + + +class TestBuildCommandPromptTemplateArgumentExtras: + """Additional build-specific tests for prompt template argument.""" + + def test_build_command_prompt_template_with_multiword_value(self): + """ + Verify that template values with spaces are handled correctly. + + Templates like "search_document: " or "Represent this sentence for searching: " + should be accepted as a single string argument. + """ + cli = LeannCLI() + parser = cli.create_parser() + + template = "Represent this sentence for searching: " + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + "/tmp/test-docs", + "--embedding-prompt-template", + template, + ] + ) + + assert args.embedding_prompt_template == template + + +class TestPromptTemplateStoredInEmbeddingOptions: + """Tests for template storage in embedding_options dict.""" + + @patch("leann.cli.LeannBuilder") + def test_prompt_template_stored_in_embedding_options_on_build( + self, mock_builder_class, tmp_path + ): + """ + Verify that when --embedding-prompt-template is provided to build command, + the value is stored in embedding_options dict passed to LeannBuilder. + + This test will fail because the CLI doesn't currently process this argument + and add it to embedding_options. + """ + # Setup mocks + mock_builder = Mock() + mock_builder_class.return_value = mock_builder + + # Create CLI and run build command + cli = LeannCLI() + + # Mock load_documents to return a document so builder is created + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + template = "search_query: " + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + str(tmp_path), + "--embedding-prompt-template", + template, + "--force", # Force rebuild to ensure LeannBuilder is called + ] + ) + + # Run the build command + import asyncio + + asyncio.run(cli.build_index(args)) + + # Check that LeannBuilder was called with embedding_options containing prompt_template + call_kwargs = mock_builder_class.call_args.kwargs + assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options" + + embedding_options = call_kwargs["embedding_options"] + assert embedding_options is not None, ( + "embedding_options should not be None when template provided" + ) + assert "prompt_template" in embedding_options, ( + "embedding_options should contain 'prompt_template' key" + ) + assert embedding_options["prompt_template"] == template, ( + f"Template should be '{template}', got {embedding_options.get('prompt_template')}" + ) + + @patch("leann.cli.LeannBuilder") + def test_prompt_template_not_in_options_when_not_provided(self, mock_builder_class, tmp_path): + """ + Verify that when --embedding-prompt-template is NOT provided, + embedding_options either doesn't have the key or it's None. + + This ensures we don't pass empty/None values unnecessarily. + """ + # Setup mocks + mock_builder = Mock() + mock_builder_class.return_value = mock_builder + + cli = LeannCLI() + + # Mock load_documents to return a document so builder is created + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + str(tmp_path), + "--force", # Force rebuild to ensure LeannBuilder is called + ] + ) + + import asyncio + + asyncio.run(cli.build_index(args)) + + # Check that if embedding_options is passed, it doesn't have prompt_template + call_kwargs = mock_builder_class.call_args.kwargs + if call_kwargs.get("embedding_options"): + embedding_options = call_kwargs["embedding_options"] + # Either the key shouldn't exist, or it should be None + assert ( + "prompt_template" not in embedding_options + or embedding_options["prompt_template"] is None + ), "prompt_template should not be set when flag not provided" + + # R1 Tests: Build-time separate template storage + @patch("leann.cli.LeannBuilder") + def test_build_stores_separate_templates(self, mock_builder_class, tmp_path): + """ + R1 Test 1: Verify that when both --embedding-prompt-template and + --query-prompt-template are provided to build command, both values + are stored separately in embedding_options dict as build_prompt_template + and query_prompt_template. + + This test will fail because: + 1. CLI doesn't accept --query-prompt-template flag yet + 2. CLI doesn't store templates as separate build_prompt_template and + query_prompt_template keys + + Expected behavior after implementation: + - .meta.json contains: {"embedding_options": { + "build_prompt_template": "doc: ", + "query_prompt_template": "query: " + }} + """ + # Setup mocks + mock_builder = Mock() + mock_builder_class.return_value = mock_builder + + cli = LeannCLI() + + # Mock load_documents to return a document so builder is created + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + build_template = "doc: " + query_template = "query: " + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + str(tmp_path), + "--embedding-prompt-template", + build_template, + "--query-prompt-template", + query_template, + "--force", + ] + ) + + # Run the build command + import asyncio + + asyncio.run(cli.build_index(args)) + + # Check that LeannBuilder was called with separate template keys + call_kwargs = mock_builder_class.call_args.kwargs + assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options" + + embedding_options = call_kwargs["embedding_options"] + assert embedding_options is not None, ( + "embedding_options should not be None when templates provided" + ) + + assert "build_prompt_template" in embedding_options, ( + "embedding_options should contain 'build_prompt_template' key" + ) + assert embedding_options["build_prompt_template"] == build_template, ( + f"build_prompt_template should be '{build_template}'" + ) + + assert "query_prompt_template" in embedding_options, ( + "embedding_options should contain 'query_prompt_template' key" + ) + assert embedding_options["query_prompt_template"] == query_template, ( + f"query_prompt_template should be '{query_template}'" + ) + + # Old key should NOT be present when using new separate template format + assert "prompt_template" not in embedding_options, ( + "Old 'prompt_template' key should not be present with separate templates" + ) + + @patch("leann.cli.LeannBuilder") + def test_build_backward_compat_single_template(self, mock_builder_class, tmp_path): + """ + R1 Test 2: Verify backward compatibility - when only + --embedding-prompt-template is provided (old behavior), it should + still be stored as 'prompt_template' in embedding_options. + + This ensures existing workflows continue to work unchanged. + + This test currently passes because it matches existing behavior, but it + documents the requirement that this behavior must be preserved after + implementing the separate template feature. + + Expected behavior: + - .meta.json contains: {"embedding_options": {"prompt_template": "prompt: "}} + - No build_prompt_template or query_prompt_template keys + """ + # Setup mocks + mock_builder = Mock() + mock_builder_class.return_value = mock_builder + + cli = LeannCLI() + + # Mock load_documents to return a document so builder is created + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + template = "prompt: " + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + str(tmp_path), + "--embedding-prompt-template", + template, + "--force", + ] + ) + + # Run the build command + import asyncio + + asyncio.run(cli.build_index(args)) + + # Check that LeannBuilder was called with old format + call_kwargs = mock_builder_class.call_args.kwargs + assert "embedding_options" in call_kwargs, "LeannBuilder should receive embedding_options" + + embedding_options = call_kwargs["embedding_options"] + assert embedding_options is not None, ( + "embedding_options should not be None when template provided" + ) + + assert "prompt_template" in embedding_options, ( + "embedding_options should contain old 'prompt_template' key for backward compat" + ) + assert embedding_options["prompt_template"] == template, ( + f"prompt_template should be '{template}'" + ) + + # New keys should NOT be present in backward compat mode + assert "build_prompt_template" not in embedding_options, ( + "build_prompt_template should not be present with single template flag" + ) + assert "query_prompt_template" not in embedding_options, ( + "query_prompt_template should not be present with single template flag" + ) + + @patch("leann.cli.LeannBuilder") + def test_build_no_templates(self, mock_builder_class, tmp_path): + """ + R1 Test 3: Verify that when no template flags are provided, + embedding_options has no prompt template keys. + + This ensures clean defaults and no unnecessary keys in .meta.json. + + This test currently passes because it matches existing behavior, but it + documents the requirement that this behavior must be preserved after + implementing the separate template feature. + + Expected behavior: + - .meta.json has no prompt_template, build_prompt_template, or + query_prompt_template keys (or embedding_options is empty/None) + """ + # Setup mocks + mock_builder = Mock() + mock_builder_class.return_value = mock_builder + + cli = LeannCLI() + + # Mock load_documents to return a document so builder is created + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + args = parser.parse_args(["build", "test-index", "--docs", str(tmp_path), "--force"]) + + # Run the build command + import asyncio + + asyncio.run(cli.build_index(args)) + + # Check that no template keys are present + call_kwargs = mock_builder_class.call_args.kwargs + if call_kwargs.get("embedding_options"): + embedding_options = call_kwargs["embedding_options"] + + # None of the template keys should be present + assert "prompt_template" not in embedding_options, ( + "prompt_template should not be present when no flags provided" + ) + assert "build_prompt_template" not in embedding_options, ( + "build_prompt_template should not be present when no flags provided" + ) + assert "query_prompt_template" not in embedding_options, ( + "query_prompt_template should not be present when no flags provided" + ) + + +class TestPromptTemplateFlowsToComputeEmbeddings: + """Tests for template flowing through to compute_embeddings function.""" + + @patch("leann.api.compute_embeddings") + def test_prompt_template_flows_to_compute_embeddings_via_provider_options( + self, mock_compute_embeddings, tmp_path + ): + """ + Verify that the prompt template flows from CLI args through LeannBuilder + to compute_embeddings() function via provider_options parameter. + + This is an integration test that verifies the complete flow: + CLI → embedding_options → LeannBuilder → compute_embeddings(provider_options) + + This test will fail because: + 1. CLI doesn't capture the argument yet + 2. embedding_options doesn't include prompt_template + 3. LeannBuilder doesn't pass it through to compute_embeddings + """ + # Mock compute_embeddings to return dummy embeddings as numpy array + import numpy as np + + mock_compute_embeddings.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + + # Use real LeannBuilder (not mocked) to test the actual flow + cli = LeannCLI() + + # Mock load_documents to return a simple document + cli.load_documents = Mock(return_value=[{"text": "test content", "metadata": {}}]) + + parser = cli.create_parser() + + template = "search_document: " + args = parser.parse_args( + [ + "build", + "test-index", + "--docs", + str(tmp_path), + "--embedding-prompt-template", + template, + "--backend-name", + "hnsw", # Use hnsw backend + "--force", # Force rebuild to ensure index is created + ] + ) + + # This should fail because the flow isn't implemented yet + import asyncio + + asyncio.run(cli.build_index(args)) + + # Verify compute_embeddings was called with provider_options containing prompt_template + assert mock_compute_embeddings.called, "compute_embeddings should have been called" + + # Check the call arguments + call_kwargs = mock_compute_embeddings.call_args.kwargs + assert "provider_options" in call_kwargs, ( + "compute_embeddings should receive provider_options parameter" + ) + + provider_options = call_kwargs["provider_options"] + assert provider_options is not None, "provider_options should not be None" + assert "prompt_template" in provider_options, ( + "provider_options should contain prompt_template key" + ) + assert provider_options["prompt_template"] == template, ( + f"Template should be '{template}', got {provider_options.get('prompt_template')}" + ) + + +class TestPromptTemplateArgumentHelp: + """Tests for argument help text and documentation.""" + + def test_build_command_prompt_template_has_help_text(self): + """ + Verify that --embedding-prompt-template has descriptive help text. + + Good help text is crucial for CLI usability. + """ + cli = LeannCLI() + parser = cli.create_parser() + + # Get the build subparser + # This is a bit tricky - we need to parse to get the help + # We'll check that the help includes relevant keywords + import io + from contextlib import redirect_stdout + + f = io.StringIO() + try: + with redirect_stdout(f): + parser.parse_args(["build", "--help"]) + except SystemExit: + pass # --help causes sys.exit(0) + + help_text = f.getvalue() + assert "--embedding-prompt-template" in help_text, ( + "Help text should mention --embedding-prompt-template" + ) + # Check for keywords that should be in the help + help_lower = help_text.lower() + assert any(keyword in help_lower for keyword in ["template", "prompt", "prepend"]), ( + "Help text should explain what the prompt template does" + ) + + def test_search_command_prompt_template_has_help_text(self): + """ + Verify that search command also has help text for --embedding-prompt-template. + """ + cli = LeannCLI() + parser = cli.create_parser() + + import io + from contextlib import redirect_stdout + + f = io.StringIO() + try: + with redirect_stdout(f): + parser.parse_args(["search", "--help"]) + except SystemExit: + pass # --help causes sys.exit(0) + + help_text = f.getvalue() + assert "--embedding-prompt-template" in help_text, ( + "Search help text should mention --embedding-prompt-template" + ) diff --git a/tests/test_embedding_prompt_template.py b/tests/test_embedding_prompt_template.py new file mode 100644 index 00000000..22a1fef8 --- /dev/null +++ b/tests/test_embedding_prompt_template.py @@ -0,0 +1,281 @@ +"""Unit tests for prompt template prepending in OpenAI embeddings. + +This test suite defines the contract for prompt template functionality that allows +users to prepend a consistent prompt to all embedding inputs. These tests verify: + +1. Template prepending to all input texts before embedding computation +2. Graceful handling of None/missing provider_options +3. Empty string template behavior (no-op) +4. Logging of template application for observability +5. Template application before token truncation + +All tests are written in Red Phase - they should FAIL initially because the +implementation does not exist yet. +""" + +from unittest.mock import MagicMock, Mock, patch + +import numpy as np +import pytest +from leann.embedding_compute import compute_embeddings_openai + + +class TestPromptTemplatePrepending: + """Tests for prompt template prepending in compute_embeddings_openai.""" + + @pytest.fixture + def mock_openai_client(self): + """Create mock OpenAI client that captures input texts.""" + mock_client = MagicMock() + + # Mock the embeddings.create response + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2, 0.3]), + Mock(embedding=[0.4, 0.5, 0.6]), + ] + mock_client.embeddings.create.return_value = mock_response + + return mock_client + + @pytest.fixture + def mock_openai_module(self, mock_openai_client, monkeypatch): + """Mock the openai module to return our mock client.""" + # Mock the API key environment variable + monkeypatch.setenv("OPENAI_API_KEY", "fake-test-key-for-mocking") + + # openai is imported inside the function, so we need to patch it there + with patch("openai.OpenAI", return_value=mock_openai_client) as mock_openai: + yield mock_openai + + def test_prompt_template_prepended_to_all_texts(self, mock_openai_module, mock_openai_client): + """Verify template is prepended to all input texts. + + When provider_options contains "prompt_template", that template should + be prepended to every text in the input list before sending to OpenAI API. + + This is the core functionality: the template acts as a consistent prefix + that provides context or instruction for the embedding model. + """ + texts = ["First document", "Second document"] + template = "search_document: " + provider_options = {"prompt_template": template} + + # Call compute_embeddings_openai with provider_options + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + + # Verify embeddings.create was called with templated texts + mock_openai_client.embeddings.create.assert_called_once() + call_args = mock_openai_client.embeddings.create.call_args + + # Extract the input texts sent to API + sent_texts = call_args.kwargs["input"] + + # Verify template was prepended to all texts + assert len(sent_texts) == 2, "Should send same number of texts" + assert sent_texts[0] == "search_document: First document", ( + "Template should be prepended to first text" + ) + assert sent_texts[1] == "search_document: Second document", ( + "Template should be prepended to second text" + ) + + # Verify result is valid embeddings array + assert isinstance(result, np.ndarray) + assert result.shape == (2, 3), "Should return correct shape" + + def test_template_not_applied_when_missing_or_empty( + self, mock_openai_module, mock_openai_client + ): + """Verify template not applied when provider_options is None, missing key, or empty string. + + This consolidated test covers three scenarios where templates should NOT be applied: + 1. provider_options is None (default behavior) + 2. provider_options exists but missing 'prompt_template' key + 3. prompt_template is explicitly set to empty string "" + + In all cases, texts should be sent to the API unchanged. + """ + # Scenario 1: None provider_options + texts = ["Original text one", "Original text two"] + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=None, + ) + call_args = mock_openai_client.embeddings.create.call_args + sent_texts = call_args.kwargs["input"] + assert sent_texts[0] == "Original text one", ( + "Text should be unchanged with None provider_options" + ) + assert sent_texts[1] == "Original text two" + assert isinstance(result, np.ndarray) + assert result.shape == (2, 3) + + # Reset mock for next scenario + mock_openai_client.reset_mock() + mock_response = Mock() + mock_response.data = [ + Mock(embedding=[0.1, 0.2, 0.3]), + Mock(embedding=[0.4, 0.5, 0.6]), + ] + mock_openai_client.embeddings.create.return_value = mock_response + + # Scenario 2: Missing 'prompt_template' key + texts = ["Text without template", "Another text"] + provider_options = {"base_url": "https://api.openai.com/v1"} + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + call_args = mock_openai_client.embeddings.create.call_args + sent_texts = call_args.kwargs["input"] + assert sent_texts[0] == "Text without template", "Text should be unchanged with missing key" + assert sent_texts[1] == "Another text" + assert isinstance(result, np.ndarray) + + # Reset mock for next scenario + mock_openai_client.reset_mock() + mock_openai_client.embeddings.create.return_value = mock_response + + # Scenario 3: Empty string template + texts = ["Text one", "Text two"] + provider_options = {"prompt_template": ""} + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + call_args = mock_openai_client.embeddings.create.call_args + sent_texts = call_args.kwargs["input"] + assert sent_texts[0] == "Text one", "Empty template should not modify text" + assert sent_texts[1] == "Text two" + assert isinstance(result, np.ndarray) + + def test_prompt_template_with_multiple_batches(self, mock_openai_module, mock_openai_client): + """Verify template is prepended in all batches when texts exceed batch size. + + OpenAI API has batch size limits. When input texts are split into + multiple batches, the template should be prepended to texts in every batch. + + This ensures consistency across all API calls. + """ + # Create many texts that will be split into multiple batches + texts = [f"Document {i}" for i in range(1000)] + template = "passage: " + provider_options = {"prompt_template": template} + + # Mock multiple batch responses + mock_response = Mock() + mock_response.data = [Mock(embedding=[0.1, 0.2, 0.3]) for _ in range(1000)] + mock_openai_client.embeddings.create.return_value = mock_response + + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + + # Verify embeddings.create was called multiple times (batching) + assert mock_openai_client.embeddings.create.call_count >= 2, ( + "Should make multiple API calls for large text list" + ) + + # Verify template was prepended in ALL batches + for call in mock_openai_client.embeddings.create.call_args_list: + sent_texts = call.kwargs["input"] + for text in sent_texts: + assert text.startswith(template), ( + f"All texts in all batches should start with template. Got: {text}" + ) + + # Verify result shape + assert result.shape[0] == 1000, "Should return embeddings for all texts" + + def test_prompt_template_with_special_characters(self, mock_openai_module, mock_openai_client): + """Verify template with special characters is handled correctly. + + Templates may contain special characters, Unicode, newlines, etc. + These should all be prepended correctly without encoding issues. + """ + texts = ["Document content"] + # Template with various special characters + template = "🔍 Search query [EN]: " + provider_options = {"prompt_template": template} + + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + + # Verify special characters in template were preserved + call_args = mock_openai_client.embeddings.create.call_args + sent_texts = call_args.kwargs["input"] + + assert sent_texts[0] == "🔍 Search query [EN]: Document content", ( + "Special characters in template should be preserved" + ) + + assert isinstance(result, np.ndarray) + + def test_prompt_template_integration_with_existing_validation( + self, mock_openai_module, mock_openai_client + ): + """Verify template works with existing input validation. + + compute_embeddings_openai has validation for empty texts and whitespace. + Template prepending should happen AFTER validation, so validation errors + are thrown based on original texts, not templated texts. + + This ensures users get clear error messages about their input. + """ + # Empty text should still raise ValueError even with template + texts = [""] + provider_options = {"prompt_template": "prefix: "} + + with pytest.raises(ValueError, match="empty/invalid"): + compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + + def test_prompt_template_with_api_key_and_base_url( + self, mock_openai_module, mock_openai_client + ): + """Verify template works alongside other provider_options. + + provider_options may contain multiple settings: prompt_template, + base_url, api_key. All should work together correctly. + """ + texts = ["Test document"] + provider_options = { + "prompt_template": "embed: ", + "base_url": "https://custom.api.com/v1", + "api_key": "test-key-123", + } + + result = compute_embeddings_openai( + texts=texts, + model_name="text-embedding-3-small", + provider_options=provider_options, + ) + + # Verify template was applied + call_args = mock_openai_client.embeddings.create.call_args + sent_texts = call_args.kwargs["input"] + assert sent_texts[0] == "embed: Test document" + + # Verify OpenAI client was created with correct base_url + mock_openai_module.assert_called() + client_init_kwargs = mock_openai_module.call_args.kwargs + assert client_init_kwargs["base_url"] == "https://custom.api.com/v1" + assert client_init_kwargs["api_key"] == "test-key-123" + + assert isinstance(result, np.ndarray) diff --git a/tests/test_lmstudio_bridge.py b/tests/test_lmstudio_bridge.py new file mode 100644 index 00000000..b5636820 --- /dev/null +++ b/tests/test_lmstudio_bridge.py @@ -0,0 +1,315 @@ +"""Unit tests for LM Studio TypeScript SDK bridge functionality. + +This test suite defines the contract for the LM Studio SDK bridge that queries +model context length via Node.js subprocess. These tests verify: + +1. Successful SDK query returns context length +2. Graceful fallback when Node.js not installed (FileNotFoundError) +3. Graceful fallback when SDK not installed (npm error) +4. Timeout handling (subprocess.TimeoutExpired) +5. Invalid JSON response handling + +All tests are written in Red Phase - they should FAIL initially because the +`_query_lmstudio_context_limit` function does not exist yet. + +The function contract: +- Inputs: model_name (str), base_url (str, WebSocket format "ws://localhost:1234") +- Outputs: context_length (int) or None on error +- Requirements: + 1. Call Node.js with inline JavaScript using @lmstudio/sdk + 2. 10-second timeout (accounts for Node.js startup) + 3. Graceful fallback on any error (returns None, doesn't raise) + 4. Parse JSON response with contextLength field + 5. Log errors at debug level (not warning/error) +""" + +import subprocess +from unittest.mock import Mock + +import pytest + +# Try to import the function - if it doesn't exist, tests will fail as expected +try: + from leann.embedding_compute import _query_lmstudio_context_limit +except ImportError: + # Function doesn't exist yet (Red Phase) - create a placeholder that will fail + def _query_lmstudio_context_limit(*args, **kwargs): + raise NotImplementedError( + "_query_lmstudio_context_limit not implemented yet - this is the Red Phase" + ) + + +class TestLMStudioBridge: + """Tests for LM Studio TypeScript SDK bridge integration.""" + + def test_query_lmstudio_success(self, monkeypatch): + """Verify successful SDK query returns context length. + + When the Node.js subprocess successfully queries the LM Studio SDK, + it should return a JSON response with contextLength field. The function + should parse this and return the integer context length. + """ + + def mock_run(*args, **kwargs): + # Verify timeout is set to 10 seconds + assert kwargs.get("timeout") == 10, "Should use 10-second timeout for Node.js startup" + + # Verify capture_output and text=True are set + assert kwargs.get("capture_output") is True, "Should capture stdout/stderr" + assert kwargs.get("text") is True, "Should decode output as text" + + # Return successful JSON response + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = '{"contextLength": 8192, "identifier": "custom-model"}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + # Test with typical LM Studio model + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:1234" + ) + + assert limit == 8192, "Should return context length from SDK response" + + def test_query_lmstudio_nodejs_not_found(self, monkeypatch): + """Verify graceful fallback when Node.js not installed. + + When Node.js is not installed, subprocess.run will raise FileNotFoundError. + The function should catch this and return None (graceful fallback to registry). + """ + + def mock_run(*args, **kwargs): + raise FileNotFoundError("node: command not found") + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when Node.js not installed" + + def test_query_lmstudio_sdk_not_installed(self, monkeypatch): + """Verify graceful fallback when @lmstudio/sdk not installed. + + When the SDK npm package is not installed, Node.js will return non-zero + exit code with error message in stderr. The function should detect this + and return None. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 1 + mock_result.stdout = "" + mock_result.stderr = ( + "Error: Cannot find module '@lmstudio/sdk'\nRequire stack:\n- /path/to/script.js" + ) + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when SDK not installed" + + def test_query_lmstudio_timeout(self, monkeypatch): + """Verify graceful fallback when subprocess times out. + + When the Node.js process takes longer than 10 seconds (e.g., LM Studio + not responding), subprocess.TimeoutExpired should be raised. The function + should catch this and return None. + """ + + def mock_run(*args, **kwargs): + raise subprocess.TimeoutExpired(cmd=["node", "lmstudio_bridge.js"], timeout=10) + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None on timeout" + + def test_query_lmstudio_invalid_json(self, monkeypatch): + """Verify graceful fallback when response is invalid JSON. + + When the subprocess returns malformed JSON (e.g., due to SDK error), + json.loads will raise ValueError/JSONDecodeError. The function should + catch this and return None. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = "This is not valid JSON" + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when JSON parsing fails" + + def test_query_lmstudio_missing_context_length_field(self, monkeypatch): + """Verify graceful fallback when JSON lacks contextLength field. + + When the SDK returns valid JSON but without the expected contextLength + field (e.g., error response), the function should return None. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = '{"identifier": "test-model", "error": "Model not found"}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="nonexistent-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when contextLength field missing" + + def test_query_lmstudio_null_context_length(self, monkeypatch): + """Verify graceful fallback when contextLength is null. + + When the SDK returns contextLength: null (model couldn't be loaded), + the function should return None for registry fallback. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = '{"contextLength": null, "identifier": "test-model"}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="test-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when contextLength is null" + + def test_query_lmstudio_zero_context_length(self, monkeypatch): + """Verify graceful fallback when contextLength is zero. + + When the SDK returns contextLength: 0 (invalid value), the function + should return None to trigger registry fallback. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = '{"contextLength": 0, "identifier": "test-model"}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="test-model", base_url="ws://localhost:1234" + ) + + assert limit is None, "Should return None when contextLength is zero" + + def test_query_lmstudio_with_custom_port(self, monkeypatch): + """Verify SDK query works with non-default WebSocket port. + + LM Studio can run on custom ports. The function should pass the + provided base_url to the Node.js subprocess. + """ + + def mock_run(*args, **kwargs): + # Verify the base_url argument is passed correctly + command = args[0] if args else kwargs.get("args", []) + assert "ws://localhost:8080" in " ".join(command), ( + "Should pass custom port to subprocess" + ) + + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = '{"contextLength": 4096, "identifier": "custom-model"}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="custom-model", base_url="ws://localhost:8080" + ) + + assert limit == 4096, "Should work with custom WebSocket port" + + @pytest.mark.parametrize( + "context_length,expected", + [ + (512, 512), # Small context + (2048, 2048), # Common context + (8192, 8192), # Large context + (32768, 32768), # Very large context + ], + ) + def test_query_lmstudio_various_context_lengths(self, monkeypatch, context_length, expected): + """Verify SDK query handles various context length values. + + Different models have different context lengths. The function should + correctly parse and return any positive integer value. + """ + + def mock_run(*args, **kwargs): + mock_result = Mock() + mock_result.returncode = 0 + mock_result.stdout = f'{{"contextLength": {context_length}, "identifier": "test"}}' + mock_result.stderr = "" + return mock_result + + monkeypatch.setattr("subprocess.run", mock_run) + + limit = _query_lmstudio_context_limit( + model_name="test-model", base_url="ws://localhost:1234" + ) + + assert limit == expected, f"Should return {expected} for context length {context_length}" + + def test_query_lmstudio_logs_at_debug_level(self, monkeypatch, caplog): + """Verify errors are logged at DEBUG level, not WARNING/ERROR. + + Following the graceful fallback pattern from Ollama implementation, + errors should be logged at debug level to avoid alarming users when + fallback to registry works fine. + """ + import logging + + caplog.set_level(logging.DEBUG, logger="leann.embedding_compute") + + def mock_run(*args, **kwargs): + raise FileNotFoundError("node: command not found") + + monkeypatch.setattr("subprocess.run", mock_run) + + _query_lmstudio_context_limit(model_name="test-model", base_url="ws://localhost:1234") + + # Check that debug logging occurred (not warning/error) + debug_logs = [record for record in caplog.records if record.levelname == "DEBUG"] + assert len(debug_logs) > 0, "Should log error at DEBUG level" + + # Verify no WARNING or ERROR logs + warning_or_error_logs = [ + record for record in caplog.records if record.levelname in ["WARNING", "ERROR"] + ] + assert len(warning_or_error_logs) == 0, ( + "Should not log at WARNING/ERROR level for expected failures" + ) diff --git a/tests/test_prompt_template_e2e.py b/tests/test_prompt_template_e2e.py new file mode 100644 index 00000000..80c9ccea --- /dev/null +++ b/tests/test_prompt_template_e2e.py @@ -0,0 +1,400 @@ +"""End-to-end integration tests for prompt template and token limit features. + +These tests verify real-world functionality with live services: +- OpenAI-compatible APIs (OpenAI, LM Studio) with prompt template support +- Ollama with dynamic token limit detection +- Hybrid token limit discovery mechanism + +Run with: pytest tests/test_prompt_template_e2e.py -v -s +Skip if services unavailable: pytest tests/test_prompt_template_e2e.py -m "not integration" + +Prerequisites: +1. LM Studio running with embedding model: http://localhost:1234 +2. [Optional] Ollama running: ollama serve +3. [Optional] Ollama model: ollama pull nomic-embed-text +4. [Optional] Node.js + @lmstudio/sdk for context length detection +""" + +import logging +import socket + +import numpy as np +import pytest +import requests +from leann.embedding_compute import ( + compute_embeddings_ollama, + compute_embeddings_openai, + get_model_token_limit, +) + +# Test markers for conditional execution +pytestmark = pytest.mark.integration + +logger = logging.getLogger(__name__) + + +def check_service_available(host: str, port: int, timeout: float = 2.0) -> bool: + """Check if a service is available on the given host:port.""" + try: + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(timeout) + result = sock.connect_ex((host, port)) + sock.close() + return result == 0 + except Exception: + return False + + +def check_ollama_available() -> bool: + """Check if Ollama service is available.""" + if not check_service_available("localhost", 11434): + return False + try: + response = requests.get("http://localhost:11434/api/tags", timeout=2.0) + return response.status_code == 200 + except Exception: + return False + + +def check_lmstudio_available() -> bool: + """Check if LM Studio service is available.""" + if not check_service_available("localhost", 1234): + return False + try: + response = requests.get("http://localhost:1234/v1/models", timeout=2.0) + return response.status_code == 200 + except Exception: + return False + + +def get_lmstudio_first_model() -> str: + """Get the first available model from LM Studio.""" + try: + response = requests.get("http://localhost:1234/v1/models", timeout=5.0) + data = response.json() + models = data.get("data", []) + if models: + return models[0]["id"] + except Exception: + pass + return None + + +class TestPromptTemplateOpenAI: + """End-to-end tests for prompt template with OpenAI-compatible APIs (LM Studio).""" + + @pytest.mark.skipif( + not check_lmstudio_available(), reason="LM Studio service not available on localhost:1234" + ) + def test_lmstudio_embedding_with_prompt_template(self): + """Test prompt templates with LM Studio using OpenAI-compatible API.""" + model_name = get_lmstudio_first_model() + if not model_name: + pytest.skip("No models loaded in LM Studio") + + texts = ["artificial intelligence", "machine learning"] + prompt_template = "search_query: " + + # Get embeddings with prompt template via provider_options + provider_options = {"prompt_template": prompt_template} + embeddings = compute_embeddings_openai( + texts=texts, + model_name=model_name, + base_url="http://localhost:1234/v1", + api_key="lm-studio", # LM Studio doesn't require real key + provider_options=provider_options, + ) + + assert embeddings is not None + assert len(embeddings) == 2 + assert all(isinstance(emb, np.ndarray) for emb in embeddings) + assert all(len(emb) > 0 for emb in embeddings) + + logger.info( + f"✓ LM Studio embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions" + ) + + @pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available") + def test_lmstudio_prompt_template_affects_embeddings(self): + """Verify that prompt templates actually change embedding values.""" + model_name = get_lmstudio_first_model() + if not model_name: + pytest.skip("No models loaded in LM Studio") + + text = "machine learning" + base_url = "http://localhost:1234/v1" + api_key = "lm-studio" + + # Get embeddings without template + embeddings_no_template = compute_embeddings_openai( + texts=[text], + model_name=model_name, + base_url=base_url, + api_key=api_key, + provider_options={}, + ) + + # Get embeddings with template + embeddings_with_template = compute_embeddings_openai( + texts=[text], + model_name=model_name, + base_url=base_url, + api_key=api_key, + provider_options={"prompt_template": "search_query: "}, + ) + + # Embeddings should be different when template is applied + assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0]) + + logger.info("✓ Prompt template changes embedding values as expected") + + +class TestPromptTemplateOllama: + """End-to-end tests for prompt template with Ollama.""" + + @pytest.mark.skipif( + not check_ollama_available(), reason="Ollama service not available on localhost:11434" + ) + def test_ollama_embedding_with_prompt_template(self): + """Test prompt templates with Ollama using any available embedding model.""" + # Get any available embedding model + try: + response = requests.get("http://localhost:11434/api/tags", timeout=2.0) + models = response.json().get("models", []) + + embedding_models = [] + for model in models: + name = model["name"] + base_name = name.split(":")[0] + if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]): + embedding_models.append(name) + + if not embedding_models: + pytest.skip("No embedding models available in Ollama") + + model_name = embedding_models[0] + + texts = ["artificial intelligence", "machine learning"] + prompt_template = "search_query: " + + # Get embeddings with prompt template via provider_options + provider_options = {"prompt_template": prompt_template} + embeddings = compute_embeddings_ollama( + texts=texts, + model_name=model_name, + is_build=False, + host="http://localhost:11434", + provider_options=provider_options, + ) + + assert embeddings is not None + assert len(embeddings) == 2 + assert all(isinstance(emb, np.ndarray) for emb in embeddings) + assert all(len(emb) > 0 for emb in embeddings) + + logger.info( + f"✓ Ollama embeddings with prompt template: {len(embeddings)} vectors, {len(embeddings[0])} dimensions" + ) + + except Exception as e: + pytest.skip(f"Could not test Ollama prompt template: {e}") + + @pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available") + def test_ollama_prompt_template_affects_embeddings(self): + """Verify that prompt templates actually change embedding values with Ollama.""" + # Get any available embedding model + try: + response = requests.get("http://localhost:11434/api/tags", timeout=2.0) + models = response.json().get("models", []) + + embedding_models = [] + for model in models: + name = model["name"] + base_name = name.split(":")[0] + if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]): + embedding_models.append(name) + + if not embedding_models: + pytest.skip("No embedding models available in Ollama") + + model_name = embedding_models[0] + text = "machine learning" + host = "http://localhost:11434" + + # Get embeddings without template + embeddings_no_template = compute_embeddings_ollama( + texts=[text], model_name=model_name, is_build=False, host=host, provider_options={} + ) + + # Get embeddings with template + embeddings_with_template = compute_embeddings_ollama( + texts=[text], + model_name=model_name, + is_build=False, + host=host, + provider_options={"prompt_template": "search_query: "}, + ) + + # Embeddings should be different when template is applied + assert not np.allclose(embeddings_no_template[0], embeddings_with_template[0]) + + logger.info("✓ Ollama prompt template changes embedding values as expected") + + except Exception as e: + pytest.skip(f"Could not test Ollama prompt template: {e}") + + +class TestLMStudioSDK: + """End-to-end tests for LM Studio SDK integration.""" + + @pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available") + def test_lmstudio_model_listing(self): + """Test that we can list models from LM Studio.""" + try: + response = requests.get("http://localhost:1234/v1/models", timeout=5.0) + assert response.status_code == 200 + + data = response.json() + assert "data" in data + + models = data["data"] + logger.info(f"✓ LM Studio models available: {len(models)}") + + if models: + logger.info(f" First model: {models[0].get('id', 'unknown')}") + except Exception as e: + pytest.skip(f"LM Studio API error: {e}") + + @pytest.mark.skipif(not check_lmstudio_available(), reason="LM Studio service not available") + def test_lmstudio_sdk_context_length_detection(self): + """Test context length detection via LM Studio SDK bridge (requires Node.js + SDK).""" + model_name = get_lmstudio_first_model() + if not model_name: + pytest.skip("No models loaded in LM Studio") + + try: + from leann.embedding_compute import _query_lmstudio_context_limit + + # SDK requires WebSocket URL (ws://) + context_length = _query_lmstudio_context_limit( + model_name=model_name, base_url="ws://localhost:1234" + ) + + if context_length is None: + logger.warning( + "⚠ LM Studio SDK bridge returned None (Node.js or SDK may not be available)" + ) + pytest.skip("Node.js or @lmstudio/sdk not available - SDK bridge unavailable") + else: + assert context_length > 0 + logger.info( + f"✓ LM Studio context length detected via SDK: {context_length} for {model_name}" + ) + + except ImportError: + pytest.skip("_query_lmstudio_context_limit not implemented yet") + except Exception as e: + logger.error(f"LM Studio SDK test error: {e}") + raise + + +class TestOllamaTokenLimit: + """End-to-end tests for Ollama token limit discovery.""" + + @pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available") + def test_ollama_token_limit_detection(self): + """Test dynamic token limit detection from Ollama /api/show endpoint.""" + # Get any available embedding model + try: + response = requests.get("http://localhost:11434/api/tags", timeout=2.0) + models = response.json().get("models", []) + + embedding_models = [] + for model in models: + name = model["name"] + base_name = name.split(":")[0] + if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]): + embedding_models.append(name) + + if not embedding_models: + pytest.skip("No embedding models available in Ollama") + + test_model = embedding_models[0] + + # Test token limit detection + limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434") + + assert limit > 0 + logger.info(f"✓ Ollama token limit detected: {limit} for {test_model}") + + except Exception as e: + pytest.skip(f"Could not test Ollama token detection: {e}") + + +class TestHybridTokenLimit: + """End-to-end tests for hybrid token limit discovery mechanism.""" + + def test_hybrid_discovery_registry_fallback(self): + """Test fallback to static registry for known OpenAI models.""" + # Use a known OpenAI model (should be in registry) + limit = get_model_token_limit( + model_name="text-embedding-3-small", + base_url="http://fake-server:9999", # Fake URL to force registry lookup + ) + + # text-embedding-3-small should have 8192 in registry + assert limit == 8192 + logger.info(f"✓ Hybrid discovery (registry fallback): {limit} tokens") + + def test_hybrid_discovery_default_fallback(self): + """Test fallback to safe default for completely unknown models.""" + limit = get_model_token_limit( + model_name="completely-unknown-model-xyz-12345", + base_url="http://fake-server:9999", + default=512, + ) + + # Should get the specified default + assert limit == 512 + logger.info(f"✓ Hybrid discovery (default fallback): {limit} tokens") + + @pytest.mark.skipif(not check_ollama_available(), reason="Ollama service not available") + def test_hybrid_discovery_ollama_dynamic_first(self): + """Test that Ollama models use dynamic discovery first.""" + # Get any available embedding model + try: + response = requests.get("http://localhost:11434/api/tags", timeout=2.0) + models = response.json().get("models", []) + + embedding_models = [] + for model in models: + name = model["name"] + base_name = name.split(":")[0] + if any(emb in base_name for emb in ["embed", "bge", "minilm", "e5", "nomic"]): + embedding_models.append(name) + + if not embedding_models: + pytest.skip("No embedding models available in Ollama") + + test_model = embedding_models[0] + + # Should query Ollama /api/show dynamically + limit = get_model_token_limit(model_name=test_model, base_url="http://localhost:11434") + + assert limit > 0 + logger.info(f"✓ Hybrid discovery (Ollama dynamic): {limit} tokens for {test_model}") + + except Exception as e: + pytest.skip(f"Could not test hybrid Ollama discovery: {e}") + + +if __name__ == "__main__": + print("\n" + "=" * 70) + print("INTEGRATION TEST SUITE - Real Service Testing") + print("=" * 70) + print("\nThese tests require live services:") + print(" • LM Studio: http://localhost:1234 (with embedding model loaded)") + print(" • [Optional] Ollama: http://localhost:11434") + print(" • [Optional] Node.js + @lmstudio/sdk for SDK bridge tests") + print("\nRun with: pytest tests/test_prompt_template_e2e.py -v -s") + print("=" * 70 + "\n") diff --git a/tests/test_prompt_template_persistence.py b/tests/test_prompt_template_persistence.py new file mode 100644 index 00000000..eefda045 --- /dev/null +++ b/tests/test_prompt_template_persistence.py @@ -0,0 +1,808 @@ +""" +Integration tests for prompt template metadata persistence and reuse. + +These tests verify the complete lifecycle of prompt template persistence: +1. Template is saved to .meta.json during index build +2. Template is automatically loaded during search operations +3. Template can be overridden with explicit flag during search +4. Template is reused during chat/ask operations + +These are integration tests that: +- Use real file system with temporary directories +- Run actual build and search operations +- Inspect .meta.json file contents directly +- Mock embedding servers to avoid external dependencies +- Use small test codebases for fast execution + +Expected to FAIL in Red Phase because metadata persistence verification is not yet implemented. +""" + +import json +import tempfile +from pathlib import Path +from unittest.mock import Mock, patch + +import numpy as np +import pytest +from leann.api import LeannBuilder, LeannSearcher + + +class TestPromptTemplateMetadataPersistence: + """Tests for prompt template storage in .meta.json during build.""" + + @pytest.fixture + def temp_index_dir(self): + """Create temporary directory for test indexes.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def mock_embeddings(self): + """Mock compute_embeddings to return dummy embeddings.""" + with patch("leann.api.compute_embeddings") as mock_compute: + # Return dummy embeddings as numpy array + mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + yield mock_compute + + def test_prompt_template_saved_to_metadata(self, temp_index_dir, mock_embeddings): + """ + Verify that when build is run with embedding_options containing prompt_template, + the template value is saved to .meta.json file. + + This is the core persistence requirement - templates must be saved to allow + reuse in subsequent search operations without re-specifying the flag. + + Expected failure: .meta.json exists but doesn't contain embedding_options + with prompt_template, or the value is not persisted correctly. + """ + # Setup test data + index_path = temp_index_dir / "test_index.leann" + template = "search_document: " + + # Build index with prompt template in embedding_options + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + embedding_options={"prompt_template": template}, + ) + + # Add a simple document + builder.add_text("This is a test document for indexing") + + # Build the index + builder.build_index(str(index_path)) + + # Verify .meta.json was created and contains the template + meta_path = temp_index_dir / "test_index.leann.meta.json" + assert meta_path.exists(), ".meta.json file should be created during build" + + # Read and parse metadata + with open(meta_path, encoding="utf-8") as f: + meta_data = json.load(f) + + # Verify embedding_options exists in metadata + assert "embedding_options" in meta_data, ( + "embedding_options should be saved to .meta.json when provided" + ) + + # Verify prompt_template is in embedding_options + embedding_options = meta_data["embedding_options"] + assert "prompt_template" in embedding_options, ( + "prompt_template should be saved within embedding_options" + ) + + # Verify the template value matches what we provided + assert embedding_options["prompt_template"] == template, ( + f"Template should be '{template}', got '{embedding_options.get('prompt_template')}'" + ) + + def test_prompt_template_absent_when_not_provided(self, temp_index_dir, mock_embeddings): + """ + Verify that when no prompt template is provided during build, + .meta.json either doesn't have embedding_options or prompt_template key. + + This ensures clean metadata without unnecessary keys when features aren't used. + + Expected behavior: Build succeeds, .meta.json doesn't contain prompt_template. + """ + index_path = temp_index_dir / "test_no_template.leann" + + # Build index WITHOUT prompt template + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + # No embedding_options provided + ) + + builder.add_text("Document without template") + builder.build_index(str(index_path)) + + # Verify metadata + meta_path = temp_index_dir / "test_no_template.leann.meta.json" + assert meta_path.exists() + + with open(meta_path, encoding="utf-8") as f: + meta_data = json.load(f) + + # If embedding_options exists, it should not contain prompt_template + if "embedding_options" in meta_data: + embedding_options = meta_data["embedding_options"] + assert "prompt_template" not in embedding_options, ( + "prompt_template should not be in metadata when not provided" + ) + + +class TestPromptTemplateAutoLoadOnSearch: + """Tests for automatic loading of prompt template during search operations. + + NOTE: Over-mocked test removed (test_prompt_template_auto_loaded_on_search). + This functionality is now comprehensively tested by TestQueryPromptTemplateAutoLoad + which uses simpler mocking and doesn't hang. + """ + + @pytest.fixture + def temp_index_dir(self): + """Create temporary directory for test indexes.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def mock_embeddings(self): + """Mock compute_embeddings to capture calls and return dummy embeddings.""" + with patch("leann.api.compute_embeddings") as mock_compute: + mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + yield mock_compute + + def test_search_without_template_in_metadata(self, temp_index_dir, mock_embeddings): + """ + Verify that searching an index built WITHOUT a prompt template + works correctly (backward compatibility). + + The searcher should handle missing prompt_template gracefully. + + Expected behavior: Search succeeds, no template is used. + """ + # Build index without template + index_path = temp_index_dir / "no_template.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + ) + builder.add_text("Document without template") + builder.build_index(str(index_path)) + + # Reset mocks + mock_embeddings.reset_mock() + + # Create searcher and search + searcher = LeannSearcher(index_path=str(index_path)) + + # Verify no template in embedding_options + assert "prompt_template" not in searcher.embedding_options, ( + "Searcher should not have prompt_template when not in metadata" + ) + + +class TestQueryPromptTemplateAutoLoad: + """Tests for automatic loading of separate query_prompt_template during search (R2). + + These tests verify the new two-template system where: + - build_prompt_template: Applied during index building + - query_prompt_template: Applied during search operations + + Expected to FAIL in Red Phase (R2) because query template extraction + and application is not yet implemented in LeannSearcher.search(). + """ + + @pytest.fixture + def temp_index_dir(self): + """Create temporary directory for test indexes.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def mock_compute_embeddings(self): + """Mock compute_embeddings to capture calls and return dummy embeddings.""" + with patch("leann.embedding_compute.compute_embeddings") as mock_compute: + mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + yield mock_compute + + def test_search_auto_loads_query_template(self, temp_index_dir, mock_compute_embeddings): + """ + Verify that search() automatically loads and applies query_prompt_template from .meta.json. + + Given: Index built with separate build_prompt_template and query_prompt_template + When: LeannSearcher.search("my query") is called + Then: Query embedding is computed with "query: my query" (query template applied) + + This is the core R2 requirement - query templates must be auto-loaded and applied + during search without user intervention. + + Expected failure: compute_embeddings called with raw "my query" instead of + "query: my query" because query template extraction is not implemented. + """ + # Setup: Build index with separate templates in new format + index_path = temp_index_dir / "query_template.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + embedding_options={ + "build_prompt_template": "doc: ", + "query_prompt_template": "query: ", + }, + ) + builder.add_text("Test document") + builder.build_index(str(index_path)) + + # Reset mock to ignore build calls + mock_compute_embeddings.reset_mock() + + # Act: Search with query + searcher = LeannSearcher(index_path=str(index_path)) + + # Mock the backend search to avoid actual search + with patch.object(searcher.backend_impl, "search") as mock_backend_search: + mock_backend_search.return_value = { + "labels": [["test_id_0"]], # IDs (nested list for batch support) + "distances": [[0.9]], # Distances (nested list for batch support) + } + + searcher.search("my query", top_k=1, recompute_embeddings=False) + + # Assert: compute_embeddings was called with query template applied + assert mock_compute_embeddings.called, "compute_embeddings should be called during search" + + # Get the actual text passed to compute_embeddings + call_args = mock_compute_embeddings.call_args + texts_arg = call_args[0][0] # First positional arg (list of texts) + + assert len(texts_arg) == 1, "Should compute embedding for one query" + assert texts_arg[0] == "query: my query", ( + f"Query template should be applied: expected 'query: my query', got '{texts_arg[0]}'" + ) + + def test_search_backward_compat_single_template(self, temp_index_dir, mock_compute_embeddings): + """ + Verify backward compatibility with old single prompt_template format. + + Given: Index with old format (single prompt_template, no query_prompt_template) + When: LeannSearcher.search("my query") is called + Then: Query embedding computed with "doc: my query" (old template applied) + + This ensures indexes built with the old single-template system continue + to work correctly with the new search implementation. + + Expected failure: Old template not recognized/applied because backward + compatibility logic is not implemented. + """ + # Setup: Build index with old single-template format + index_path = temp_index_dir / "old_template.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + embedding_options={"prompt_template": "doc: "}, # Old format + ) + builder.add_text("Test document") + builder.build_index(str(index_path)) + + # Reset mock + mock_compute_embeddings.reset_mock() + + # Act: Search + searcher = LeannSearcher(index_path=str(index_path)) + + with patch.object(searcher.backend_impl, "search") as mock_backend_search: + mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]} + + searcher.search("my query", top_k=1, recompute_embeddings=False) + + # Assert: Old template was applied + call_args = mock_compute_embeddings.call_args + texts_arg = call_args[0][0] + + assert texts_arg[0] == "doc: my query", ( + f"Old prompt_template should be applied for backward compatibility: " + f"expected 'doc: my query', got '{texts_arg[0]}'" + ) + + def test_search_backward_compat_no_template(self, temp_index_dir, mock_compute_embeddings): + """ + Verify backward compatibility when no template is present in .meta.json. + + Given: Index with no template in .meta.json (very old indexes) + When: LeannSearcher.search("my query") is called + Then: Query embedding computed with "my query" (no template, raw query) + + This ensures the most basic backward compatibility - indexes without + any template support continue to work as before. + + Expected failure: May fail if default template is incorrectly applied, + or if missing template causes error. + """ + # Setup: Build index without any template + index_path = temp_index_dir / "no_template.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + # No embedding_options at all + ) + builder.add_text("Test document") + builder.build_index(str(index_path)) + + # Reset mock + mock_compute_embeddings.reset_mock() + + # Act: Search + searcher = LeannSearcher(index_path=str(index_path)) + + with patch.object(searcher.backend_impl, "search") as mock_backend_search: + mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]} + + searcher.search("my query", top_k=1, recompute_embeddings=False) + + # Assert: No template applied (raw query) + call_args = mock_compute_embeddings.call_args + texts_arg = call_args[0][0] + + assert texts_arg[0] == "my query", ( + f"No template should be applied when missing from metadata: " + f"expected 'my query', got '{texts_arg[0]}'" + ) + + def test_search_override_via_provider_options(self, temp_index_dir, mock_compute_embeddings): + """ + Verify that explicit provider_options can override stored query template. + + Given: Index with query_prompt_template: "query: " + When: search() called with provider_options={"prompt_template": "override: "} + Then: Query embedding computed with "override: test" (override takes precedence) + + This enables users to experiment with different query templates without + rebuilding the index, or to handle special query types differently. + + Expected failure: provider_options parameter is accepted via **kwargs but + not used. Query embedding computed with raw "test" instead of "override: test" + because override logic is not implemented. + """ + # Setup: Build index with query template + index_path = temp_index_dir / "override_template.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + embedding_options={ + "build_prompt_template": "doc: ", + "query_prompt_template": "query: ", + }, + ) + builder.add_text("Test document") + builder.build_index(str(index_path)) + + # Reset mock + mock_compute_embeddings.reset_mock() + + # Act: Search with override + searcher = LeannSearcher(index_path=str(index_path)) + + with patch.object(searcher.backend_impl, "search") as mock_backend_search: + mock_backend_search.return_value = {"labels": [["test_id_0"]], "distances": [[0.9]]} + + # This should accept provider_options parameter + searcher.search( + "test", + top_k=1, + recompute_embeddings=False, + provider_options={"prompt_template": "override: "}, + ) + + # Assert: Override template was applied + call_args = mock_compute_embeddings.call_args + texts_arg = call_args[0][0] + + assert texts_arg[0] == "override: test", ( + f"Override template should take precedence: " + f"expected 'override: test', got '{texts_arg[0]}'" + ) + + +class TestPromptTemplateReuseInChat: + """Tests for prompt template reuse in chat/ask operations.""" + + @pytest.fixture + def temp_index_dir(self): + """Create temporary directory for test indexes.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.fixture + def mock_embeddings(self): + """Mock compute_embeddings to return dummy embeddings.""" + with patch("leann.api.compute_embeddings") as mock_compute: + mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + yield mock_compute + + @pytest.fixture + def mock_embedding_server_manager(self): + """Mock EmbeddingServerManager for chat tests.""" + with patch("leann.searcher_base.EmbeddingServerManager") as mock_manager_class: + mock_manager = Mock() + mock_manager.start_server.return_value = (True, 5557) + mock_manager_class.return_value = mock_manager + yield mock_manager + + @pytest.fixture + def index_with_template(self, temp_index_dir, mock_embeddings): + """Build an index with a prompt template.""" + index_path = temp_index_dir / "chat_template_index.leann" + template = "document_query: " + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model="text-embedding-3-small", + embedding_mode="openai", + embedding_options={"prompt_template": template}, + ) + + builder.add_text("Test document for chat") + builder.build_index(str(index_path)) + + return str(index_path), template + + +class TestPromptTemplateIntegrationWithEmbeddingModes: + """Tests for prompt template compatibility with different embedding modes.""" + + @pytest.fixture + def temp_index_dir(self): + """Create temporary directory for test indexes.""" + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + @pytest.mark.parametrize( + "mode,model,template,filename_prefix", + [ + ( + "openai", + "text-embedding-3-small", + "Represent this for searching: ", + "openai_template", + ), + ("ollama", "nomic-embed-text", "search_query: ", "ollama_template"), + ("sentence-transformers", "facebook/contriever", "query: ", "st_template"), + ], + ) + def test_prompt_template_metadata_with_embedding_modes( + self, temp_index_dir, mode, model, template, filename_prefix + ): + """Verify prompt template is saved correctly across different embedding modes. + + Tests that prompt templates are persisted to .meta.json for: + - OpenAI mode (primary use case) + - Ollama mode (also supports templates) + - Sentence-transformers mode (saved for forward compatibility) + + Expected behavior: Template is saved to .meta.json regardless of mode. + """ + with patch("leann.api.compute_embeddings") as mock_compute: + mock_compute.return_value = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + + index_path = temp_index_dir / f"{filename_prefix}.leann" + + builder = LeannBuilder( + backend_name="hnsw", + embedding_model=model, + embedding_mode=mode, + embedding_options={"prompt_template": template}, + ) + + builder.add_text(f"{mode.capitalize()} test document") + builder.build_index(str(index_path)) + + # Verify metadata + meta_path = temp_index_dir / f"{filename_prefix}.leann.meta.json" + with open(meta_path, encoding="utf-8") as f: + meta_data = json.load(f) + + assert meta_data["embedding_mode"] == mode + # Template should be saved for all modes (even if not used by some) + if "embedding_options" in meta_data: + assert meta_data["embedding_options"]["prompt_template"] == template + + +class TestQueryTemplateApplicationInComputeEmbedding: + """Tests for query template application in compute_query_embedding() (Bug Fix). + + These tests verify that query templates are applied consistently in BOTH + code paths (server and fallback) when computing query embeddings. + + This addresses the bug where query templates were only applied in the + fallback path, not when using the embedding server (the default path). + + Bug Context: + - Issue: Query templates were stored in metadata but only applied during + fallback (direct) computation, not when using embedding server + - Fix: Move template application to BEFORE any computation path in + compute_query_embedding() (searcher_base.py:107-110) + - Impact: Critical for models like EmbeddingGemma that require task-specific + templates for optimal performance + + These tests ensure the fix works correctly and prevent regression. + """ + + @pytest.fixture + def temp_index_with_template(self): + """Create a temporary index with query template in metadata""" + with tempfile.TemporaryDirectory() as tmpdir: + index_dir = Path(tmpdir) + index_file = index_dir / "test.leann" + meta_file = index_dir / "test.leann.meta.json" + + # Create minimal metadata with query template + metadata = { + "version": "1.0", + "backend_name": "hnsw", + "embedding_model": "text-embedding-embeddinggemma-300m-qat", + "dimensions": 768, + "embedding_mode": "openai", + "backend_kwargs": { + "graph_degree": 32, + "complexity": 64, + "distance_metric": "cosine", + }, + "embedding_options": { + "base_url": "http://localhost:1234/v1", + "api_key": "test-key", + "build_prompt_template": "title: none | text: ", + "query_prompt_template": "task: search result | query: ", + }, + } + + meta_file.write_text(json.dumps(metadata, indent=2)) + + # Create minimal HNSW index file (empty is okay for this test) + index_file.write_bytes(b"") + + yield str(index_file) + + def test_query_template_applied_in_fallback_path(self, temp_index_with_template): + """Test that query template is applied when using fallback (direct) path""" + from leann.searcher_base import BaseSearcher + + # Create a concrete implementation for testing + class TestSearcher(BaseSearcher): + def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + return {"labels": [], "distances": []} + + searcher = object.__new__(TestSearcher) + searcher.index_path = Path(temp_index_with_template) + searcher.index_dir = searcher.index_path.parent + + # Load metadata + meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json" + with open(meta_file) as f: + searcher.meta = json.load(f) + + searcher.embedding_model = searcher.meta["embedding_model"] + searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers") + searcher.embedding_options = searcher.meta.get("embedding_options", {}) + + # Mock compute_embeddings to capture the query text + captured_queries = [] + + def mock_compute_embeddings(texts, model, mode, provider_options=None): + captured_queries.extend(texts) + return np.random.rand(len(texts), 768).astype(np.float32) + + with patch( + "leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings + ): + # Call compute_query_embedding with template (fallback path) + result = searcher.compute_query_embedding( + query="vector database", + use_server_if_available=False, # Force fallback path + query_template="task: search result | query: ", + ) + + # Verify template was applied + assert len(captured_queries) == 1 + assert captured_queries[0] == "task: search result | query: vector database" + assert result.shape == (1, 768) + + def test_query_template_applied_in_server_path(self, temp_index_with_template): + """Test that query template is applied when using server path""" + from leann.searcher_base import BaseSearcher + + # Create a concrete implementation for testing + class TestSearcher(BaseSearcher): + def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + return {"labels": [], "distances": []} + + searcher = object.__new__(TestSearcher) + searcher.index_path = Path(temp_index_with_template) + searcher.index_dir = searcher.index_path.parent + + # Load metadata + meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json" + with open(meta_file) as f: + searcher.meta = json.load(f) + + searcher.embedding_model = searcher.meta["embedding_model"] + searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers") + searcher.embedding_options = searcher.meta.get("embedding_options", {}) + + # Mock the server methods to capture the query text + captured_queries = [] + + def mock_ensure_server_running(passages_file, port): + return port + + def mock_compute_embedding_via_server(chunks, port): + captured_queries.extend(chunks) + return np.random.rand(len(chunks), 768).astype(np.float32) + + searcher._ensure_server_running = mock_ensure_server_running + searcher._compute_embedding_via_server = mock_compute_embedding_via_server + + # Call compute_query_embedding with template (server path) + result = searcher.compute_query_embedding( + query="vector database", + use_server_if_available=True, # Use server path + query_template="task: search result | query: ", + ) + + # Verify template was applied BEFORE calling server + assert len(captured_queries) == 1 + assert captured_queries[0] == "task: search result | query: vector database" + assert result.shape == (1, 768) + + def test_query_template_without_template_parameter(self, temp_index_with_template): + """Test that query is unchanged when no template is provided""" + from leann.searcher_base import BaseSearcher + + class TestSearcher(BaseSearcher): + def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + return {"labels": [], "distances": []} + + searcher = object.__new__(TestSearcher) + searcher.index_path = Path(temp_index_with_template) + searcher.index_dir = searcher.index_path.parent + + meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json" + with open(meta_file) as f: + searcher.meta = json.load(f) + + searcher.embedding_model = searcher.meta["embedding_model"] + searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers") + searcher.embedding_options = searcher.meta.get("embedding_options", {}) + + captured_queries = [] + + def mock_compute_embeddings(texts, model, mode, provider_options=None): + captured_queries.extend(texts) + return np.random.rand(len(texts), 768).astype(np.float32) + + with patch( + "leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings + ): + searcher.compute_query_embedding( + query="vector database", + use_server_if_available=False, + query_template=None, # No template + ) + + # Verify query is unchanged + assert len(captured_queries) == 1 + assert captured_queries[0] == "vector database" + + def test_query_template_consistency_between_paths(self, temp_index_with_template): + """Test that both paths apply template identically""" + from leann.searcher_base import BaseSearcher + + class TestSearcher(BaseSearcher): + def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + return {"labels": [], "distances": []} + + searcher = object.__new__(TestSearcher) + searcher.index_path = Path(temp_index_with_template) + searcher.index_dir = searcher.index_path.parent + + meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json" + with open(meta_file) as f: + searcher.meta = json.load(f) + + searcher.embedding_model = searcher.meta["embedding_model"] + searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers") + searcher.embedding_options = searcher.meta.get("embedding_options", {}) + + query_template = "task: search result | query: " + original_query = "vector database" + + # Capture queries from fallback path + fallback_queries = [] + + def mock_compute_embeddings(texts, model, mode, provider_options=None): + fallback_queries.extend(texts) + return np.random.rand(len(texts), 768).astype(np.float32) + + with patch( + "leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings + ): + searcher.compute_query_embedding( + query=original_query, + use_server_if_available=False, + query_template=query_template, + ) + + # Capture queries from server path + server_queries = [] + + def mock_ensure_server_running(passages_file, port): + return port + + def mock_compute_embedding_via_server(chunks, port): + server_queries.extend(chunks) + return np.random.rand(len(chunks), 768).astype(np.float32) + + searcher._ensure_server_running = mock_ensure_server_running + searcher._compute_embedding_via_server = mock_compute_embedding_via_server + + searcher.compute_query_embedding( + query=original_query, + use_server_if_available=True, + query_template=query_template, + ) + + # Verify both paths produced identical templated queries + assert len(fallback_queries) == 1 + assert len(server_queries) == 1 + assert fallback_queries[0] == server_queries[0] + assert fallback_queries[0] == f"{query_template}{original_query}" + + def test_query_template_with_empty_string(self, temp_index_with_template): + """Test behavior with empty template string""" + from leann.searcher_base import BaseSearcher + + class TestSearcher(BaseSearcher): + def search(self, query_vectors, top_k, complexity, beam_width=1, **kwargs): + return {"labels": [], "distances": []} + + searcher = object.__new__(TestSearcher) + searcher.index_path = Path(temp_index_with_template) + searcher.index_dir = searcher.index_path.parent + + meta_file = searcher.index_dir / f"{searcher.index_path.name}.meta.json" + with open(meta_file) as f: + searcher.meta = json.load(f) + + searcher.embedding_model = searcher.meta["embedding_model"] + searcher.embedding_mode = searcher.meta.get("embedding_mode", "sentence-transformers") + searcher.embedding_options = searcher.meta.get("embedding_options", {}) + + captured_queries = [] + + def mock_compute_embeddings(texts, model, mode, provider_options=None): + captured_queries.extend(texts) + return np.random.rand(len(texts), 768).astype(np.float32) + + with patch( + "leann.embedding_compute.compute_embeddings", side_effect=mock_compute_embeddings + ): + searcher.compute_query_embedding( + query="vector database", + use_server_if_available=False, + query_template="", # Empty string + ) + + # Empty string is falsy, so no template should be applied + assert captured_queries[0] == "vector database" diff --git a/tests/test_token_truncation.py b/tests/test_token_truncation.py index ad00e3a6..bfb3ca23 100644 --- a/tests/test_token_truncation.py +++ b/tests/test_token_truncation.py @@ -266,3 +266,378 @@ def test_truncate_exact_token_limit(self, tokenizer): assert result_tokens <= target_tokens, ( f"Should be ≤{target_tokens} tokens, got {result_tokens}" ) + + +class TestLMStudioHybridDiscovery: + """Tests for LM Studio integration in get_model_token_limit() hybrid discovery. + + These tests verify that get_model_token_limit() properly integrates with + the LM Studio SDK bridge for dynamic token limit discovery. The integration + should: + + 1. Detect LM Studio URLs (port 1234 or 'lmstudio'/'lm.studio' in URL) + 2. Convert HTTP URLs to WebSocket format for SDK queries + 3. Query LM Studio SDK and use discovered limit + 4. Fall back to registry when SDK returns None + 5. Execute AFTER Ollama detection but BEFORE registry fallback + + All tests are written in Red Phase - they should FAIL initially because the + LM Studio detection and integration logic does not exist yet in get_model_token_limit(). + """ + + def test_get_model_token_limit_lmstudio_success(self, monkeypatch): + """Verify LM Studio SDK query succeeds and returns detected limit. + + When a LM Studio base_url is detected and the SDK query succeeds, + get_model_token_limit() should return the dynamically discovered + context length without falling back to the registry. + """ + + # Mock _query_lmstudio_context_limit to return successful SDK query + def mock_query_lmstudio(model_name, base_url): + # Verify WebSocket URL was passed (not HTTP) + assert base_url.startswith("ws://"), ( + f"Should convert HTTP to WebSocket format, got: {base_url}" + ) + return 8192 # Successful SDK query + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test with HTTP URL that should be converted to WebSocket + limit = get_model_token_limit( + model_name="custom-model", base_url="http://localhost:1234/v1" + ) + + assert limit == 8192, "Should return limit from LM Studio SDK query" + + def test_get_model_token_limit_lmstudio_fallback_to_registry(self, monkeypatch): + """Verify fallback to registry when LM Studio SDK returns None. + + When LM Studio SDK query fails (returns None), get_model_token_limit() + should fall back to the EMBEDDING_MODEL_LIMITS registry. + """ + + # Mock _query_lmstudio_context_limit to return None (SDK failure) + def mock_query_lmstudio(model_name, base_url): + return None # SDK query failed + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test with known model that exists in registry + limit = get_model_token_limit( + model_name="nomic-embed-text", base_url="http://localhost:1234/v1" + ) + + # Should fall back to registry value + assert limit == 2048, "Should fall back to registry when SDK returns None" + + def test_get_model_token_limit_lmstudio_port_detection(self, monkeypatch): + """Verify detection of LM Studio via port 1234. + + get_model_token_limit() should recognize port 1234 as a LM Studio + server and attempt SDK query, regardless of hostname. + """ + query_called = False + + def mock_query_lmstudio(model_name, base_url): + nonlocal query_called + query_called = True + return 4096 + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test with port 1234 (default LM Studio port) + limit = get_model_token_limit(model_name="test-model", base_url="http://127.0.0.1:1234/v1") + + assert query_called, "Should detect port 1234 and call LM Studio SDK query" + assert limit == 4096, "Should return SDK query result" + + @pytest.mark.parametrize( + "test_url,expected_limit,keyword", + [ + ("http://lmstudio.local:8080/v1", 16384, "lmstudio"), + ("http://api.lm.studio:5000/v1", 32768, "lm.studio"), + ], + ) + def test_get_model_token_limit_lmstudio_url_keyword_detection( + self, monkeypatch, test_url, expected_limit, keyword + ): + """Verify detection of LM Studio via keywords in URL. + + get_model_token_limit() should recognize 'lmstudio' or 'lm.studio' + in the URL as indicating a LM Studio server. + """ + query_called = False + + def mock_query_lmstudio(model_name, base_url): + nonlocal query_called + query_called = True + return expected_limit + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + limit = get_model_token_limit(model_name="test-model", base_url=test_url) + + assert query_called, f"Should detect '{keyword}' keyword and call SDK query" + assert limit == expected_limit, f"Should return SDK query result for {keyword}" + + @pytest.mark.parametrize( + "input_url,expected_protocol,expected_host", + [ + ("http://localhost:1234/v1", "ws://", "localhost:1234"), + ("https://lmstudio.example.com:1234/v1", "wss://", "lmstudio.example.com:1234"), + ], + ) + def test_get_model_token_limit_protocol_conversion( + self, monkeypatch, input_url, expected_protocol, expected_host + ): + """Verify HTTP/HTTPS URL is converted to WebSocket format for SDK query. + + LM Studio SDK requires WebSocket URLs. get_model_token_limit() should: + 1. Convert 'http://' to 'ws://' + 2. Convert 'https://' to 'wss://' + 3. Remove '/v1' or other path suffixes (SDK expects base URL) + 4. Preserve host and port + """ + conversions_tested = [] + + def mock_query_lmstudio(model_name, base_url): + conversions_tested.append(base_url) + return 8192 + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + get_model_token_limit(model_name="test-model", base_url=input_url) + + # Verify conversion happened + assert len(conversions_tested) == 1, "Should have called SDK query once" + assert conversions_tested[0].startswith(expected_protocol), ( + f"Should convert to {expected_protocol}" + ) + assert expected_host in conversions_tested[0], ( + f"Should preserve host and port: {expected_host}" + ) + + def test_get_model_token_limit_lmstudio_executes_after_ollama(self, monkeypatch): + """Verify LM Studio detection happens AFTER Ollama detection. + + The hybrid discovery order should be: + 1. Ollama dynamic discovery (port 11434 or 'ollama' in URL) + 2. LM Studio dynamic discovery (port 1234 or 'lmstudio' in URL) + 3. Registry fallback + + If both Ollama and LM Studio patterns match, Ollama should take precedence. + This test verifies that LM Studio is checked but doesn't interfere with Ollama. + """ + ollama_called = False + lmstudio_called = False + + def mock_query_ollama(model_name, base_url): + nonlocal ollama_called + ollama_called = True + return 2048 # Ollama query succeeds + + def mock_query_lmstudio(model_name, base_url): + nonlocal lmstudio_called + lmstudio_called = True + return None # Should not be reached if Ollama succeeds + + monkeypatch.setattr( + "leann.embedding_compute._query_ollama_context_limit", + mock_query_ollama, + ) + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test with Ollama URL + limit = get_model_token_limit( + model_name="test-model", base_url="http://localhost:11434/api" + ) + + assert ollama_called, "Should attempt Ollama query first" + assert not lmstudio_called, "Should not attempt LM Studio query when Ollama succeeds" + assert limit == 2048, "Should return Ollama result" + + def test_get_model_token_limit_lmstudio_not_detected_for_non_lmstudio_urls(self, monkeypatch): + """Verify LM Studio SDK query is NOT called for non-LM Studio URLs. + + Only URLs with port 1234 or 'lmstudio'/'lm.studio' keywords should + trigger LM Studio SDK queries. Other URLs should skip to registry fallback. + """ + lmstudio_called = False + + def mock_query_lmstudio(model_name, base_url): + nonlocal lmstudio_called + lmstudio_called = True + return 8192 + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test with non-LM Studio URLs + test_cases = [ + "http://localhost:8080/v1", # Different port + "http://openai.example.com/v1", # Different service + "http://localhost:3000/v1", # Another port + ] + + for base_url in test_cases: + lmstudio_called = False # Reset for each test + get_model_token_limit(model_name="nomic-embed-text", base_url=base_url) + assert not lmstudio_called, f"Should NOT call LM Studio SDK for URL: {base_url}" + + def test_get_model_token_limit_lmstudio_case_insensitive_detection(self, monkeypatch): + """Verify LM Studio detection is case-insensitive for keywords. + + Keywords 'lmstudio' and 'lm.studio' should be detected regardless + of case (LMStudio, LMSTUDIO, LmStudio, etc.). + """ + query_called = False + + def mock_query_lmstudio(model_name, base_url): + nonlocal query_called + query_called = True + return 8192 + + monkeypatch.setattr( + "leann.embedding_compute._query_lmstudio_context_limit", + mock_query_lmstudio, + ) + + # Test various case variations + test_cases = [ + "http://LMStudio.local:8080/v1", + "http://LMSTUDIO.example.com/v1", + "http://LmStudio.local/v1", + "http://api.LM.STUDIO:5000/v1", + ] + + for base_url in test_cases: + query_called = False # Reset for each test + limit = get_model_token_limit(model_name="test-model", base_url=base_url) + assert query_called, f"Should detect LM Studio in URL: {base_url}" + assert limit == 8192, f"Should return SDK result for URL: {base_url}" + + +class TestTokenLimitCaching: + """Tests for token limit caching to prevent repeated SDK/API calls. + + Caching prevents duplicate SDK/API calls within the same Python process, + which is important because: + 1. LM Studio SDK load() can load duplicate model instances + 2. Ollama /api/show queries add latency + 3. Registry lookups are pure overhead + + Cache is process-scoped and resets between leann build invocations. + """ + + def setup_method(self): + """Clear cache before each test.""" + from leann.embedding_compute import _token_limit_cache + + _token_limit_cache.clear() + + def test_registry_lookup_is_cached(self): + """Verify that registry lookups are cached.""" + from leann.embedding_compute import _token_limit_cache + + # First call + limit1 = get_model_token_limit("text-embedding-3-small") + assert limit1 == 8192 + + # Verify it's in cache + cache_key = ("text-embedding-3-small", "") + assert cache_key in _token_limit_cache + assert _token_limit_cache[cache_key] == 8192 + + # Second call should use cache + limit2 = get_model_token_limit("text-embedding-3-small") + assert limit2 == 8192 + + def test_default_fallback_is_cached(self): + """Verify that default fallbacks are cached.""" + from leann.embedding_compute import _token_limit_cache + + # First call with unknown model + limit1 = get_model_token_limit("unknown-model-xyz", default=512) + assert limit1 == 512 + + # Verify it's in cache + cache_key = ("unknown-model-xyz", "") + assert cache_key in _token_limit_cache + assert _token_limit_cache[cache_key] == 512 + + # Second call should use cache + limit2 = get_model_token_limit("unknown-model-xyz", default=512) + assert limit2 == 512 + + def test_different_urls_create_separate_cache_entries(self): + """Verify that different base_urls create separate cache entries.""" + from leann.embedding_compute import _token_limit_cache + + # Same model, different URLs + limit1 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:11434") + limit2 = get_model_token_limit("nomic-embed-text", base_url="http://localhost:1234/v1") + + # Both should find the model in registry (2048) + assert limit1 == 2048 + assert limit2 == 2048 + + # But they should be separate cache entries + cache_key1 = ("nomic-embed-text", "http://localhost:11434") + cache_key2 = ("nomic-embed-text", "http://localhost:1234/v1") + + assert cache_key1 in _token_limit_cache + assert cache_key2 in _token_limit_cache + assert len(_token_limit_cache) == 2 + + def test_cache_prevents_repeated_lookups(self): + """Verify that cache prevents repeated registry/API lookups.""" + from leann.embedding_compute import _token_limit_cache + + model_name = "text-embedding-ada-002" + + # First call - should add to cache + assert len(_token_limit_cache) == 0 + limit1 = get_model_token_limit(model_name) + + cache_size_after_first = len(_token_limit_cache) + assert cache_size_after_first == 1 + + # Multiple subsequent calls - cache size should not change + for _ in range(5): + limit = get_model_token_limit(model_name) + assert limit == limit1 + assert len(_token_limit_cache) == cache_size_after_first + + def test_versioned_model_names_cached_correctly(self): + """Verify that versioned model names (e.g., model:tag) are cached.""" + from leann.embedding_compute import _token_limit_cache + + # Model with version tag + limit = get_model_token_limit("nomic-embed-text:latest", base_url="http://localhost:11434") + assert limit == 2048 + + # Should be cached with full name including version + cache_key = ("nomic-embed-text:latest", "http://localhost:11434") + assert cache_key in _token_limit_cache + assert _token_limit_cache[cache_key] == 2048