diff --git a/README.md b/README.md index 20712e4f..f5eaceb1 100644 --- a/README.md +++ b/README.md @@ -238,7 +238,7 @@ We're excited to announce the support for **RedisVL Extensions**. These modules Increase application throughput and reduce the cost of using LLM models in production by leveraging previously generated knowledge with the [`SemanticCache`](https://docs.redisvl.com/en/stable/api/cache.html#semanticcache). ```python -from redisvl.extensions.llmcache import SemanticCache +from redisvl.extensions.cache.llm import SemanticCache # init cache with TTL and semantic distance threshold llmcache = SemanticCache( diff --git a/docs/api/cache.rst b/docs/api/cache.rst index 14132316..e3cc3131 100644 --- a/docs/api/cache.rst +++ b/docs/api/cache.rst @@ -7,9 +7,26 @@ SemanticCache .. _semantic_cache_api: -.. currentmodule:: redisvl.extensions.llmcache +.. currentmodule:: redisvl.extensions.cache.llm .. autoclass:: SemanticCache :show-inheritance: :members: :inherited-members: + + +**************** +Embeddings Cache +**************** + +EmbeddingsCache +=============== + +.. _embeddings_cache_api: + +.. currentmodule:: redisvl.extensions.cache.embeddings + +.. autoclass:: EmbeddingsCache + :show-inheritance: + :members: + :inherited-members: diff --git a/docs/user_guide/03_llmcache.ipynb b/docs/user_guide/03_llmcache.ipynb index 1a20c0d0..a9e4709a 100644 --- a/docs/user_guide/03_llmcache.ipynb +++ b/docs/user_guide/03_llmcache.ipynb @@ -88,7 +88,7 @@ } ], "source": [ - "from redisvl.extensions.llmcache import SemanticCache\n", + "from redisvl.extensions.cache.llm import SemanticCache\n", "\n", "llmcache = SemanticCache(\n", " name=\"llmcache\", # underlying search index name\n", diff --git a/docs/user_guide/04_vectorizers.ipynb b/docs/user_guide/04_vectorizers.ipynb index c4f862e1..13c3715a 100644 --- a/docs/user_guide/04_vectorizers.ipynb +++ b/docs/user_guide/04_vectorizers.ipynb @@ -609,7 +609,7 @@ "metadata": {}, "outputs": [], "source": [ - "from redisvl.extensions.llmcache import SemanticCache\n", + "from redisvl.extensions.cache.llm import SemanticCache\n", "\n", "cache = SemanticCache(name=\"custom_cache\", vectorizer=custom_vectorizer)\n", "\n", diff --git a/docs/user_guide/09_threshold_optimization.ipynb b/docs/user_guide/09_threshold_optimization.ipynb index 2dbac38f..602892df 100644 --- a/docs/user_guide/09_threshold_optimization.ipynb +++ b/docs/user_guide/09_threshold_optimization.ipynb @@ -24,7 +24,7 @@ "metadata": {}, "outputs": [], "source": [ - "from redisvl.extensions.llmcache import SemanticCache\n", + "from redisvl.extensions.cache.llm import SemanticCache\n", "\n", "sem_cache = SemanticCache(\n", " name=\"sem_cache\", # underlying search index name\n", diff --git a/docs/user_guide/10_embeddings_cache.ipynb b/docs/user_guide/10_embeddings_cache.ipynb new file mode 100644 index 00000000..d5a90096 --- /dev/null +++ b/docs/user_guide/10_embeddings_cache.ipynb @@ -0,0 +1,793 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Caching Embeddings\n", + "\n", + "RedisVL provides an `EmbeddingsCache` that makes it easy to store and retrieve embedding vectors with their associated text and metadata. This cache is particularly useful for applications that frequently compute the same embeddings, enabling you to:\n", + "\n", + "- Reduce computational costs by reusing previously computed embeddings\n", + "- Decrease latency in applications that rely on embeddings\n", + "- Store additional metadata alongside embeddings for richer applications\n", + "\n", + "This notebook will show you how to use the `EmbeddingsCache` effectively in your applications." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Setup\n", + "\n", + "First, let's import the necessary libraries. We'll use a text embedding model from HuggingFace to generate our embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import time\n", + "import numpy as np\n", + "\n", + "# Disable tokenizers parallelism to avoid deadlocks\n", + "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"False\"\n", + "\n", + "# Import the EmbeddingsCache\n", + "from redisvl.extensions.cache.embeddings import EmbeddingsCache\n", + "from redisvl.utils.vectorize import HFTextVectorizer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's create a vectorizer to generate embeddings for our texts:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the vectorizer\n", + "vectorizer = HFTextVectorizer(\n", + " model=\"sentence-transformers/all-mpnet-base-v2\",\n", + " cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Initializing the EmbeddingsCache\n", + "\n", + "Now let's initialize our `EmbeddingsCache`. The cache requires a Redis connection to store the embeddings and their associated data." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialize the embeddings cache\n", + "cache = EmbeddingsCache(\n", + " name=\"embedcache\", # name prefix for Redis keys\n", + " redis_url=\"redis://localhost:6379\", # Redis connection URL\n", + " ttl=None # Optional TTL in seconds (None means no expiration)\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Basic Usage\n", + "\n", + "### Storing Embeddings\n", + "\n", + "Let's store some text with its embedding in the cache. The `set` method takes the following parameters:\n", + "- `text`: The input text that was embedded\n", + "- `model_name`: The name of the embedding model used\n", + "- `embedding`: The embedding vector\n", + "- `metadata`: Optional metadata associated with the embedding\n", + "- `ttl`: Optional time-to-live override for this specific entry" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stored with key: embedcache:059d...\n" + ] + } + ], + "source": [ + "# Text to embed\n", + "text = \"What is machine learning?\"\n", + "model_name = \"sentence-transformers/all-mpnet-base-v2\"\n", + "\n", + "# Generate the embedding\n", + "embedding = vectorizer.embed(text)\n", + "\n", + "# Optional metadata\n", + "metadata = {\"category\": \"ai\", \"source\": \"user_query\"}\n", + "\n", + "# Store in cache\n", + "key = cache.set(\n", + " text=text,\n", + " model_name=model_name,\n", + " embedding=embedding,\n", + " metadata=metadata\n", + ")\n", + "\n", + "print(f\"Stored with key: {key[:15]}...\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Retrieving Embeddings\n", + "\n", + "To retrieve an embedding from the cache, use the `get` method with the original text and model name:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Found in cache: What is machine learning?\n", + "Model: sentence-transformers/all-mpnet-base-v2\n", + "Metadata: {'category': 'ai', 'source': 'user_query'}\n", + "Embedding shape: (768,)\n" + ] + } + ], + "source": [ + "# Retrieve from cache\n", + "\n", + "if result := cache.get(text=text, model_name=model_name):\n", + " print(f\"Found in cache: {result['text']}\")\n", + " print(f\"Model: {result['model_name']}\")\n", + " print(f\"Metadata: {result['metadata']}\")\n", + " print(f\"Embedding shape: {np.array(result['embedding']).shape}\")\n", + "else:\n", + " print(\"Not found in cache.\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Checking Existence\n", + "\n", + "You can check if an embedding exists in the cache without retrieving it using the `exists` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "First query exists in cache: True\n", + "New query exists in cache: False\n" + ] + } + ], + "source": [ + "# Check if existing text is in cache\n", + "exists = cache.exists(text=text, model_name=model_name)\n", + "print(f\"First query exists in cache: {exists}\")\n", + "\n", + "# Check if a new text is in cache\n", + "new_text = \"What is deep learning?\"\n", + "exists = cache.exists(text=new_text, model_name=model_name)\n", + "print(f\"New query exists in cache: {exists}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Removing Entries\n", + "\n", + "To remove an entry from the cache, use the `drop` method:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "After dropping: False\n" + ] + } + ], + "source": [ + "# Remove from cache\n", + "cache.drop(text=text, model_name=model_name)\n", + "\n", + "# Verify it's gone\n", + "exists = cache.exists(text=text, model_name=model_name)\n", + "print(f\"After dropping: {exists}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Advanced Usage\n", + "\n", + "### Key-Based Operations\n", + "\n", + "The `EmbeddingsCache` also provides methods that work directly with Redis keys, which can be useful for advanced use cases:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stored with key: embedcache:059d...\n", + "Exists by key: True\n", + "Retrieved by key: What is machine learning?\n" + ] + } + ], + "source": [ + "# Store an entry again\n", + "key = cache.set(\n", + " text=text,\n", + " model_name=model_name,\n", + " embedding=embedding,\n", + " metadata=metadata\n", + ")\n", + "print(f\"Stored with key: {key[:15]}...\")\n", + "\n", + "# Check existence by key\n", + "exists_by_key = cache.exists_by_key(key)\n", + "print(f\"Exists by key: {exists_by_key}\")\n", + "\n", + "# Retrieve by key\n", + "result_by_key = cache.get_by_key(key)\n", + "print(f\"Retrieved by key: {result_by_key['text']}\")\n", + "\n", + "# Drop by key\n", + "cache.drop_by_key(key)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Batch Operations\n", + "\n", + "When working with multiple embeddings, batch operations can significantly improve performance by reducing network roundtrips. The `EmbeddingsCache` provides methods prefixed with `m` (for \"multi\") that handle batches efficiently." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stored 3 embeddings with batch operation\n", + "All embeddings exist: True\n", + "Retrieved 3 embeddings in one operation\n" + ] + } + ], + "source": [ + "# Create multiple embeddings\n", + "texts = [\n", + " \"What is machine learning?\",\n", + " \"How do neural networks work?\",\n", + " \"What is deep learning?\"\n", + "]\n", + "embeddings = [vectorizer.embed(t) for t in texts]\n", + "\n", + "# Prepare batch items as dictionaries\n", + "batch_items = [\n", + " {\n", + " \"text\": texts[0],\n", + " \"model_name\": model_name,\n", + " \"embedding\": embeddings[0],\n", + " \"metadata\": {\"category\": \"ai\", \"type\": \"question\"}\n", + " },\n", + " {\n", + " \"text\": texts[1],\n", + " \"model_name\": model_name,\n", + " \"embedding\": embeddings[1],\n", + " \"metadata\": {\"category\": \"ai\", \"type\": \"question\"}\n", + " },\n", + " {\n", + " \"text\": texts[2],\n", + " \"model_name\": model_name,\n", + " \"embedding\": embeddings[2],\n", + " \"metadata\": {\"category\": \"ai\", \"type\": \"question\"}\n", + " }\n", + "]\n", + "\n", + "# Store multiple embeddings in one operation\n", + "keys = cache.mset(batch_items)\n", + "print(f\"Stored {len(keys)} embeddings with batch operation\")\n", + "\n", + "# Check if multiple embeddings exist in one operation\n", + "exist_results = cache.mexists(texts, model_name)\n", + "print(f\"All embeddings exist: {all(exist_results)}\")\n", + "\n", + "# Retrieve multiple embeddings in one operation\n", + "results = cache.mget(texts, model_name)\n", + "print(f\"Retrieved {len(results)} embeddings in one operation\")\n", + "\n", + "# Delete multiple embeddings in one operation\n", + "cache.mdrop(texts, model_name)\n", + "\n", + "# Alternative: key-based batch operations\n", + "# cache.mget_by_keys(keys) # Retrieve by keys\n", + "# cache.mexists_by_keys(keys) # Check existence by keys\n", + "# cache.mdrop_by_keys(keys) # Delete by keys" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Batch operations are particularly beneficial when working with large numbers of embeddings. They provide the same functionality as individual operations but with better performance by reducing network roundtrips.\n", + "\n", + "For asynchronous applications, async versions of all batch methods are also available with the `am` prefix (e.g., `amset`, `amget`, `amexists`, `amdrop`)." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Working with TTL (Time-To-Live)\n", + "\n", + "You can set a global TTL when initializing the cache, or specify TTL for individual entries:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Immediately after setting: True\n", + "After waiting: False\n" + ] + } + ], + "source": [ + "# Create a cache with a default 5-second TTL\n", + "ttl_cache = EmbeddingsCache(\n", + " name=\"ttl_cache\",\n", + " redis_url=\"redis://localhost:6379\",\n", + " ttl=5 # 5 second TTL\n", + ")\n", + "\n", + "# Store an entry\n", + "key = ttl_cache.set(\n", + " text=text,\n", + " model_name=model_name,\n", + " embedding=embedding\n", + ")\n", + "\n", + "# Check if it exists\n", + "exists = ttl_cache.exists_by_key(key)\n", + "print(f\"Immediately after setting: {exists}\")\n", + "\n", + "# Wait for it to expire\n", + "time.sleep(6)\n", + "\n", + "# Check again\n", + "exists = ttl_cache.exists_by_key(key)\n", + "print(f\"After waiting: {exists}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also override the default TTL for individual entries:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Entry with custom TTL after 2 seconds: False\n", + "Entry with default TTL after 2 seconds: True\n" + ] + } + ], + "source": [ + "# Store an entry with a custom 1-second TTL\n", + "key1 = ttl_cache.set(\n", + " text=\"Short-lived entry\",\n", + " model_name=model_name,\n", + " embedding=embedding,\n", + " ttl=1 # Override with 1 second TTL\n", + ")\n", + "\n", + "# Store another entry with the default TTL (5 seconds)\n", + "key2 = ttl_cache.set(\n", + " text=\"Default TTL entry\",\n", + " model_name=model_name,\n", + " embedding=embedding\n", + " # No TTL specified = uses the default 5 seconds\n", + ")\n", + "\n", + "# Wait for 2 seconds\n", + "time.sleep(2)\n", + "\n", + "# Check both entries\n", + "exists1 = ttl_cache.exists_by_key(key1)\n", + "exists2 = ttl_cache.exists_by_key(key2)\n", + "\n", + "print(f\"Entry with custom TTL after 2 seconds: {exists1}\")\n", + "print(f\"Entry with default TTL after 2 seconds: {exists2}\")\n", + "\n", + "# Cleanup\n", + "ttl_cache.drop_by_key(key2)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Support\n", + "\n", + "The `EmbeddingsCache` provides async versions of all methods for use in async applications. The async methods are prefixed with `a` (e.g., `aset`, `aget`, `aexists`, `adrop`)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Async set successful? True\n", + "Async get successful? True\n" + ] + } + ], + "source": [ + "async def async_cache_demo():\n", + " # Store an entry asynchronously\n", + " key = await cache.aset(\n", + " text=\"Async embedding\",\n", + " model_name=model_name,\n", + " embedding=embedding,\n", + " metadata={\"async\": True}\n", + " )\n", + " \n", + " # Check if it exists\n", + " exists = await cache.aexists_by_key(key)\n", + " print(f\"Async set successful? {exists}\")\n", + " \n", + " # Retrieve it\n", + " result = await cache.aget_by_key(key)\n", + " success = result is not None and result[\"text\"] == \"Async embedding\"\n", + " print(f\"Async get successful? {success}\")\n", + " \n", + " # Remove it\n", + " await cache.adrop_by_key(key)\n", + "\n", + "# Run the async demo\n", + "await async_cache_demo()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Real-World Example\n", + "\n", + "Let's build a simple embeddings caching system for a text classification task. We'll check the cache before computing new embeddings to save computation time." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Computing embedding for: What is artificial intelligence?\n", + "Computing embedding for: How does machine learning work?\n", + "Found in cache: What is artificial intelligence?\n", + "Computing embedding for: What are neural networks?\n", + "Found in cache: How does machine learning work?\n", + "\n", + "Statistics:\n", + "Total queries: 5\n", + "Cache hits: 2\n", + "Cache misses: 3\n", + "Cache hit rate: 40.0%\n" + ] + } + ], + "source": [ + "# Create a fresh cache for this example\n", + "example_cache = EmbeddingsCache(\n", + " name=\"example_cache\",\n", + " redis_url=\"redis://localhost:6379\",\n", + " ttl=3600 # 1 hour TTL\n", + ")\n", + "\n", + "# Function to get embedding with caching\n", + "def get_cached_embedding(text, model_name):\n", + " # Check if it's in the cache first\n", + " if cached_result := example_cache.get(text=text, model_name=model_name):\n", + " print(f\"Found in cache: {text}\")\n", + " return cached_result[\"embedding\"]\n", + " \n", + " # Not in cache, compute the embedding\n", + " print(f\"Computing embedding for: {text}\")\n", + " embedding = vectorizer.embed(text)\n", + " \n", + " # Store in cache\n", + " example_cache.set(\n", + " text=text,\n", + " model_name=model_name,\n", + " embedding=embedding,\n", + " )\n", + " \n", + " return embedding\n", + "\n", + "# Simulate processing a stream of queries\n", + "queries = [\n", + " \"What is artificial intelligence?\",\n", + " \"How does machine learning work?\",\n", + " \"What is artificial intelligence?\", # Repeated query\n", + " \"What are neural networks?\",\n", + " \"How does machine learning work?\" # Repeated query\n", + "]\n", + "\n", + "# Process the queries and track statistics\n", + "total_queries = 0\n", + "cache_hits = 0\n", + "\n", + "for query in queries:\n", + " total_queries += 1\n", + " \n", + " # Check cache before computing\n", + " before = example_cache.exists(text=query, model_name=model_name)\n", + " if before:\n", + " cache_hits += 1\n", + " \n", + " # Get embedding (will compute or use cache)\n", + " embedding = get_cached_embedding(query, model_name)\n", + "\n", + "# Report statistics\n", + "cache_misses = total_queries - cache_hits\n", + "hit_rate = (cache_hits / total_queries) * 100\n", + "\n", + "print(\"\\nStatistics:\")\n", + "print(f\"Total queries: {total_queries}\")\n", + "print(f\"Cache hits: {cache_hits}\")\n", + "print(f\"Cache misses: {cache_misses}\")\n", + "print(f\"Cache hit rate: {hit_rate:.1f}%\")\n", + "\n", + "# Cleanup\n", + "for query in set(queries): # Use set to get unique queries\n", + " example_cache.drop(text=query, model_name=model_name)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Performance Benchmark\n", + "\n", + "Let's run benchmarks to compare the performance of embedding with and without caching, as well as batch versus individual operations." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Benchmarking without caching:\n", + "Time taken without caching: 0.0940 seconds\n", + "Average time per embedding: 0.0094 seconds\n", + "\n", + "Benchmarking with caching:\n", + "Time taken with caching: 0.0237 seconds\n", + "Average time per embedding: 0.0024 seconds\n", + "\n", + "Performance comparison:\n", + "Speedup with caching: 3.96x faster\n", + "Time saved: 0.0703 seconds (74.8%)\n", + "Latency reduction: 0.0070 seconds per query\n" + ] + } + ], + "source": [ + "# Text to use for benchmarking\n", + "benchmark_text = \"This is a benchmark text to measure the performance of embedding caching.\"\n", + "benchmark_model = \"sentence-transformers/all-mpnet-base-v2\"\n", + "\n", + "# Create a fresh cache for benchmarking\n", + "benchmark_cache = EmbeddingsCache(\n", + " name=\"benchmark_cache\",\n", + " redis_url=\"redis://localhost:6379\",\n", + " ttl=3600 # 1 hour TTL\n", + ")\n", + "\n", + "# Function to get embeddings without caching\n", + "def get_embedding_without_cache(text, model_name):\n", + " return vectorizer.embed(text)\n", + "\n", + "# Function to get embeddings with caching\n", + "def get_embedding_with_cache(text, model_name):\n", + " if cached_result := benchmark_cache.get(text=text, model_name=model_name):\n", + " return cached_result[\"embedding\"]\n", + " \n", + " embedding = vectorizer.embed(text)\n", + " benchmark_cache.set(\n", + " text=text,\n", + " model_name=model_name,\n", + " embedding=embedding\n", + " )\n", + " return embedding\n", + "\n", + "# Number of iterations for the benchmark\n", + "n_iterations = 10\n", + "\n", + "# Benchmark without caching\n", + "print(\"Benchmarking without caching:\")\n", + "start_time = time.time()\n", + "get_embedding_without_cache(benchmark_text, benchmark_model)\n", + "no_cache_time = time.time() - start_time\n", + "print(f\"Time taken without caching: {no_cache_time:.4f} seconds\")\n", + "print(f\"Average time per embedding: {no_cache_time/n_iterations:.4f} seconds\")\n", + "\n", + "# Benchmark with caching\n", + "print(\"\\nBenchmarking with caching:\")\n", + "start_time = time.time()\n", + "get_embedding_with_cache(benchmark_text, benchmark_model)\n", + "cache_time = time.time() - start_time\n", + "print(f\"Time taken with caching: {cache_time:.4f} seconds\")\n", + "print(f\"Average time per embedding: {cache_time/n_iterations:.4f} seconds\")\n", + "\n", + "# Compare performance\n", + "speedup = no_cache_time / cache_time\n", + "latency_reduction = (no_cache_time/n_iterations) - (cache_time/n_iterations)\n", + "print(f\"\\nPerformance comparison:\")\n", + "print(f\"Speedup with caching: {speedup:.2f}x faster\")\n", + "print(f\"Time saved: {no_cache_time - cache_time:.4f} seconds ({(1 - cache_time/no_cache_time) * 100:.1f}%)\")\n", + "print(f\"Latency reduction: {latency_reduction:.4f} seconds per query\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Common Use Cases for Embedding Caching\n", + "\n", + "Embedding caching is particularly useful in the following scenarios:\n", + "\n", + "1. **Search applications**: Cache embeddings for frequently searched queries to reduce latency\n", + "2. **Content recommendation systems**: Cache embeddings for content items to speed up similarity calculations\n", + "3. **API services**: Reduce costs and improve response times when generating embeddings through paid APIs\n", + "4. **Batch processing**: Speed up processing of datasets that contain duplicate texts\n", + "5. **Chatbots and virtual assistants**: Cache embeddings for common user queries to provide faster responses\n", + "6. **Development** workflows" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cleanup\n", + "\n", + "Let's clean up our caches to avoid leaving data in Redis:" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "# Clean up all caches\n", + "cache.clear()\n", + "ttl_cache.clear()\n", + "example_cache.clear()\n", + "benchmark_cache.clear()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Summary\n", + "\n", + "The `EmbeddingsCache` provides an efficient way to store and retrieve embeddings with their associated text and metadata. Key features include:\n", + "\n", + "- Simple API for storing and retrieving individual embeddings (`set`/`get`)\n", + "- Batch operations for working with multiple embeddings efficiently (`mset`/`mget`/`mexists`/`mdrop`)\n", + "- Support for metadata storage alongside embeddings\n", + "- Configurable time-to-live (TTL) for cache entries\n", + "- Key-based operations for advanced use cases\n", + "- Async support for use in asynchronous applications\n", + "- Significant performance improvements (15-20x faster with batch operations)\n", + "\n", + "By using the `EmbeddingsCache`, you can reduce computational costs and improve the performance of applications that rely on embeddings." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/user_guide/index.md b/docs/user_guide/index.md index 30a51b8a..4d6e5c04 100644 --- a/docs/user_guide/index.md +++ b/docs/user_guide/index.md @@ -15,6 +15,7 @@ User guides provide helpful resources for using RedisVL and its different compon 01_getting_started 02_hybrid_queries 03_llmcache +10_embeddings_cache 04_vectorizers 05_hash_vs_json 06_rerankers diff --git a/docs/user_guide/release_guide/0_5_0_release.ipynb b/docs/user_guide/release_guide/0_5_0_release.ipynb index fa9d06ed..ae4f35de 100644 --- a/docs/user_guide/release_guide/0_5_0_release.ipynb +++ b/docs/user_guide/release_guide/0_5_0_release.ipynb @@ -248,7 +248,7 @@ ], "source": [ "from redisvl.utils.optimize import CacheThresholdOptimizer\n", - "from redisvl.extensions.llmcache import SemanticCache\n", + "from redisvl.extensions.cache.llm import SemanticCache\n", "\n", "sem_cache = SemanticCache(\n", " name=\"sem_cache\", # underlying search index name\n", diff --git a/redisvl/extensions/__init__.py b/redisvl/extensions/__init__.py index e69de29b..8b137891 100644 --- a/redisvl/extensions/__init__.py +++ b/redisvl/extensions/__init__.py @@ -0,0 +1 @@ + diff --git a/redisvl/extensions/cache/__init__.py b/redisvl/extensions/cache/__init__.py new file mode 100644 index 00000000..f3ce33ce --- /dev/null +++ b/redisvl/extensions/cache/__init__.py @@ -0,0 +1,12 @@ +""" +Redis Vector Library Cache Extensions + +This module provides caching functionality for Redis Vector Library, +including both embedding caches and LLM response caches. +""" + +from redisvl.extensions.cache.base import BaseCache +from redisvl.extensions.cache.embeddings import EmbeddingsCache +from redisvl.extensions.cache.llm import SemanticCache + +__all__ = ["BaseCache", "EmbeddingsCache", "SemanticCache"] diff --git a/redisvl/extensions/cache/base.py b/redisvl/extensions/cache/base.py new file mode 100644 index 00000000..68e120e4 --- /dev/null +++ b/redisvl/extensions/cache/base.py @@ -0,0 +1,222 @@ +"""Base cache interface for RedisVL. + +This module defines the abstract base cache interface that is implemented by +specific cache types such as LLM caches and embedding caches. +""" + +from typing import Any, Dict, Optional + +from redis import Redis +from redis.asyncio import Redis as AsyncRedis + + +class BaseCache: + """Base abstract cache interface for all RedisVL caches. + + This class defines common functionality shared by all cache implementations, + including TTL management, connection handling, and basic cache operations. + """ + + _redis_client: Optional[Redis] + _async_redis_client: Optional[AsyncRedis] + + def __init__( + self, + name: str, + ttl: Optional[int] = None, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + ): + """Initialize a base cache. + + Args: + name (str): The name of the cache. + ttl (Optional[int], optional): The time-to-live for records cached + in Redis. Defaults to None. + redis_client (Optional[Redis], optional): A redis client connection instance. + Defaults to None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. + """ + self.name = name + self._ttl: Optional[int] = None + self.set_ttl(ttl) + + self.redis_kwargs = { + "redis_client": redis_client, + "redis_url": redis_url, + "connection_kwargs": connection_kwargs, + } + + # Initialize Redis clients + self._async_redis_client = None + + if redis_client: + self._owns_redis_client = False + self._redis_client = redis_client + else: + self._owns_redis_client = True + self._redis_client = None # type: ignore + + def _get_prefix(self) -> str: + """Get the key prefix for Redis keys. + + Returns: + str: The prefix to use for Redis keys. + """ + return f"{self.name}:" + + def _make_key(self, entry_id: str) -> str: + """Generate a full Redis key for the given entry ID. + + Args: + entry_id (str): The unique entry ID. + + Returns: + str: The full Redis key including prefix. + """ + return f"{self._get_prefix()}{entry_id}" + + @property + def ttl(self) -> Optional[int]: + """The default TTL, in seconds, for entries in the cache.""" + return self._ttl + + def set_ttl(self, ttl: Optional[int] = None) -> None: + """Set the default TTL, in seconds, for entries in the cache. + + Args: + ttl (Optional[int], optional): The optional time-to-live expiration + for the cache, in seconds. + + Raises: + ValueError: If the time-to-live value is not an integer. + """ + if ttl: + if not isinstance(ttl, int): + raise ValueError(f"TTL must be an integer value, got {ttl}") + self._ttl = int(ttl) + else: + self._ttl = None + + def _get_redis_client(self) -> Redis: + """Get or create a Redis client. + + Returns: + Redis: A Redis client instance. + """ + if self._redis_client is None: + # Create new Redis client + url = self.redis_kwargs["redis_url"] + kwargs = self.redis_kwargs["connection_kwargs"] + self._redis_client = Redis.from_url(url, **kwargs) # type: ignore + return self._redis_client + + async def _get_async_redis_client(self) -> AsyncRedis: + """Get or create an async Redis client. + + Returns: + AsyncRedis: An async Redis client instance. + """ + if not hasattr(self, "_async_redis_client") or self._async_redis_client is None: + # Create new async Redis client + url = self.redis_kwargs["redis_url"] + kwargs = self.redis_kwargs["connection_kwargs"] + self._async_redis_client = AsyncRedis.from_url(url, **kwargs) # type: ignore + return self._async_redis_client + + def expire(self, key: str, ttl: Optional[int] = None) -> None: + """Set or refresh the expiration time for a key in the cache. + + Args: + key (str): The Redis key to set the expiration on. + ttl (Optional[int], optional): The time-to-live in seconds. If None, + uses the default TTL configured for this cache instance. + Defaults to None. + + Note: + If neither the provided TTL nor the default TTL is set (both are None), + this method will have no effect. + """ + _ttl = ttl if ttl is not None else self._ttl + if _ttl: + client = self._get_redis_client() + client.expire(key, _ttl) + + async def aexpire(self, key: str, ttl: Optional[int] = None) -> None: + """Asynchronously set or refresh the expiration time for a key in the cache. + + Args: + key (str): The Redis key to set the expiration on. + ttl (Optional[int], optional): The time-to-live in seconds. If None, + uses the default TTL configured for this cache instance. + Defaults to None. + + Note: + If neither the provided TTL nor the default TTL is set (both are None), + this method will have no effect. + """ + _ttl = ttl if ttl is not None else self._ttl + if _ttl: + client = await self._get_async_redis_client() + await client.expire(key, _ttl) + + def clear(self) -> None: + """Clear the cache of all keys.""" + client = self._get_redis_client() + prefix = self._get_prefix() + + # Scan for all keys with our prefix + cursor = 0 # Start with cursor 0 + while True: + cursor_int, keys = client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore + if keys: + client.delete(*keys) + if cursor_int == 0: # Redis returns 0 when scan is complete + break + cursor = cursor_int # Update cursor for next iteration + + async def aclear(self) -> None: + """Async clear the cache of all keys.""" + client = await self._get_async_redis_client() + prefix = self._get_prefix() + + # Scan for all keys with our prefix + cursor = 0 # Start with cursor 0 + while True: + cursor_int, keys = await client.scan(cursor=cursor, match=f"{prefix}*", count=100) # type: ignore + if keys: + await client.delete(*keys) + if cursor_int == 0: # Redis returns 0 when scan is complete + break + cursor = cursor_int # Update cursor for next iteration + + def disconnect(self) -> None: + """Disconnect from Redis.""" + if self._owns_redis_client is False: + return + + if self._redis_client: + self._redis_client.close() + self._redis_client = None # type: ignore + + if hasattr(self, "_async_redis_client") and self._async_redis_client: + # Use synchronous close for async client in synchronous context + self._async_redis_client.close() # type: ignore + self._async_redis_client = None # type: ignore + + async def adisconnect(self) -> None: + """Async disconnect from Redis.""" + if self._owns_redis_client is False: + return + + if self._redis_client: + self._redis_client.close() + self._redis_client = None # type: ignore + + if hasattr(self, "_async_redis_client") and self._async_redis_client: + # Use proper async close method + await self._async_redis_client.aclose() # type: ignore + self._async_redis_client = None # type: ignore diff --git a/redisvl/extensions/cache/embeddings/__init__.py b/redisvl/extensions/cache/embeddings/__init__.py new file mode 100644 index 00000000..5c6fd3b7 --- /dev/null +++ b/redisvl/extensions/cache/embeddings/__init__.py @@ -0,0 +1,10 @@ +""" +Redis Vector Library - Embeddings Cache Extensions + +This module provides embedding caching functionality for RedisVL. +""" + +from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache +from redisvl.extensions.cache.embeddings.schema import CacheEntry + +__all__ = ["EmbeddingsCache", "CacheEntry"] diff --git a/redisvl/extensions/cache/embeddings/embeddings.py b/redisvl/extensions/cache/embeddings/embeddings.py new file mode 100644 index 00000000..795096bd --- /dev/null +++ b/redisvl/extensions/cache/embeddings/embeddings.py @@ -0,0 +1,936 @@ +"""Embeddings cache implementation for RedisVL.""" + +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +from redis import Redis +from redis.asyncio import Redis as AsyncRedis + +from redisvl.extensions.cache.base import BaseCache +from redisvl.extensions.cache.embeddings.schema import CacheEntry +from redisvl.redis.utils import convert_bytes, hashify + + +class EmbeddingsCache(BaseCache): + """Embeddings Cache for storing embedding vectors with exact key matching.""" + + def __init__( + self, + name: str = "embedcache", + ttl: Optional[int] = None, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + ): + """Initialize an embeddings cache. + + Args: + name (str): The name of the cache. Defaults to "embedcache". + ttl (Optional[int]): The time-to-live for cached embeddings. Defaults to None. + redis_client (Optional[Redis]): Redis client instance. Defaults to None. + redis_url (str): Redis URL for connection. Defaults to "redis://localhost:6379". + connection_kwargs (Dict[str, Any]): Redis connection arguments. Defaults to {}. + + Raises: + ValueError: If vector dimensions are invalid + + .. code-block:: python + + cache = EmbeddingsCache( + name="my_embeddings_cache", + ttl=3600, # 1 hour + redis_url="redis://localhost:6379" + ) + """ + super().__init__( + name=name, + ttl=ttl, + redis_client=redis_client, + redis_url=redis_url, + connection_kwargs=connection_kwargs, + ) + + def _make_entry_id(self, text: str, model_name: str) -> str: + """Generate a deterministic entry ID for the given text and model name. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + str: A deterministic entry ID based on the text and model name. + """ + return hashify(f"{text}:{model_name}") + + def _make_cache_key(self, text: str, model_name: str) -> str: + """Generate a full Redis key for the given text and model name. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + str: The full Redis key. + """ + entry_id = self._make_entry_id(text, model_name) + return self._make_key(entry_id) + + def _prepare_entry_data( + self, + text: str, + model_name: str, + embedding: List[float], + metadata: Optional[Dict[str, Any]] = None, + ) -> Tuple[str, Dict[str, Any]]: + """Prepare data for storage in Redis + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + embedding (List[float]): The embedding vector. + metadata (Optional[Dict[str, Any]]): Optional metadata. + + Returns: + Tuple[str, Dict[str, Any]]: A tuple of (key, entry_data) + """ + # Create cache entry with entry_id + entry_id = self._make_entry_id(text, model_name) + key = self._make_key(entry_id) + entry = CacheEntry( + entry_id=entry_id, + text=text, + model_name=model_name, + embedding=embedding, + metadata=metadata, + ) + return key, entry.to_dict() + + def _process_cache_data( + self, data: Optional[Dict[str, Any]] + ) -> Optional[Dict[str, Any]]: + """Process Redis hash data into a cache entry response. + + Args: + data (Optional[Dict[str, Any]]): Raw Redis hash data. + + Returns: + Optional[Dict[str, Any]]: Processed cache entry or None if no data. + """ + if not data: + return None + + cache_hit = CacheEntry(**convert_bytes(data)) + return cache_hit.model_dump(exclude_none=True) + + def get( + self, + text: str, + model_name: str, + ) -> Optional[Dict[str, Any]]: + """Get embedding by text and model name. + + Retrieves a cached embedding for the given text and model name. + If found, refreshes the TTL of the entry. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + Optional[Dict[str, Any]]: Embedding cache entry or None if not found. + + .. code-block:: python + + embedding_data = cache.get( + text="What is machine learning?", + model_name="text-embedding-ada-002" + ) + """ + key = self._make_cache_key(text, model_name) + return self.get_by_key(key) + + def get_by_key(self, key: str) -> Optional[Dict[str, Any]]: + """Get embedding by its full Redis key. + + Retrieves a cached embedding for the given Redis key. + If found, refreshes the TTL of the entry. + + Args: + key (str): The full Redis key for the embedding. + + Returns: + Optional[Dict[str, Any]]: Embedding cache entry or None if not found. + + .. code-block:: python + + embedding_data = cache.get_by_key("embedcache:1234567890abcdef") + """ + client = self._get_redis_client() + + # Get all fields + data = client.hgetall(key) + + # Refresh TTL if data exists + if data: + self.expire(key) + + return self._process_cache_data(data) + + def mget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]: + """Get multiple embeddings by their Redis keys. + + Efficiently retrieves multiple cached embeddings in a single network roundtrip. + If found, refreshes the TTL of each entry. + + Args: + keys (List[str]): List of Redis keys to retrieve. + + Returns: + List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for keys not found. + The order matches the input keys order. + + .. code-block:: python + + # Get multiple embeddings + embedding_data = cache.mget_by_keys([ + "embedcache:key1", + "embedcache:key2" + ]) + """ + if not keys: + return [] + + client = self._get_redis_client() + + with client.pipeline(transaction=False) as pipeline: + # Queue all hgetall operations + for key in keys: + pipeline.hgetall(key) + results = pipeline.execute() + + # Process results + processed_results = [] + for i, result in enumerate(results): + if result: # If cache hit, refresh TTL separately + self.expire(keys[i]) + processed_results.append(self._process_cache_data(result)) + + return processed_results + + def mget(self, texts: List[str], model_name: str) -> List[Optional[Dict[str, Any]]]: + """Get multiple embeddings by their texts and model name. + + Efficiently retrieves multiple cached embeddings in a single operation. + If found, refreshes the TTL of each entry. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + Returns: + List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for texts not found. + + .. code-block:: python + + # Get multiple embeddings + embedding_data = cache.mget( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return [] + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + return self.mget_by_keys(keys) + + def set( + self, + text: str, + model_name: str, + embedding: List[float], + metadata: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Store an embedding with its text and model name. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + embedding (List[float]): The embedding vector to store. + metadata (Optional[Dict[str, Any]]): Optional metadata to store with the embedding. + ttl (Optional[int]): Optional TTL override for this specific entry. + + Returns: + str: The Redis key where the embedding was stored. + + .. code-block:: python + + key = cache.set( + text="What is machine learning?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3, ...], + metadata={"source": "user_query"} + ) + """ + # Prepare data + key, cache_entry = self._prepare_entry_data( + text, model_name, embedding, metadata + ) + + # Store in Redis + client = self._get_redis_client() + client.hset(name=key, mapping=cache_entry) # type: ignore + + # Set TTL if specified + self.expire(key, ttl) + + return key + + def mset( + self, + items: List[Dict[str, Any]], + ttl: Optional[int] = None, + ) -> List[str]: + """Store multiple embeddings in a batch operation. + + Each item in the input list should be a dictionary with the following fields: + - 'text': The text input that was embedded + - 'model_name': The name of the embedding model + - 'embedding': The embedding vector + - 'metadata': Optional metadata to store with the embedding + + Args: + items: List of dictionaries, each containing text, model_name, embedding, and optional metadata. + ttl: Optional TTL override for these entries. + + Returns: + List[str]: List of Redis keys where the embeddings were stored. + + .. code-block:: python + + # Store multiple embeddings + keys = cache.mset([ + { + "text": "What is ML?", + "model_name": "text-embedding-ada-002", + "embedding": [0.1, 0.2, 0.3], + "metadata": {"source": "user"} + }, + { + "text": "What is AI?", + "model_name": "text-embedding-ada-002", + "embedding": [0.4, 0.5, 0.6], + "metadata": {"source": "docs"} + } + ]) + """ + if not items: + return [] + + client = self._get_redis_client() + keys = [] + + with client.pipeline(transaction=False) as pipeline: + # Process all entries + for item in items: + # Prepare and store + key, cache_entry = self._prepare_entry_data(**item) + keys.append(key) + pipeline.hset(name=key, mapping=cache_entry) # type: ignore + + pipeline.execute() + + # Set TTLs + for key in keys: + self.expire(key, ttl) + + return keys + + def exists(self, text: str, model_name: str) -> bool: + """Check if an embedding exists for the given text and model. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + bool: True if the embedding exists in the cache, False otherwise. + + .. code-block:: python + + if cache.exists("What is machine learning?", "text-embedding-ada-002"): + print("Embedding is in cache") + """ + client = self._get_redis_client() + key = self._make_cache_key(text, model_name) + return bool(client.exists(key)) + + def exists_by_key(self, key: str) -> bool: + """Check if an embedding exists for the given Redis key. + + Args: + key (str): The full Redis key for the embedding. + + Returns: + bool: True if the embedding exists in the cache, False otherwise. + + .. code-block:: python + + if cache.exists_by_key("embedcache:1234567890abcdef"): + print("Embedding is in cache") + """ + client = self._get_redis_client() + return bool(client.exists(key)) + + def mexists_by_keys(self, keys: List[str]) -> List[bool]: + """Check if multiple embeddings exist by their Redis keys. + + Efficiently checks existence of multiple keys in a single operation. + + Args: + keys (List[str]): List of Redis keys to check. + + Returns: + List[bool]: List of boolean values indicating whether each key exists. + The order matches the input keys order. + + .. code-block:: python + + # Check if multiple keys exist + exists_results = cache.mexists_by_keys(["embedcache:key1", "embedcache:key2"]) + """ + if not keys: + return [] + + client = self._get_redis_client() + + with client.pipeline(transaction=False) as pipeline: + # Queue all exists operations + for key in keys: + pipeline.exists(key) + results = pipeline.execute() + + # Convert to boolean values + return [bool(result) for result in results] + + def mexists(self, texts: List[str], model_name: str) -> List[bool]: + """Check if multiple embeddings exist by their texts and model name. + + Efficiently checks existence of multiple embeddings in a single operation. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + Returns: + List[bool]: List of boolean values indicating whether each embedding exists. + + .. code-block:: python + + # Check if multiple embeddings exist + exists_results = cache.mexists( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return [] + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + return self.mexists_by_keys(keys) + + def drop(self, text: str, model_name: str) -> None: + """Remove an embedding from the cache. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + .. code-block:: python + + cache.drop( + text="What is machine learning?", + model_name="text-embedding-ada-002" + ) + """ + key = self._make_cache_key(text, model_name) + self.drop_by_key(key) + + def drop_by_key(self, key: str) -> None: + """Remove an embedding from the cache by its Redis key. + + Args: + key (str): The full Redis key for the embedding. + + .. code-block:: python + + cache.drop_by_key("embedcache:1234567890abcdef") + """ + client = self._get_redis_client() + client.delete(key) + + def mdrop_by_keys(self, keys: List[str]) -> None: + """Remove multiple embeddings from the cache by their Redis keys. + + Efficiently removes multiple embeddings in a single operation. + + Args: + keys (List[str]): List of Redis keys to remove. + + .. code-block:: python + + # Remove multiple embeddings + cache.mdrop_by_keys(["embedcache:key1", "embedcache:key2"]) + """ + if not keys: + return + + client = self._get_redis_client() + + with client.pipeline(transaction=False) as pipeline: + for key in keys: + pipeline.delete(key) + pipeline.execute() + + def mdrop(self, texts: List[str], model_name: str) -> None: + """Remove multiple embeddings from the cache by their texts and model name. + + Efficiently removes multiple embeddings in a single operation. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + .. code-block:: python + + # Remove multiple embeddings + cache.mdrop( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + self.mdrop_by_keys(keys) + + async def aget( + self, + text: str, + model_name: str, + ) -> Optional[Dict[str, Any]]: + """Async get embedding by text and model name. + + Asynchronously retrieves a cached embedding for the given text and model name. + If found, refreshes the TTL of the entry. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + Optional[Dict[str, Any]]: Embedding cache entry or None if not found. + + .. code-block:: python + + embedding_data = await cache.aget( + text="What is machine learning?", + model_name="text-embedding-ada-002" + ) + """ + key = self._make_cache_key(text, model_name) + return await self.aget_by_key(key) + + async def aget_by_key(self, key: str) -> Optional[Dict[str, Any]]: + """Async get embedding by its full Redis key. + + Asynchronously retrieves a cached embedding for the given Redis key. + If found, refreshes the TTL of the entry. + + Args: + key (str): The full Redis key for the embedding. + + Returns: + Optional[Dict[str, Any]]: Embedding cache entry or None if not found. + + .. code-block:: python + + embedding_data = await cache.aget_by_key("embedcache:1234567890abcdef") + """ + client = await self._get_async_redis_client() + + # Get all fields + data = await client.hgetall(key) + + # Refresh TTL if data exists + if data: + await self.aexpire(key) + + return self._process_cache_data(data) + + async def amget_by_keys(self, keys: List[str]) -> List[Optional[Dict[str, Any]]]: + """Async get multiple embeddings by their Redis keys. + + Asynchronously retrieves multiple cached embeddings in a single network roundtrip. + If found, refreshes the TTL of each entry. + + Args: + keys (List[str]): List of Redis keys to retrieve. + + Returns: + List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for keys not found. + The order matches the input keys order. + + .. code-block:: python + + # Get multiple embeddings asynchronously + embedding_data = await cache.amget_by_keys([ + "embedcache:key1", + "embedcache:key2" + ]) + """ + if not keys: + return [] + + client = await self._get_async_redis_client() + + # Use pipeline only for retrieval + async with client.pipeline(transaction=False) as pipeline: + # Queue all hgetall operations + for key in keys: + await pipeline.hgetall(key) + results = await pipeline.execute() + + # Process results and refresh TTLs separately + processed_results = [] + for i, result in enumerate(results): + if result: # If cache hit, refresh TTL + await self.aexpire(keys[i]) + processed_results.append(self._process_cache_data(result)) + + return processed_results + + async def amget( + self, texts: List[str], model_name: str + ) -> List[Optional[Dict[str, Any]]]: + """Async get multiple embeddings by their texts and model name. + + Asynchronously retrieves multiple cached embeddings in a single operation. + If found, refreshes the TTL of each entry. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + Returns: + List[Optional[Dict[str, Any]]]: List of embedding cache entries or None for texts not found. + + .. code-block:: python + + # Get multiple embeddings asynchronously + embedding_data = await cache.amget( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return [] + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + return await self.amget_by_keys(keys) + + async def aset( + self, + text: str, + model_name: str, + embedding: List[float], + metadata: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Async store an embedding with its text and model name. + + Asynchronously stores an embedding with its text and model name. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + embedding (List[float]): The embedding vector to store. + metadata (Optional[Dict[str, Any]]): Optional metadata to store with the embedding. + ttl (Optional[int]): Optional TTL override for this specific entry. + + Returns: + str: The Redis key where the embedding was stored. + + .. code-block:: python + + key = await cache.aset( + text="What is machine learning?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3, ...], + metadata={"source": "user_query"} + ) + """ + # Prepare data + key, cache_entry = self._prepare_entry_data( + text, model_name, embedding, metadata + ) + + # Store in Redis + client = await self._get_async_redis_client() + await client.hset(name=key, mapping=cache_entry) # type: ignore + + # Set TTL if specified + await self.aexpire(key, ttl) + + return key + + async def amset( + self, + items: List[Dict[str, Any]], + ttl: Optional[int] = None, + ) -> List[str]: + """Async store multiple embeddings in a batch operation. + + Each item in the input list should be a dictionary with the following fields: + - 'text': The text input that was embedded + - 'model_name': The name of the embedding model + - 'embedding': The embedding vector + - 'metadata': Optional metadata to store with the embedding + + Args: + items: List of dictionaries, each containing text, model_name, embedding, and optional metadata. + ttl: Optional TTL override for these entries. + + Returns: + List[str]: List of Redis keys where the embeddings were stored. + + .. code-block:: python + + # Store multiple embeddings asynchronously + keys = await cache.amset([ + { + "text": "What is ML?", + "model_name": "text-embedding-ada-002", + "embedding": [0.1, 0.2, 0.3], + "metadata": {"source": "user"} + }, + { + "text": "What is AI?", + "model_name": "text-embedding-ada-002", + "embedding": [0.4, 0.5, 0.6], + "metadata": {"source": "docs"} + } + ]) + """ + if not items: + return [] + + client = await self._get_async_redis_client() + keys = [] + + async with client.pipeline(transaction=False) as pipeline: + # Process all entries + for item in items: + # Prepare and store + key, cache_entry = self._prepare_entry_data(**item) + keys.append(key) + await pipeline.hset(name=key, mapping=cache_entry) # type: ignore + + await pipeline.execute() + + # Set TTLs + for key in keys: + await self.aexpire(key, ttl) + + return keys + + async def amexists_by_keys(self, keys: List[str]) -> List[bool]: + """Async check if multiple embeddings exist by their Redis keys. + + Asynchronously checks existence of multiple keys in a single operation. + + Args: + keys (List[str]): List of Redis keys to check. + + Returns: + List[bool]: List of boolean values indicating whether each key exists. + The order matches the input keys order. + + .. code-block:: python + + # Check if multiple keys exist asynchronously + exists_results = await cache.amexists_by_keys(["embedcache:key1", "embedcache:key2"]) + """ + if not keys: + return [] + + client = await self._get_async_redis_client() + + async with client.pipeline(transaction=False) as pipeline: + # Queue all exists operations + for key in keys: + await pipeline.exists(key) + results = await pipeline.execute() + + # Convert to boolean values + return [bool(result) for result in results] + + async def amexists(self, texts: List[str], model_name: str) -> List[bool]: + """Async check if multiple embeddings exist by their texts and model name. + + Asynchronously checks existence of multiple embeddings in a single operation. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + Returns: + List[bool]: List of boolean values indicating whether each embedding exists. + + .. code-block:: python + + # Check if multiple embeddings exist asynchronously + exists_results = await cache.amexists( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return [] + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + return await self.amexists_by_keys(keys) + + async def amdrop_by_keys(self, keys: List[str]) -> None: + """Async remove multiple embeddings from the cache by their Redis keys. + + Asynchronously removes multiple embeddings in a single operation. + + Args: + keys (List[str]): List of Redis keys to remove. + + .. code-block:: python + + # Remove multiple embeddings asynchronously + await cache.amdrop_by_keys(["embedcache:key1", "embedcache:key2"]) + """ + if not keys: + return + + client = await self._get_async_redis_client() + await client.delete(*keys) + + async def amdrop(self, texts: List[str], model_name: str) -> None: + """Async remove multiple embeddings from the cache by their texts and model name. + + Asynchronously removes multiple embeddings in a single operation. + + Args: + texts (List[str]): List of text inputs that were embedded. + model_name (str): The name of the embedding model. + + .. code-block:: python + + # Remove multiple embeddings asynchronously + await cache.amdrop( + texts=["What is machine learning?", "What is deep learning?"], + model_name="text-embedding-ada-002" + ) + """ + if not texts: + return + + # Generate keys for each text + keys = [self._make_cache_key(text, model_name) for text in texts] + + # Use the key-based batch operation + await self.amdrop_by_keys(keys) + + async def aexists(self, text: str, model_name: str) -> bool: + """Async check if an embedding exists. + + Asynchronously checks if an embedding exists for the given text and model. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + Returns: + bool: True if the embedding exists in the cache, False otherwise. + + .. code-block:: python + + if await cache.aexists("What is machine learning?", "text-embedding-ada-002"): + print("Embedding is in cache") + """ + key = self._make_cache_key(text, model_name) + return await self.aexists_by_key(key) + + async def aexists_by_key(self, key: str) -> bool: + """Async check if an embedding exists for the given Redis key. + + Asynchronously checks if an embedding exists for the given Redis key. + + Args: + key (str): The full Redis key for the embedding. + + Returns: + bool: True if the embedding exists in the cache, False otherwise. + + .. code-block:: python + + if await cache.aexists_by_key("embedcache:1234567890abcdef"): + print("Embedding is in cache") + """ + client = await self._get_async_redis_client() + return bool(await client.exists(key)) + + async def adrop(self, text: str, model_name: str) -> None: + """Async remove an embedding from the cache. + + Asynchronously removes an embedding from the cache. + + Args: + text (str): The text input that was embedded. + model_name (str): The name of the embedding model. + + .. code-block:: python + + await cache.adrop( + text="What is machine learning?", + model_name="text-embedding-ada-002" + ) + """ + key = self._make_cache_key(text, model_name) + await self.adrop_by_key(key) + + async def adrop_by_key(self, key: str) -> None: + """Async remove an embedding from the cache by its Redis key. + + Asynchronously removes an embedding from the cache by its Redis key. + + Args: + key (str): The full Redis key for the embedding. + + .. code-block:: python + + await cache.adrop_by_key("embedcache:1234567890abcdef") + """ + client = await self._get_async_redis_client() + await client.delete(key) diff --git a/redisvl/extensions/cache/embeddings/schema.py b/redisvl/extensions/cache/embeddings/schema.py new file mode 100644 index 00000000..182351c0 --- /dev/null +++ b/redisvl/extensions/cache/embeddings/schema.py @@ -0,0 +1,53 @@ +"""Schema definitions for embeddings cache in RedisVL. + +This module defines the Pydantic models used for embedding cache entries and +related data structures. +""" + +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, model_validator + +from redisvl.extensions.constants import EMBEDDING_FIELD_NAME, METADATA_FIELD_NAME +from redisvl.utils.utils import current_timestamp, deserialize, serialize + + +class CacheEntry(BaseModel): + """Embedding cache entry data model""" + + entry_id: str + """Cache entry identifier""" + text: str + """The text input that was embedded""" + model_name: str + """The name of the embedding model used""" + embedding: List[float] + """The embedding vector representation""" + inserted_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was added to the cache""" + metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" + + @model_validator(mode="before") + @classmethod + def deserialize_cache_entry(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # Deserialize metadata if necessary + if METADATA_FIELD_NAME in values and isinstance( + values[METADATA_FIELD_NAME], str + ): + values[METADATA_FIELD_NAME] = deserialize(values[METADATA_FIELD_NAME]) + # Deserialize embeddings if necessary + if EMBEDDING_FIELD_NAME in values and isinstance( + values[EMBEDDING_FIELD_NAME], str + ): + values[EMBEDDING_FIELD_NAME] = deserialize(values[EMBEDDING_FIELD_NAME]) + + return values + + def to_dict(self) -> Dict[str, Any]: + """Convert the cache entry to a dictionary for storage""" + data = self.model_dump(exclude_none=True) + data[EMBEDDING_FIELD_NAME] = serialize(self.embedding) + if self.metadata is not None: + data[METADATA_FIELD_NAME] = serialize(self.metadata) + return data diff --git a/redisvl/extensions/cache/llm/__init__.py b/redisvl/extensions/cache/llm/__init__.py new file mode 100644 index 00000000..8381b3bf --- /dev/null +++ b/redisvl/extensions/cache/llm/__init__.py @@ -0,0 +1,14 @@ +""" +Redis Vector Library - LLM Cache Extensions + +This module provides LLM cache implementations for RedisVL. +""" + +from redisvl.extensions.cache.llm.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) +from redisvl.extensions.cache.llm.semantic import SemanticCache + +__all__ = ["SemanticCache", "CacheEntry", "CacheHit", "SemanticCacheIndexSchema"] diff --git a/redisvl/extensions/cache/llm/base.py b/redisvl/extensions/cache/llm/base.py new file mode 100644 index 00000000..1d421b2a --- /dev/null +++ b/redisvl/extensions/cache/llm/base.py @@ -0,0 +1,121 @@ +"""Base LLM cache interface for RedisVL. + +This module defines the abstract base interface for LLM caches, which store +prompt-response pairs with semantic retrieval capabilities. +""" + +from typing import Any, Dict, List, Optional + +from redisvl.extensions.cache.base import BaseCache +from redisvl.query.filter import FilterExpression + + +class BaseLLMCache(BaseCache): + """Base abstract LLM cache interface. + + This class defines the core functionality for caching LLM responses + with semantic similarity search capabilities. + """ + + def __init__(self, name: str, ttl: Optional[int] = None, **kwargs): + """Initialize an LLM cache. + + Args: + name (str): The name of the cache. + ttl (Optional[int]): The time-to-live for cached responses. Defaults to None. + **kwargs: Additional arguments passed to the parent class. + """ + super().__init__(name=name, ttl=ttl, **kwargs) + + def delete(self) -> None: + """Delete the cache and its index entirely.""" + raise NotImplementedError + + async def adelete(self) -> None: + """Async delete the cache and its index entirely.""" + raise NotImplementedError + + def check( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Check the cache for semantically similar prompts. + + Args: + prompt (Optional[str]): The text prompt to search for in the cache. + vector (Optional[List[float]]): Vector representation to search for. + num_results (int): Number of results to return. Defaults to 1. + return_fields (Optional[List[str]]): Fields to return in results. + filter_expression (Optional[FilterExpression]): Optional filter to apply. + distance_threshold (Optional[float]): Override for semantic distance threshold. + + Returns: + List[Dict[str, Any]]: List of matching cache entries. + """ + raise NotImplementedError + + async def acheck( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Async check the cache for semantically similar prompts.""" + raise NotImplementedError + + def store( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Store a prompt-response pair in the cache. + + Args: + prompt (str): The user prompt to cache. + response (str): The LLM response to cache. + vector (Optional[List[float]]): Optional embedding vector. + metadata (Optional[Dict[str, Any]]): Optional metadata. + filters (Optional[Dict[str, Any]]): Optional filters for retrieval. + ttl (Optional[int]): Optional TTL override. + + Returns: + str: The Redis key for the cached entry. + """ + raise NotImplementedError + + async def astore( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Async store a prompt-response pair in the cache.""" + raise NotImplementedError + + def update(self, key: str, **kwargs) -> None: + """Update specific fields within an existing cache entry. + + Args: + key (str): The key of the document to update. + **kwargs: Field-value pairs to update. + """ + raise NotImplementedError + + async def aupdate(self, key: str, **kwargs) -> None: + """Async update specific fields within an existing cache entry.""" + raise NotImplementedError diff --git a/redisvl/extensions/cache/llm/schema.py b/redisvl/extensions/cache/llm/schema.py new file mode 100644 index 00000000..fa6f720a --- /dev/null +++ b/redisvl/extensions/cache/llm/schema.py @@ -0,0 +1,136 @@ +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator + +from redisvl.extensions.constants import ( + CACHE_VECTOR_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + PROMPT_FIELD_NAME, + RESPONSE_FIELD_NAME, + UPDATED_AT_FIELD_NAME, +) +from redisvl.redis.utils import array_to_buffer, hashify +from redisvl.schema import IndexSchema +from redisvl.utils.utils import current_timestamp, deserialize, serialize + + +class CacheEntry(BaseModel): + """A single cache entry in Redis""" + + entry_id: Optional[str] = Field(default=None) + """Cache entry identifier""" + prompt: str + """Input prompt or question cached in Redis""" + response: str + """Response or answer to the question, cached in Redis""" + prompt_vector: List[float] + """Text embedding representation of the prompt""" + inserted_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was added to the cache""" + updated_at: float = Field(default_factory=current_timestamp) + """Timestamp of when the entry was updated in the cache""" + metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" + filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" + + @model_validator(mode="before") + @classmethod + def generate_id(cls, values): + # Ensure entry_id is set + if not values.get("entry_id"): + values["entry_id"] = hashify(values["prompt"], values.get("filters")) + return values + + @field_validator("metadata") + @classmethod + def non_empty_metadata(cls, v): + if v is not None and not isinstance(v, dict): + raise TypeError("Metadata must be a dictionary.") + return v + + def to_dict(self, dtype: str) -> Dict: + data = self.model_dump(exclude_none=True) + data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) + if self.metadata is not None: + data["metadata"] = serialize(self.metadata) + if self.filters is not None: + data.update(self.filters) + del data["filters"] + return data + + +class CacheHit(BaseModel): + """A cache hit based on some input query""" + + entry_id: str + """Cache entry identifier""" + prompt: str + """Input prompt or question cached in Redis""" + response: str + """Response or answer to the question, cached in Redis""" + vector_distance: float + """The semantic distance between the query vector and the stored prompt vector""" + inserted_at: float + """Timestamp of when the entry was added to the cache""" + updated_at: float + """Timestamp of when the entry was updated in the cache""" + metadata: Optional[Dict[str, Any]] = Field(default=None) + """Optional metadata stored on the cache entry""" + filters: Optional[Dict[str, Any]] = Field(default=None) + """Optional filter data stored on the cache entry for customizing retrieval""" + + # Allow extra fields to simplify handling filters + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def validate_cache_hit(cls, values: Dict[str, Any]) -> Dict[str, Any]: + # Deserialize metadata if necessary + if "metadata" in values and isinstance(values["metadata"], str): + values["metadata"] = deserialize(values["metadata"]) + + # Collect any extra fields and store them as filters + extra_data = values.pop("__pydantic_extra__", {}) or {} + if extra_data: + current_filters = values.get("filters") or {} + if not isinstance(current_filters, dict): + current_filters = {} + current_filters.update(extra_data) + values["filters"] = current_filters + + return values + + def to_dict(self) -> Dict[str, Any]: + """Convert this model to a dictionary, merging filters into the result.""" + data = self.model_dump(exclude_none=True) + if data.get("filters"): + data.update(data["filters"]) + del data["filters"] + return data + + +class SemanticCacheIndexSchema(IndexSchema): + + @classmethod + def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str): + + return cls( + index={"name": name, "prefix": prefix}, # type: ignore + fields=[ # type: ignore + {"name": PROMPT_FIELD_NAME, "type": "text"}, + {"name": RESPONSE_FIELD_NAME, "type": "text"}, + {"name": INSERTED_AT_FIELD_NAME, "type": "numeric"}, + {"name": UPDATED_AT_FIELD_NAME, "type": "numeric"}, + { + "name": CACHE_VECTOR_FIELD_NAME, + "type": "vector", + "attrs": { + "dims": vector_dims, + "datatype": dtype, + "distance_metric": "cosine", + "algorithm": "flat", + }, + }, + ], + ) diff --git a/redisvl/extensions/cache/llm/semantic.py b/redisvl/extensions/cache/llm/semantic.py new file mode 100644 index 00000000..5a97b03d --- /dev/null +++ b/redisvl/extensions/cache/llm/semantic.py @@ -0,0 +1,825 @@ +import asyncio +from typing import Any, Dict, List, Optional, Tuple + +from redis import Redis + +from redisvl.extensions.cache.llm.base import BaseLLMCache +from redisvl.extensions.cache.llm.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) +from redisvl.extensions.constants import ( + CACHE_VECTOR_FIELD_NAME, + ENTRY_ID_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + METADATA_FIELD_NAME, + PROMPT_FIELD_NAME, + REDIS_KEY_FIELD_NAME, + RESPONSE_FIELD_NAME, + UPDATED_AT_FIELD_NAME, +) +from redisvl.index import AsyncSearchIndex, SearchIndex +from redisvl.query import VectorRangeQuery +from redisvl.query.filter import FilterExpression +from redisvl.redis.connection import RedisConnectionFactory +from redisvl.redis.utils import hashify +from redisvl.utils.log import get_logger +from redisvl.utils.utils import ( + current_timestamp, + deprecated_argument, + serialize, + validate_vector_dims, +) +from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer + +logger = get_logger("[RedisVL]") + + +class SemanticCache(BaseLLMCache): + """Semantic Cache for Large Language Models.""" + + _index: SearchIndex + _aindex: Optional[AsyncSearchIndex] = None + + @deprecated_argument("dtype", "vectorizer") + def __init__( + self, + name: str = "llmcache", + distance_threshold: float = 0.1, + ttl: Optional[int] = None, + vectorizer: Optional[BaseVectorizer] = None, + filterable_fields: Optional[List[Dict[str, Any]]] = None, + redis_client: Optional[Redis] = None, + redis_url: str = "redis://localhost:6379", + connection_kwargs: Dict[str, Any] = {}, + overwrite: bool = False, + **kwargs, + ): + """Semantic Cache for Large Language Models. + + Args: + name (str, optional): The name of the semantic cache search index. + Defaults to "llmcache". + distance_threshold (float, optional): Semantic threshold for the + cache. Defaults to 0.1. + ttl (Optional[int], optional): The time-to-live for records cached + in Redis. Defaults to None. + vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache. + Defaults to HFTextVectorizer. + filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields + that can be used to customize cache retrieval with filters. + redis_client(Optional[Redis], optional): A redis client connection instance. + Defaults to None. + redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. + connection_kwargs (Dict[str, Any]): The connection arguments + for the redis client. Defaults to empty {}. + overwrite (bool): Whether or not to force overwrite the schema for + the semantic cache index. Defaults to false. + + Raises: + TypeError: If an invalid vectorizer is provided. + TypeError: If the TTL value is not an int. + ValueError: If the threshold is not between 0 and 1. + ValueError: If existing schema does not match new schema and overwrite is False. + """ + # Call parent class with all shared parameters + super().__init__( + name=name, + ttl=ttl, + redis_client=redis_client, + redis_url=redis_url, + connection_kwargs=connection_kwargs, + ) + + # Handle the deprecated dtype parameter + dtype = kwargs.pop("dtype", None) + + # Set up vectorizer - either use the provided one or create a default + if vectorizer: + if not isinstance(vectorizer, BaseVectorizer): + raise TypeError("Must provide a valid redisvl.vectorizer class.") + if dtype and vectorizer.dtype != dtype: + raise ValueError( + f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}" + ) + self._vectorizer = vectorizer + else: + # Create a default vectorizer + vectorizer_kwargs = kwargs + if dtype: + vectorizer_kwargs.update(dtype=dtype) + + self._vectorizer = HFTextVectorizer( + model="sentence-transformers/all-mpnet-base-v2", + **vectorizer_kwargs, + ) + + # Set threshold for semantic matching + self.set_threshold(distance_threshold) + + # Define the fields to return in search results + self.return_fields = [ + ENTRY_ID_FIELD_NAME, + PROMPT_FIELD_NAME, + RESPONSE_FIELD_NAME, + INSERTED_AT_FIELD_NAME, + UPDATED_AT_FIELD_NAME, + METADATA_FIELD_NAME, + ] + + # Create semantic cache schema and index + schema = SemanticCacheIndexSchema.from_params( + name, name, self._vectorizer.dims, self._vectorizer.dtype # type: ignore + ) + schema = self._modify_schema(schema, filterable_fields) + + # Initialize the search index + self._index = SearchIndex( + schema=schema, + redis_client=self._redis_client, + redis_url=self.redis_kwargs["redis_url"], + **self.redis_kwargs["connection_kwargs"], + ) + self._aindex = None + + # Check for existing cache index and handle schema mismatch + self.overwrite = overwrite + if not self.overwrite and self._index.exists(): + existing_index = SearchIndex.from_existing( + name, redis_client=self._index.client + ) + if existing_index.schema.to_dict() != self._index.schema.to_dict(): + raise ValueError( + f"Existing index {name} schema does not match the user provided schema for the semantic cache. " + "If you wish to overwrite the index schema, set overwrite=True during initialization." + ) + + # Create the search index in Redis + self._index.create(overwrite=self.overwrite, drop=False) + + def _modify_schema( + self, + schema: SemanticCacheIndexSchema, + filterable_fields: Optional[List[Dict[str, Any]]] = None, + ) -> SemanticCacheIndexSchema: + """Modify the base cache schema using the provided filterable fields""" + + if filterable_fields is not None: + protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME]) + for filter_field in filterable_fields: + field_name = filter_field["name"] + if field_name in protected_field_names: + raise ValueError( + f"{field_name} is a reserved field name for the semantic cache schema" + ) + # Add to schema + schema.add_field(filter_field) + # Add to return fields too + self.return_fields.append(field_name) + + return schema + + async def _get_async_index(self) -> AsyncSearchIndex: + """Lazily construct the async search index class.""" + # Construct async index if necessary + async_client = None + if self._aindex is None: + client = self.redis_kwargs.get("redis_client") + if isinstance(client, Redis): + async_client = RedisConnectionFactory.sync_to_async_redis(client) + self._aindex = AsyncSearchIndex( + schema=self._index.schema, + redis_client=async_client, + redis_url=self.redis_kwargs["redis_url"], + **self.redis_kwargs["connection_kwargs"], + ) + return self._aindex + + @property + def index(self) -> SearchIndex: + """The underlying SearchIndex for the cache. + + Returns: + SearchIndex: The search index. + """ + return self._index + + @property + def aindex(self) -> Optional[AsyncSearchIndex]: + """The underlying AsyncSearchIndex for the cache. + + Returns: + AsyncSearchIndex: The async search index. + """ + return self._aindex + + @property + def distance_threshold(self) -> float: + """The semantic distance threshold for the cache. + + Returns: + float: The semantic distance threshold. + """ + return self._distance_threshold + + def set_threshold(self, distance_threshold: float) -> None: + """Sets the semantic distance threshold for the cache. + + Args: + distance_threshold (float): The semantic distance threshold for + the cache. + + Raises: + ValueError: If the threshold is not between 0 and 1. + """ + if not 0 <= float(distance_threshold) <= 2: + raise ValueError( + f"Distance must be between 0 and 2, got {distance_threshold}" + ) + self._distance_threshold = float(distance_threshold) + + def delete(self) -> None: + """Delete the cache and its index entirely.""" + self._index.delete(drop=True) + + async def adelete(self) -> None: + """Async delete the cache and its index entirely.""" + aindex = await self._get_async_index() + await aindex.delete(drop=True) + + def drop( + self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None + ) -> None: + """Drop specific entries from the cache by ID or Redis key. + + Args: + ids (Optional[List[str]]): List of entry IDs to remove from the cache. + Entry IDs are the unique identifiers without the cache prefix. + keys (Optional[List[str]]): List of full Redis keys to remove from the cache. + Keys are the complete Redis keys including the cache prefix. + + Note: + At least one of ids or keys must be provided. + + Raises: + ValueError: If neither ids nor keys is provided. + """ + if ids is None and keys is None: + raise ValueError("At least one of ids or keys must be provided.") + + # Convert entry IDs to full Redis keys if provided + if ids is not None: + self._index.drop_keys([self._index.key(id) for id in ids]) + if keys is not None: + self._index.drop_keys(keys) + + async def adrop( + self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None + ) -> None: + """Async drop specific entries from the cache by ID or Redis key. + + Args: + ids (Optional[List[str]]): List of entry IDs to remove from the cache. + Entry IDs are the unique identifiers without the cache prefix. + keys (Optional[List[str]]): List of full Redis keys to remove from the cache. + Keys are the complete Redis keys including the cache prefix. + + Note: + At least one of ids or keys must be provided. + + Raises: + ValueError: If neither ids nor keys is provided. + """ + aindex = await self._get_async_index() + + if ids is None and keys is None: + raise ValueError("At least one of ids or keys must be provided.") + + # Convert entry IDs to full Redis keys if provided + if ids is not None: + await aindex.drop_keys([self._index.key(id) for id in ids]) + if keys is not None: + await aindex.drop_keys(keys) + + def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: + """Converts a text prompt to its vector representation using the + configured vectorizer.""" + if not isinstance(prompt, str): + raise TypeError("Prompt must be a string.") + + result = self._vectorizer.embed(prompt) + return result # type: ignore + + async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: + """Converts a text prompt to its vector representation using the + configured vectorizer.""" + if not isinstance(prompt, str): + raise TypeError("Prompt must be a string.") + + result = await self._vectorizer.aembed(prompt) + return result # type: ignore + + def _check_vector_dims(self, vector: List[float]): + """Checks the size of the provided vector and raises an error if it + doesn't match the search index vector dimensions.""" + schema_vector_dims = self._index.schema.fields[ + CACHE_VECTOR_FIELD_NAME + ].attrs.dims # type: ignore + validate_vector_dims(len(vector), schema_vector_dims) + + def check( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Checks the semantic cache for results similar to the specified prompt + or vector. + + This method searches the cache using vector similarity with + either a raw text prompt (converted to a vector) or a provided vector as + input. It checks for semantically similar prompts and fetches the cached + LLM responses. + + Args: + prompt (Optional[str], optional): The text prompt to search for in + the cache. + vector (Optional[List[float]], optional): The vector representation + of the prompt to search for in the cache. + num_results (int, optional): The number of cached results to return. + Defaults to 1. + return_fields (Optional[List[str]], optional): The fields to include + in each returned result. If None, defaults to all available + fields in the cached entry. + filter_expression (Optional[FilterExpression]) : Optional filter expression + that can be used to filter cache results. Defaults to None and + the full cache will be searched. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + + Returns: + List[Dict[str, Any]]: A list of dicts containing the requested + return fields for each similar cached response. + + Raises: + ValueError: If neither a `prompt` nor a `vector` is specified. + ValueError: if 'vector' has incorrect dimensions. + TypeError: If `return_fields` is not a list when provided. + + .. code-block:: python + + response = cache.check( + prompt="What is the captial city of France?" + ) + """ + if not any([prompt, vector]): + raise ValueError("Either prompt or vector must be specified.") + if return_fields and not isinstance(return_fields, list): + raise TypeError("Return fields must be a list of values.") + + # Use overrides or defaults + distance_threshold = distance_threshold or self._distance_threshold + + # Vectorize prompt if not provided + if vector is None and prompt is not None: + vector = self._vectorize_prompt(prompt) + + # Validate the vector dimensions + if vector is not None: + self._check_vector_dims(vector) + else: + raise ValueError("Failed to generate a valid vector for the query.") + + # Create the vector search query + query = VectorRangeQuery( + vector=vector, + vector_field_name=CACHE_VECTOR_FIELD_NAME, + return_fields=self.return_fields, + distance_threshold=distance_threshold, + num_results=num_results, + return_score=True, + filter_expression=filter_expression, + dtype=self._vectorizer.dtype, + ) + + # Search the cache! + cache_search_results = self._index.query(query) + redis_keys, cache_hits = self._process_cache_results( + cache_search_results, + return_fields, # type: ignore + ) + + # Refresh TTL on all found keys + for key in redis_keys: + self.expire(key) + + return cache_hits + + async def acheck( + self, + prompt: Optional[str] = None, + vector: Optional[List[float]] = None, + num_results: int = 1, + return_fields: Optional[List[str]] = None, + filter_expression: Optional[FilterExpression] = None, + distance_threshold: Optional[float] = None, + ) -> List[Dict[str, Any]]: + """Async check the semantic cache for results similar to the specified prompt + or vector. + + This method searches the cache using vector similarity with + either a raw text prompt (converted to a vector) or a provided vector as + input. It checks for semantically similar prompts and fetches the cached + LLM responses. + + Args: + prompt (Optional[str], optional): The text prompt to search for in + the cache. + vector (Optional[List[float]], optional): The vector representation + of the prompt to search for in the cache. + num_results (int, optional): The number of cached results to return. + Defaults to 1. + return_fields (Optional[List[str]], optional): The fields to include + in each returned result. If None, defaults to all available + fields in the cached entry. + filter_expression (Optional[FilterExpression]) : Optional filter expression + that can be used to filter cache results. Defaults to None and + the full cache will be searched. + distance_threshold (Optional[float]): The threshold for semantic + vector distance. + + Returns: + List[Dict[str, Any]]: A list of dicts containing the requested + return fields for each similar cached response. + + Raises: + ValueError: If neither a `prompt` nor a `vector` is specified. + ValueError: if 'vector' has incorrect dimensions. + TypeError: If `return_fields` is not a list when provided. + + .. code-block:: python + + response = await cache.acheck( + prompt="What is the captial city of France?" + ) + """ + aindex = await self._get_async_index() + + if not any([prompt, vector]): + raise ValueError("Either prompt or vector must be specified.") + if return_fields and not isinstance(return_fields, list): + raise TypeError("Return fields must be a list of values.") + + # Use overrides or defaults + distance_threshold = distance_threshold or self._distance_threshold + + # Vectorize prompt if not provided + if vector is None and prompt is not None: + vector = await self._avectorize_prompt(prompt) + + # Validate the vector dimensions + if vector is not None: + self._check_vector_dims(vector) + else: + raise ValueError("Failed to generate a valid vector for the query.") + + # Create the vector search query + query = VectorRangeQuery( + vector=vector, + vector_field_name=CACHE_VECTOR_FIELD_NAME, + return_fields=self.return_fields, + distance_threshold=distance_threshold, + num_results=num_results, + return_score=True, + filter_expression=filter_expression, + normalize_vector_distance=True, + ) + + # Search the cache! + cache_search_results = await aindex.query(query) + redis_keys, cache_hits = self._process_cache_results( + cache_search_results, + return_fields, # type: ignore + ) + + # Refresh TTL on all found keys async + await asyncio.gather(*[self.aexpire(key) for key in redis_keys]) + + return cache_hits + + def _process_cache_results( + self, + cache_search_results: List[Dict[str, Any]], + return_fields: Optional[List[str]] = None, + ) -> Tuple[List[str], List[Dict[str, Any]]]: + """Process raw search results into cache hits.""" + redis_keys: List[str] = [] + cache_hits: List[Dict[Any, str]] = [] + + for cache_search_result in cache_search_results: + # Pop the redis key from the result + redis_key = cache_search_result.pop("id") + redis_keys.append(redis_key) + + # Create and process cache hit + cache_hit = CacheHit(**cache_search_result) + cache_hit_dict = cache_hit.to_dict() + + # Filter down to only selected return fields if needed + if isinstance(return_fields, list) and return_fields: + cache_hit_dict = { + k: v for k, v in cache_hit_dict.items() if k in return_fields + } + + # Add the Redis key to the result + cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key + cache_hits.append(cache_hit_dict) + + return redis_keys, cache_hits + + def store( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Stores the specified key-value pair in the cache along with metadata. + + Args: + prompt (str): The user prompt to cache. + response (str): The LLM response to cache. + vector (Optional[List[float]], optional): The prompt vector to + cache. Defaults to None, and the prompt vector is generated on + demand. + metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache + alongside the prompt and response. Defaults to None. + filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. + Defaults to None. + ttl (Optional[int]): The optional TTL override to use on this individual cache + entry. Defaults to the global TTL setting. + + Returns: + str: The Redis key for the entries added to the semantic cache. + + Raises: + ValueError: If neither prompt nor vector is specified. + ValueError: if vector has incorrect dimensions. + TypeError: If provided metadata is not a dictionary. + + .. code-block:: python + + key = cache.store( + prompt="What is the captial city of France?", + response="Paris", + metadata={"city": "Paris", "country": "France"} + ) + """ + # Vectorize prompt if necessary + vector = vector or self._vectorize_prompt(prompt) + self._check_vector_dims(vector) + + # Generate the entry ID + entry_id = self._make_entry_id(prompt, filters) + + # Build cache entry for the cache + cache_entry = CacheEntry( + entry_id=entry_id, + prompt=prompt, + response=response, + prompt_vector=vector, + metadata=metadata, + filters=filters, + ) + + # Load cache entry with TTL + ttl = ttl or self._ttl + keys = self._index.load( + data=[cache_entry.to_dict(self._vectorizer.dtype)], + ttl=ttl, + id_field=ENTRY_ID_FIELD_NAME, + ) + + # Return the key where the entry was stored + return keys[0] + + async def astore( + self, + prompt: str, + response: str, + vector: Optional[List[float]] = None, + metadata: Optional[Dict[str, Any]] = None, + filters: Optional[Dict[str, Any]] = None, + ttl: Optional[int] = None, + ) -> str: + """Async stores the specified key-value pair in the cache along with metadata. + + Args: + prompt (str): The user prompt to cache. + response (str): The LLM response to cache. + vector (Optional[List[float]], optional): The prompt vector to + cache. Defaults to None, and the prompt vector is generated on + demand. + metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache + alongside the prompt and response. Defaults to None. + filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. + Defaults to None. + ttl (Optional[int]): The optional TTL override to use on this individual cache + entry. Defaults to the global TTL setting. + + Returns: + str: The Redis key for the entries added to the semantic cache. + + Raises: + ValueError: If neither prompt nor vector is specified. + ValueError: if vector has incorrect dimensions. + TypeError: If provided metadata is not a dictionary. + + .. code-block:: python + + key = await cache.astore( + prompt="What is the captial city of France?", + response="Paris", + metadata={"city": "Paris", "country": "France"} + ) + """ + aindex = await self._get_async_index() + + # Vectorize prompt if necessary + vector = vector or await self._avectorize_prompt(prompt) + self._check_vector_dims(vector) + + # Generate the entry ID + entry_id = self._make_entry_id(prompt, filters) + + # Build cache entry for the cache + cache_entry = CacheEntry( + entry_id=entry_id, + prompt=prompt, + response=response, + prompt_vector=vector, + metadata=metadata, + filters=filters, + ) + + # Load cache entry with TTL + ttl = ttl or self._ttl + keys = await aindex.load( + data=[cache_entry.to_dict(self._vectorizer.dtype)], + ttl=ttl, + id_field=ENTRY_ID_FIELD_NAME, + ) + + # Return the key where the entry was stored + return keys[0] + + def update(self, key: str, **kwargs) -> None: + """Update specific fields within an existing cache entry. If no fields + are passed, then only the document TTL is refreshed. + + Args: + key (str): the key of the document to update using kwargs. + + Raises: + ValueError if an incorrect mapping is provided as a kwarg. + TypeError if metadata is provided and not of type dict. + + .. code-block:: python + + key = cache.store('this is a prompt', 'this is a response') + cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) + """ + if kwargs: + for k, v in kwargs.items(): + # Make sure the item is in the index schema + if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): + raise ValueError(f"{k} is not a valid field within the cache entry") + + # Check for metadata and serialize + if k == METADATA_FIELD_NAME: + if isinstance(v, dict): + kwargs[k] = serialize(v) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + + # Add updated timestamp + kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) + + # Update the hash in Redis - ensure client exists and handle type properly + client = self._get_redis_client() + client.hset(key, mapping=kwargs) # type: ignore + + # Refresh TTL regardless of whether fields were updated + self.expire(key) + + async def aupdate(self, key: str, **kwargs) -> None: + """Async update specific fields within an existing cache entry. If no fields + are passed, then only the document TTL is refreshed. + + Args: + key (str): the key of the document to update using kwargs. + + Raises: + ValueError if an incorrect mapping is provided as a kwarg. + TypeError if metadata is provided and not of type dict. + + .. code-block:: python + + key = await cache.astore('this is a prompt', 'this is a response') + await cache.aupdate( + key, + metadata={"hit_count": 1, "model_name": "Llama-2-7b"} + ) + """ + if kwargs: + for k, v in kwargs.items(): + # Make sure the item is in the index schema + if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): + raise ValueError(f"{k} is not a valid field within the cache entry") + + # Check for metadata and serialize + if k == METADATA_FIELD_NAME: + if isinstance(v, dict): + kwargs[k] = serialize(v) + else: + raise TypeError( + "If specified, cached metadata must be a dictionary." + ) + + # Add updated timestamp + kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) + + # Update the hash in Redis - ensure client exists and handle type properly + client = await self._get_async_redis_client() + # Convert dict values to proper types for Redis + await client.hset(key, mapping=kwargs) # type: ignore + + # Refresh TTL regardless of whether fields were updated + await self.aexpire(key) + + def __enter__(self): + """Context manager entry point.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit point.""" + self.disconnect() + + async def __aenter__(self): + """Async context manager entry.""" + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + await self.adisconnect() + + def disconnect(self): + """Disconnect from Redis and search index. + + Closes all Redis connections and index connections. + """ + # Close the search index connections + if hasattr(self, "_index") and self._index: + self._index.disconnect() + + # Close the async search index connections + if hasattr(self, "_aindex") and self._aindex: + self._aindex.disconnect_sync() + + # Close the base Redis connections + super().disconnect() + + async def adisconnect(self): + """Asynchronously disconnect from Redis and search index. + + Closes all Redis connections and index connections. + """ + # Close the async search index connections + if hasattr(self, "_aindex") and self._aindex: + await self._aindex.disconnect() + self._aindex = None + + # Close the base Redis connections + await super().adisconnect() + + def _make_entry_id( + self, prompt: str, filters: Optional[Dict[str, Any]] = None + ) -> str: + """Generate a deterministic entry ID for the given prompt and optional filters. + + Args: + prompt (str): The prompt text. + filters (Optional[Dict[str, Any]]): Optional filter dictionary. + + Returns: + str: A deterministic entry ID based on the prompt and filters. + """ + return hashify(prompt, filters) diff --git a/redisvl/extensions/constants.py b/redisvl/extensions/constants.py index 7d7fb841..dfd2ceef 100644 --- a/redisvl/extensions/constants.py +++ b/redisvl/extensions/constants.py @@ -25,5 +25,11 @@ UPDATED_AT_FIELD_NAME: str = "updated_at" METADATA_FIELD_NAME: str = "metadata" +# EmbeddingsCache +TEXT_FIELD_NAME: str = "text" +MODEL_NAME_FIELD_NAME: str = "model_name" +EMBEDDING_FIELD_NAME: str = "embedding" +DIMENSIONS_FIELD_NAME: str = "dimensions" + # SemanticRouter ROUTE_VECTOR_FIELD_NAME: str = "vector" diff --git a/redisvl/extensions/llmcache/__init__.py b/redisvl/extensions/llmcache/__init__.py index d2eed359..732b0694 100644 --- a/redisvl/extensions/llmcache/__init__.py +++ b/redisvl/extensions/llmcache/__init__.py @@ -1,3 +1,30 @@ -from redisvl.extensions.llmcache.semantic import SemanticCache +""" +RedisVL LLM Cache Extensions (Deprecated Path) -__all__ = ["SemanticCache"] +This module is kept for backward compatibility. Please use `redisvl.extensions.cache` instead. +""" + +import warnings + +from redisvl.extensions.cache.llm.base import BaseLLMCache +from redisvl.extensions.cache.llm.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) +from redisvl.extensions.cache.llm.semantic import SemanticCache + +warnings.warn( + "Importing from redisvl.extensions.llmcache is deprecated. " + "Please import from redisvl.extensions.cache.llm instead.", + DeprecationWarning, + stacklevel=2, +) + +__all__ = [ + "BaseLLMCache", + "SemanticCache", + "CacheEntry", + "CacheHit", + "SemanticCacheIndexSchema", +] diff --git a/redisvl/extensions/llmcache/base.py b/redisvl/extensions/llmcache/base.py index 8fce67ca..76ad97a5 100644 --- a/redisvl/extensions/llmcache/base.py +++ b/redisvl/extensions/llmcache/base.py @@ -1,79 +1,18 @@ -from typing import Any, Dict, List, Optional +""" +RedisVL Base LLM Cache (Deprecated Path) +This module is kept for backward compatibility. Please use `redisvl.extensions.cache.llm.base` instead. +""" -class BaseLLMCache: - def __init__(self, ttl: Optional[int] = None): - self._ttl: Optional[int] = None - self.set_ttl(ttl) +import warnings - @property - def ttl(self) -> Optional[int]: - """The default TTL, in seconds, for entries in the cache.""" - return self._ttl +from redisvl.extensions.cache.llm.base import BaseLLMCache - def set_ttl(self, ttl: Optional[int] = None): - """Set the default TTL, in seconds, for entries in the cache. +warnings.warn( + "Importing from redisvl.extensions.llmcache.base is deprecated. " + "Please import from redisvl.extensions.cache.llm.base instead.", + DeprecationWarning, + stacklevel=2, +) - Args: - ttl (Optional[int], optional): The optional time-to-live expiration - for the cache, in seconds. - - Raises: - ValueError: If the time-to-live value is not an integer. - """ - if ttl: - if not isinstance(ttl, int): - raise ValueError(f"TTL must be an integer value, got {ttl}") - self._ttl = int(ttl) - else: - self._ttl = None - - def clear(self) -> None: - """Clear the cache of all keys in the index.""" - raise NotImplementedError - - async def aclear(self) -> None: - """Async clear the cache of all keys in the index.""" - raise NotImplementedError - - def check( - self, - prompt: Optional[str] = None, - vector: Optional[List[float]] = None, - num_results: int = 1, - return_fields: Optional[List[str]] = None, - ) -> List[dict]: - """Check the cache based on a prompt or vector.""" - raise NotImplementedError - - async def acheck( - self, - prompt: Optional[str] = None, - vector: Optional[List[float]] = None, - num_results: int = 1, - return_fields: Optional[List[str]] = None, - ) -> List[dict]: - """Async check the cache based on a prompt or vector.""" - raise NotImplementedError - - def store( - self, - prompt: str, - response: str, - vector: Optional[List[float]] = None, - metadata: Optional[dict] = {}, - ) -> str: - """Store the specified key-value pair in the cache along with - metadata.""" - raise NotImplementedError - - async def astore( - self, - prompt: str, - response: str, - vector: Optional[List[float]] = None, - metadata: Optional[dict] = {}, - ) -> str: - """Async store the specified key-value pair in the cache along with - metadata.""" - raise NotImplementedError +__all__ = ["BaseLLMCache"] diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index fa6f720a..95ed3a6a 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -1,136 +1,22 @@ -from typing import Any, Dict, List, Optional +""" +RedisVL Semantic Cache Schema (Deprecated Path) -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +This module is kept for backward compatibility. Please use `redisvl.extensions.cache.llm.schema` instead. +""" -from redisvl.extensions.constants import ( - CACHE_VECTOR_FIELD_NAME, - INSERTED_AT_FIELD_NAME, - PROMPT_FIELD_NAME, - RESPONSE_FIELD_NAME, - UPDATED_AT_FIELD_NAME, -) -from redisvl.redis.utils import array_to_buffer, hashify -from redisvl.schema import IndexSchema -from redisvl.utils.utils import current_timestamp, deserialize, serialize - - -class CacheEntry(BaseModel): - """A single cache entry in Redis""" - - entry_id: Optional[str] = Field(default=None) - """Cache entry identifier""" - prompt: str - """Input prompt or question cached in Redis""" - response: str - """Response or answer to the question, cached in Redis""" - prompt_vector: List[float] - """Text embedding representation of the prompt""" - inserted_at: float = Field(default_factory=current_timestamp) - """Timestamp of when the entry was added to the cache""" - updated_at: float = Field(default_factory=current_timestamp) - """Timestamp of when the entry was updated in the cache""" - metadata: Optional[Dict[str, Any]] = Field(default=None) - """Optional metadata stored on the cache entry""" - filters: Optional[Dict[str, Any]] = Field(default=None) - """Optional filter data stored on the cache entry for customizing retrieval""" - - @model_validator(mode="before") - @classmethod - def generate_id(cls, values): - # Ensure entry_id is set - if not values.get("entry_id"): - values["entry_id"] = hashify(values["prompt"], values.get("filters")) - return values - - @field_validator("metadata") - @classmethod - def non_empty_metadata(cls, v): - if v is not None and not isinstance(v, dict): - raise TypeError("Metadata must be a dictionary.") - return v - - def to_dict(self, dtype: str) -> Dict: - data = self.model_dump(exclude_none=True) - data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) - if self.metadata is not None: - data["metadata"] = serialize(self.metadata) - if self.filters is not None: - data.update(self.filters) - del data["filters"] - return data - - -class CacheHit(BaseModel): - """A cache hit based on some input query""" +import warnings - entry_id: str - """Cache entry identifier""" - prompt: str - """Input prompt or question cached in Redis""" - response: str - """Response or answer to the question, cached in Redis""" - vector_distance: float - """The semantic distance between the query vector and the stored prompt vector""" - inserted_at: float - """Timestamp of when the entry was added to the cache""" - updated_at: float - """Timestamp of when the entry was updated in the cache""" - metadata: Optional[Dict[str, Any]] = Field(default=None) - """Optional metadata stored on the cache entry""" - filters: Optional[Dict[str, Any]] = Field(default=None) - """Optional filter data stored on the cache entry for customizing retrieval""" - - # Allow extra fields to simplify handling filters - model_config = ConfigDict(extra="allow") - - @model_validator(mode="before") - @classmethod - def validate_cache_hit(cls, values: Dict[str, Any]) -> Dict[str, Any]: - # Deserialize metadata if necessary - if "metadata" in values and isinstance(values["metadata"], str): - values["metadata"] = deserialize(values["metadata"]) - - # Collect any extra fields and store them as filters - extra_data = values.pop("__pydantic_extra__", {}) or {} - if extra_data: - current_filters = values.get("filters") or {} - if not isinstance(current_filters, dict): - current_filters = {} - current_filters.update(extra_data) - values["filters"] = current_filters - - return values - - def to_dict(self) -> Dict[str, Any]: - """Convert this model to a dictionary, merging filters into the result.""" - data = self.model_dump(exclude_none=True) - if data.get("filters"): - data.update(data["filters"]) - del data["filters"] - return data - - -class SemanticCacheIndexSchema(IndexSchema): +from redisvl.extensions.cache.llm.schema import ( + CacheEntry, + CacheHit, + SemanticCacheIndexSchema, +) - @classmethod - def from_params(cls, name: str, prefix: str, vector_dims: int, dtype: str): +warnings.warn( + "Importing from redisvl.extensions.llmcache.schema is deprecated. " + "Please import from redisvl.extensions.cache.llm.schema instead.", + DeprecationWarning, + stacklevel=2, +) - return cls( - index={"name": name, "prefix": prefix}, # type: ignore - fields=[ # type: ignore - {"name": PROMPT_FIELD_NAME, "type": "text"}, - {"name": RESPONSE_FIELD_NAME, "type": "text"}, - {"name": INSERTED_AT_FIELD_NAME, "type": "numeric"}, - {"name": UPDATED_AT_FIELD_NAME, "type": "numeric"}, - { - "name": CACHE_VECTOR_FIELD_NAME, - "type": "vector", - "attrs": { - "dims": vector_dims, - "datatype": dtype, - "distance_metric": "cosine", - "algorithm": "flat", - }, - }, - ], - ) +__all__ = ["CacheEntry", "CacheHit", "SemanticCacheIndexSchema"] diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 13bab707..9c234aff 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -1,745 +1,18 @@ -import asyncio -import weakref -from typing import Any, Dict, List, Optional +""" +RedisVL Semantic Cache (Deprecated Path) -import numpy as np -from redis import Redis +This module is kept for backward compatibility. Please use `redisvl.extensions.cache.llm.semantic` instead. +""" -from redisvl.extensions.constants import ( - CACHE_VECTOR_FIELD_NAME, - ENTRY_ID_FIELD_NAME, - INSERTED_AT_FIELD_NAME, - METADATA_FIELD_NAME, - PROMPT_FIELD_NAME, - REDIS_KEY_FIELD_NAME, - RESPONSE_FIELD_NAME, - UPDATED_AT_FIELD_NAME, -) -from redisvl.extensions.llmcache.base import BaseLLMCache -from redisvl.extensions.llmcache.schema import ( - CacheEntry, - CacheHit, - SemanticCacheIndexSchema, -) -from redisvl.index import AsyncSearchIndex, SearchIndex -from redisvl.query import VectorRangeQuery -from redisvl.query.filter import FilterExpression -from redisvl.query.query import BaseQuery -from redisvl.redis.connection import RedisConnectionFactory -from redisvl.utils.log import get_logger -from redisvl.utils.utils import ( - current_timestamp, - deprecated_argument, - serialize, - sync_wrapper, - validate_vector_dims, -) -from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer - -logger = get_logger("[RedisVL]") - - -class SemanticCache(BaseLLMCache): - """Semantic Cache for Large Language Models.""" - - _index: SearchIndex - _aindex: Optional[AsyncSearchIndex] = None - - @deprecated_argument("dtype", "vectorizer") - def __init__( - self, - name: str = "llmcache", - distance_threshold: float = 0.1, - ttl: Optional[int] = None, - vectorizer: Optional[BaseVectorizer] = None, - filterable_fields: Optional[List[Dict[str, Any]]] = None, - redis_client: Optional[Redis] = None, - redis_url: str = "redis://localhost:6379", - connection_kwargs: Dict[str, Any] = {}, - overwrite: bool = False, - **kwargs, - ): - """Semantic Cache for Large Language Models. - - Args: - name (str, optional): The name of the semantic cache search index. - Defaults to "llmcache". - distance_threshold (float, optional): Semantic threshold for the - cache. Defaults to 0.1. - ttl (Optional[int], optional): The time-to-live for records cached - in Redis. Defaults to None. - vectorizer (Optional[BaseVectorizer], optional): The vectorizer for the cache. - Defaults to HFTextVectorizer. - filterable_fields (Optional[List[Dict[str, Any]]]): An optional list of RedisVL fields - that can be used to customize cache retrieval with filters. - redis_client(Optional[Redis], optional): A redis client connection instance. - Defaults to None. - redis_url (str, optional): The redis url. Defaults to redis://localhost:6379. - connection_kwargs (Dict[str, Any]): The connection arguments - for the redis client. Defaults to empty {}. - overwrite (bool): Whether or not to force overwrite the schema for - the semantic cache index. Defaults to false. - - Raises: - TypeError: If an invalid vectorizer is provided. - TypeError: If the TTL value is not an int. - ValueError: If the threshold is not between 0 and 1. - ValueError: If existing schema does not match new schema and overwrite is False. - """ - super().__init__(ttl) - - self.redis_kwargs = { - "redis_client": redis_client, - "redis_url": redis_url, - "connection_kwargs": connection_kwargs, - } - - # Use the index name as the key prefix by default - prefix = kwargs.pop("prefix", name) - dtype = kwargs.pop("dtype", None) - - # Validate a provided vectorizer or set the default - if vectorizer: - if not isinstance(vectorizer, BaseVectorizer): - raise TypeError("Must provide a valid redisvl.vectorizer class.") - if dtype and vectorizer.dtype != dtype: - raise ValueError( - f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}" - ) - else: - vectorizer_kwargs = kwargs - - if dtype: - vectorizer_kwargs.update(**{"dtype": dtype}) - - vectorizer = HFTextVectorizer( - model="sentence-transformers/all-mpnet-base-v2", - **vectorizer_kwargs, - ) - - self._vectorizer = vectorizer - - # Process fields and other settings - self.set_threshold(distance_threshold) - self.return_fields = [ - ENTRY_ID_FIELD_NAME, - PROMPT_FIELD_NAME, - RESPONSE_FIELD_NAME, - INSERTED_AT_FIELD_NAME, - UPDATED_AT_FIELD_NAME, - METADATA_FIELD_NAME, - ] - - # Create semantic cache schema and index - schema = SemanticCacheIndexSchema.from_params( - name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore - ) - schema = self._modify_schema(schema, filterable_fields) - - if redis_client: - self._owns_redis_client = False - else: - self._owns_redis_client = True - - self._index = SearchIndex( - schema=schema, - redis_client=redis_client, - redis_url=redis_url, - **connection_kwargs, - ) - - # Check for existing cache index - if not overwrite and self._index.exists(): - existing_index = SearchIndex.from_existing( - name, redis_client=self._index.client - ) - if existing_index.schema.to_dict() != self._index.schema.to_dict(): - raise ValueError( - f"Existing index {name} schema does not match the user provided schema for the semantic cache. " - "If you wish to overwrite the index schema, set overwrite=True during initialization." - ) - - # Create the search index in Redis - self._index.create(overwrite=overwrite, drop=False) - - def _modify_schema( - self, - schema: SemanticCacheIndexSchema, - filterable_fields: Optional[List[Dict[str, Any]]] = None, - ) -> SemanticCacheIndexSchema: - """Modify the base cache schema using the provided filterable fields""" - - if filterable_fields is not None: - protected_field_names = set(self.return_fields + [REDIS_KEY_FIELD_NAME]) - for filter_field in filterable_fields: - field_name = filter_field["name"] - if field_name in protected_field_names: - raise ValueError( - f"{field_name} is a reserved field name for the semantic cache schema" - ) - # Add to schema - schema.add_field(filter_field) - # Add to return fields too - self.return_fields.append(field_name) - - return schema - - async def _get_async_index(self) -> AsyncSearchIndex: - """Lazily construct the async search index class.""" - # Construct async index if necessary - async_client = None - if self._aindex is None: - client = self.redis_kwargs.get("redis_client") - if isinstance(client, Redis): - async_client = RedisConnectionFactory.sync_to_async_redis(client) - self._aindex = AsyncSearchIndex( - schema=self._index.schema, - redis_client=async_client, - redis_url=self.redis_kwargs["redis_url"], - **self.redis_kwargs["connection_kwargs"], - ) - return self._aindex - - @property - def index(self) -> SearchIndex: - """The underlying SearchIndex for the cache. - - Returns: - SearchIndex: The search index. - """ - return self._index - - @property - def aindex(self) -> Optional[AsyncSearchIndex]: - """The underlying AsyncSearchIndex for the cache. - - Returns: - AsyncSearchIndex: The async search index. - """ - return self._aindex - - @property - def distance_threshold(self) -> float: - """The semantic distance threshold for the cache. - - Returns: - float: The semantic distance threshold. - """ - return self._distance_threshold - - def set_threshold(self, distance_threshold: float) -> None: - """Sets the semantic distance threshold for the cache. - - Args: - distance_threshold (float): The semantic distance threshold for - the cache. - - Raises: - ValueError: If the threshold is not between 0 and 1. - """ - if not 0 <= float(distance_threshold) <= 2: - raise ValueError( - f"Distance must be between 0 and 2, got {distance_threshold}" - ) - self._distance_threshold = float(distance_threshold) - - def clear(self) -> None: - """Clear the cache of all keys while preserving the index.""" - self._index.clear() - - async def aclear(self) -> None: - """""" - aindex = await self._get_async_index() - await aindex.clear() - - def delete(self) -> None: - """Clear the semantic cache of all keys and remove the underlying search - index.""" - self._index.delete(drop=True) - - async def adelete(self) -> None: - """""" - aindex = await self._get_async_index() - await aindex.delete(drop=True) - - def drop( - self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None - ) -> None: - """Manually expire specific entries from the cache by id or specific - Redis key. - - Args: - ids (Optional[str]): The document ID or IDs to remove from the cache. - keys (Optional[str]): The Redis keys to remove from the cache. - """ - if ids is not None: - self._index.drop_keys([self._index.key(id) for id in ids]) - if keys is not None: - self._index.drop_keys(keys) - - async def adrop( - self, ids: Optional[List[str]] = None, keys: Optional[List[str]] = None - ) -> None: - """Async expire specific entries from the cache by id or specific - Redis key. - - Args: - ids (Optional[str]): The document ID or IDs to remove from the cache. - keys (Optional[str]): The Redis keys to remove from the cache. - """ - aindex = await self._get_async_index() - - if ids is not None: - await aindex.drop_keys([self._index.key(id) for id in ids]) - if keys is not None: - await aindex.drop_keys(keys) - - def _refresh_ttl(self, key: str) -> None: - """Refresh the time-to-live for the specified key.""" - if self._ttl: - self._index.expire_keys(key, self._ttl) - - async def _async_refresh_ttl(self, key: str) -> None: - """Async refresh the time-to-live for the specified key.""" - aindex = await self._get_async_index() - if self._ttl: - await aindex.expire_keys(key, self._ttl) - - def _vectorize_prompt(self, prompt: Optional[str]) -> List[float]: - """Converts a text prompt to its vector representation using the - configured vectorizer.""" - if not isinstance(prompt, str): - raise TypeError("Prompt must be a string.") - - result = self._vectorizer.embed(prompt) - return result # type: ignore - - async def _avectorize_prompt(self, prompt: Optional[str]) -> List[float]: - """Converts a text prompt to its vector representation using the - configured vectorizer.""" - if not isinstance(prompt, str): - raise TypeError("Prompt must be a string.") - - result = await self._vectorizer.aembed(prompt) - return result # type: ignore - - def _check_vector_dims(self, vector: List[float]): - """Checks the size of the provided vector and raises an error if it - doesn't match the search index vector dimensions.""" - schema_vector_dims = self._index.schema.fields[ - CACHE_VECTOR_FIELD_NAME - ].attrs.dims # type: ignore - validate_vector_dims(len(vector), schema_vector_dims) - - def check( - self, - prompt: Optional[str] = None, - vector: Optional[List[float]] = None, - num_results: int = 1, - return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, - distance_threshold: Optional[float] = None, - ) -> List[Dict[str, Any]]: - """Checks the semantic cache for results similar to the specified prompt - or vector. - - This method searches the cache using vector similarity with - either a raw text prompt (converted to a vector) or a provided vector as - input. It checks for semantically similar prompts and fetches the cached - LLM responses. - - Args: - prompt (Optional[str], optional): The text prompt to search for in - the cache. - vector (Optional[List[float]], optional): The vector representation - of the prompt to search for in the cache. - num_results (int, optional): The number of cached results to return. - Defaults to 1. - return_fields (Optional[List[str]], optional): The fields to include - in each returned result. If None, defaults to all available - fields in the cached entry. - filter_expression (Optional[FilterExpression]) : Optional filter expression - that can be used to filter cache results. Defaults to None and - the full cache will be searched. - distance_threshold (Optional[float]): The threshold for semantic - vector distance. - - Returns: - List[Dict[str, Any]]: A list of dicts containing the requested - return fields for each similar cached response. - - Raises: - ValueError: If neither a `prompt` nor a `vector` is specified. - ValueError: if 'vector' has incorrect dimensions. - TypeError: If `return_fields` is not a list when provided. - - .. code-block:: python +import warnings - response = cache.check( - prompt="What is the captial city of France?" - ) - """ - if not any([prompt, vector]): - raise ValueError("Either prompt or vector must be specified.") - if return_fields and not isinstance(return_fields, list): - raise TypeError("Return fields must be a list of values.") +from redisvl.extensions.cache.llm.semantic import SemanticCache - # overrides - distance_threshold = distance_threshold or self._distance_threshold - vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) - - query = VectorRangeQuery( - vector=vector, - vector_field_name=CACHE_VECTOR_FIELD_NAME, - return_fields=self.return_fields, - distance_threshold=distance_threshold, - num_results=num_results, - return_score=True, - filter_expression=filter_expression, - dtype=self._vectorizer.dtype, - ) - - # Search the cache! - cache_search_results = self._index.query(query) - redis_keys, cache_hits = self._process_cache_results( - cache_search_results, - return_fields, # type: ignore - ) - # Extend TTL on keys - for key in redis_keys: - self._refresh_ttl(key) - - return cache_hits - - async def acheck( - self, - prompt: Optional[str] = None, - vector: Optional[List[float]] = None, - num_results: int = 1, - return_fields: Optional[List[str]] = None, - filter_expression: Optional[FilterExpression] = None, - distance_threshold: Optional[float] = None, - ) -> List[Dict[str, Any]]: - """Async check the semantic cache for results similar to the specified prompt - or vector. - - This method searches the cache using vector similarity with - either a raw text prompt (converted to a vector) or a provided vector as - input. It checks for semantically similar prompts and fetches the cached - LLM responses. - - Args: - prompt (Optional[str], optional): The text prompt to search for in - the cache. - vector (Optional[List[float]], optional): The vector representation - of the prompt to search for in the cache. - num_results (int, optional): The number of cached results to return. - Defaults to 1. - return_fields (Optional[List[str]], optional): The fields to include - in each returned result. If None, defaults to all available - fields in the cached entry. - filter_expression (Optional[FilterExpression]) : Optional filter expression - that can be used to filter cache results. Defaults to None and - the full cache will be searched. - distance_threshold (Optional[float]): The threshold for semantic - vector distance. - - Returns: - List[Dict[str, Any]]: A list of dicts containing the requested - return fields for each similar cached response. - - Raises: - ValueError: If neither a `prompt` nor a `vector` is specified. - ValueError: if 'vector' has incorrect dimensions. - TypeError: If `return_fields` is not a list when provided. - - .. code-block:: python - - response = await cache.acheck( - prompt="What is the captial city of France?" - ) - """ - aindex = await self._get_async_index() - - if not any([prompt, vector]): - raise ValueError("Either prompt or vector must be specified.") - if return_fields and not isinstance(return_fields, list): - raise TypeError("Return fields must be a list of values.") - - # overrides - distance_threshold = distance_threshold or self._distance_threshold - vector = vector or await self._avectorize_prompt(prompt) - self._check_vector_dims(vector) - - query = VectorRangeQuery( - vector=vector, - vector_field_name=CACHE_VECTOR_FIELD_NAME, - return_fields=self.return_fields, - distance_threshold=distance_threshold, - num_results=num_results, - return_score=True, - filter_expression=filter_expression, - normalize_vector_distance=True, - ) - - # Search the cache! - cache_search_results = await aindex.query(query) - redis_keys, cache_hits = self._process_cache_results( - cache_search_results, - return_fields, # type: ignore - ) - # Extend TTL on keys - await asyncio.gather(*[self._async_refresh_ttl(key) for key in redis_keys]) - - return cache_hits - - def _process_cache_results( - self, cache_search_results: List[Dict[str, Any]], return_fields: List[str] - ): - redis_keys: List[str] = [] - cache_hits: List[Dict[Any, str]] = [] - for cache_search_result in cache_search_results: - # Pop the redis key from the result - redis_key = cache_search_result.pop("id") - redis_keys.append(redis_key) - # Create and process cache hit - cache_hit = CacheHit(**cache_search_result) - cache_hit_dict = cache_hit.to_dict() - # Filter down to only selected return fields if needed - if isinstance(return_fields, list) and len(return_fields) > 0: - cache_hit_dict = { - k: v for k, v in cache_hit_dict.items() if k in return_fields - } - cache_hit_dict[REDIS_KEY_FIELD_NAME] = redis_key - cache_hits.append(cache_hit_dict) - return redis_keys, cache_hits - - def store( - self, - prompt: str, - response: str, - vector: Optional[List[float]] = None, - metadata: Optional[Dict[str, Any]] = None, - filters: Optional[Dict[str, Any]] = None, - ttl: Optional[int] = None, - ) -> str: - """Stores the specified key-value pair in the cache along with metadata. - - Args: - prompt (str): The user prompt to cache. - response (str): The LLM response to cache. - vector (Optional[List[float]], optional): The prompt vector to - cache. Defaults to None, and the prompt vector is generated on - demand. - metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache - alongside the prompt and response. Defaults to None. - filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. - Defaults to None. - ttl (Optional[int]): The optional TTL override to use on this individual cache - entry. Defaults to the global TTL setting. - - Returns: - str: The Redis key for the entries added to the semantic cache. - - Raises: - ValueError: If neither prompt nor vector is specified. - ValueError: if vector has incorrect dimensions. - TypeError: If provided metadata is not a dictionary. - - .. code-block:: python - - key = cache.store( - prompt="What is the captial city of France?", - response="Paris", - metadata={"city": "Paris", "country": "France"} - ) - """ - # Vectorize prompt if necessary and create cache payload - vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) - - # Build cache entry for the cache - cache_entry = CacheEntry( - prompt=prompt, - response=response, - prompt_vector=vector, - metadata=metadata, - filters=filters, - ) - - # Load cache entry with TTL - ttl = ttl or self._ttl - keys = self._index.load( - data=[cache_entry.to_dict(self._vectorizer.dtype)], - ttl=ttl, - id_field=ENTRY_ID_FIELD_NAME, - ) - return keys[0] - - async def astore( - self, - prompt: str, - response: str, - vector: Optional[List[float]] = None, - metadata: Optional[Dict[str, Any]] = None, - filters: Optional[Dict[str, Any]] = None, - ttl: Optional[int] = None, - ) -> str: - """Async stores the specified key-value pair in the cache along with metadata. - - Args: - prompt (str): The user prompt to cache. - response (str): The LLM response to cache. - vector (Optional[List[float]], optional): The prompt vector to - cache. Defaults to None, and the prompt vector is generated on - demand. - metadata (Optional[Dict[str, Any]], optional): The optional metadata to cache - alongside the prompt and response. Defaults to None. - filters (Optional[Dict[str, Any]]): The optional tag to assign to the cache entry. - Defaults to None. - ttl (Optional[int]): The optional TTL override to use on this individual cache - entry. Defaults to the global TTL setting. - - Returns: - str: The Redis key for the entries added to the semantic cache. - - Raises: - ValueError: If neither prompt nor vector is specified. - ValueError: if vector has incorrect dimensions. - TypeError: If provided metadata is not a dictionary. - - .. code-block:: python - - key = await cache.astore( - prompt="What is the captial city of France?", - response="Paris", - metadata={"city": "Paris", "country": "France"} - ) - """ - aindex = await self._get_async_index() - - # Vectorize prompt if necessary and create cache payload - vector = vector or self._vectorize_prompt(prompt) - self._check_vector_dims(vector) - - # Build cache entry for the cache - cache_entry = CacheEntry( - prompt=prompt, - response=response, - prompt_vector=vector, - metadata=metadata, - filters=filters, - ) - - # Load cache entry with TTL - ttl = ttl or self._ttl - keys = await aindex.load( - data=[cache_entry.to_dict(self._vectorizer.dtype)], - ttl=ttl, - id_field=ENTRY_ID_FIELD_NAME, - ) - return keys[0] - - def update(self, key: str, **kwargs) -> None: - """Update specific fields within an existing cache entry. If no fields - are passed, then only the document TTL is refreshed. - - Args: - key (str): the key of the document to update using kwargs. - - Raises: - ValueError if an incorrect mapping is provided as a kwarg. - TypeError if metadata is provided and not of type dict. - - .. code-block:: python - - key = cache.store('this is a prompt', 'this is a response') - cache.update(key, metadata={"hit_count": 1, "model_name": "Llama-2-7b"}) - ) - """ - if kwargs: - for k, v in kwargs.items(): - # Make sure the item is in the index schema - if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): - raise ValueError(f"{k} is not a valid field within the cache entry") - - # Check for metadata and deserialize - if k == METADATA_FIELD_NAME: - if isinstance(v, dict): - kwargs[k] = serialize(v) - else: - raise TypeError( - "If specified, cached metadata must be a dictionary." - ) - - kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) - - self._index.client.hset(key, mapping=kwargs) # type: ignore - - self._refresh_ttl(key) - - async def aupdate(self, key: str, **kwargs) -> None: - """Async update specific fields within an existing cache entry. If no fields - are passed, then only the document TTL is refreshed. - - Args: - key (str): the key of the document to update using kwargs. - - Raises: - ValueError if an incorrect mapping is provided as a kwarg. - TypeError if metadata is provided and not of type dict. - - .. code-block:: python - - key = await cache.astore('this is a prompt', 'this is a response') - await cache.aupdate( - key, - metadata={"hit_count": 1, "model_name": "Llama-2-7b"} - ) - """ - aindex = await self._get_async_index() - - if kwargs: - for k, v in kwargs.items(): - # Make sure the item is in the index schema - if k not in set(self._index.schema.field_names + [METADATA_FIELD_NAME]): - raise ValueError(f"{k} is not a valid field within the cache entry") - - # Check for metadata and deserialize - if k == METADATA_FIELD_NAME: - if isinstance(v, dict): - kwargs[k] = serialize(v) - else: - raise TypeError( - "If specified, cached metadata must be a dictionary." - ) - - kwargs.update({UPDATED_AT_FIELD_NAME: current_timestamp()}) - - await aindex.load(data=[kwargs], keys=[key]) - - await self._async_refresh_ttl(key) - - def disconnect(self): - if self._owns_redis_client is False: - return - if self._index: - self._index.disconnect() - if self._aindex: - self._aindex.disconnect_sync() - - async def adisconnect(self): - if not self._owns_redis_client: - return - if self._index: - self._index.disconnect() - if self._aindex: - await self._aindex.disconnect() - self._aindex = None - - async def __aenter__(self): - return self +warnings.warn( + "Importing from redisvl.extensions.llmcache.semantic is deprecated. " + "Please import from redisvl.extensions.cache.llm.semantic instead.", + DeprecationWarning, + stacklevel=2, +) - async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.adisconnect() +__all__ = ["SemanticCache"] diff --git a/redisvl/utils/optimize/cache.py b/redisvl/utils/optimize/cache.py index 05c99f76..be5a05c6 100644 --- a/redisvl/utils/optimize/cache.py +++ b/redisvl/utils/optimize/cache.py @@ -3,7 +3,7 @@ import numpy as np from ranx import Qrels, Run, evaluate -from redisvl.extensions.llmcache.semantic import SemanticCache +from redisvl.extensions.cache.llm.semantic import SemanticCache from redisvl.query import RangeQuery from redisvl.utils.optimize.base import BaseThresholdOptimizer, EvalMetric from redisvl.utils.optimize.schema import LabeledData @@ -76,7 +76,7 @@ class CacheThresholdOptimizer(BaseThresholdOptimizer): .. code-block:: python - from redisvl.extensions.llmcache import SemanticCache + from redisvl.extensions.cache.llm import SemanticCache from redisvl.utils.optimize import CacheThresholdOptimizer sem_cache = SemanticCache( diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index ab13f16f..3b6511f9 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -56,12 +56,12 @@ def validate_vector_dims(v1: int, v2: int) -> None: ) -def serialize(data: Dict[str, Any]) -> str: +def serialize(data: Any) -> str: """Serlize the input into a string.""" return json.dumps(data) -def deserialize(data: str) -> Dict[str, Any]: +def deserialize(data: str) -> Any: """Deserialize the input from a string.""" return json.loads(data) diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 189b6e1a..83e58fe4 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,4 +1,3 @@ -from abc import ABC, abstractmethod from enum import Enum from typing import Callable, List, Optional, Union @@ -18,7 +17,7 @@ class Vectorizers(Enum): voyageai = "voyageai" -class BaseVectorizer(BaseModel, ABC): +class BaseVectorizer(BaseModel): """Base vectorizer interface.""" model: str @@ -48,7 +47,6 @@ def check_dims(cls, value): raise ValueError("Dims must be a positive integer.") return value - @abstractmethod def embed( self, text: str, @@ -69,7 +67,6 @@ def embed( """ raise NotImplementedError - @abstractmethod def embed_many( self, texts: List[str], diff --git a/tests/integration/test_embedcache.py b/tests/integration/test_embedcache.py new file mode 100644 index 00000000..45150989 --- /dev/null +++ b/tests/integration/test_embedcache.py @@ -0,0 +1,776 @@ +import asyncio +import json +import time +from typing import Any, Dict, List, Optional + +import pytest +from redis.exceptions import ConnectionError + +from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache +from redisvl.redis.utils import hashify + + +@pytest.fixture +def cache(redis_url): + """Basic EmbeddingsCache fixture with cleanup.""" + cache_instance = EmbeddingsCache( + name="test_embed_cache", + redis_url=redis_url, + ) + yield cache_instance + # Clean up all keys with this prefix + cache_instance.clear() + + +@pytest.fixture +def cache_with_ttl(redis_url): + """EmbeddingsCache with TTL setting.""" + cache_instance = EmbeddingsCache( + name="test_ttl_cache", + ttl=2, # 2 second TTL for testing expiration + redis_url=redis_url, + ) + yield cache_instance + # Clean up all keys with this prefix + cache_instance.clear() + + +@pytest.fixture +def cache_with_redis_client(client): + """EmbeddingsCache with provided Redis client.""" + cache_instance = EmbeddingsCache( + name="test_client_cache", + redis_client=client, + ) + yield cache_instance + # Clean up all keys with this prefix + cache_instance.clear() + + +@pytest.fixture +def sample_embedding_data(): + """Sample data for embedding cache tests.""" + return [ + { + "text": "What is machine learning?", + "model_name": "text-embedding-ada-002", + "embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "metadata": {"source": "user_query", "category": "ai"}, + }, + { + "text": "How do neural networks work?", + "model_name": "text-embedding-ada-002", + "embedding": [0.2, 0.3, 0.4, 0.5, 0.6], + "metadata": {"source": "documentation", "category": "ai"}, + }, + { + "text": "What's the weather like today?", + "model_name": "text-embedding-ada-002", + "embedding": [0.5, 0.6, 0.7, 0.8, 0.9], + "metadata": {"source": "user_query", "category": "weather"}, + }, + ] + + +def test_cache_initialization(redis_url): + """Test that the cache can be initialized with different parameters.""" + # Default initialization + cache1 = EmbeddingsCache() + assert cache1.name == "embedcache" + assert cache1.ttl is None + + # Custom name and TTL + cache2 = EmbeddingsCache(name="custom_cache", ttl=60, redis_url=redis_url) + assert cache2.name == "custom_cache" + assert cache2.ttl == 60 + + # With redis client + cache3 = EmbeddingsCache(redis_url=redis_url) + client = cache3._get_redis_client() + cache4 = EmbeddingsCache(redis_client=client) + assert cache4._redis_client is client + + +def test_make_entry_id(): + """Test that entry IDs are generated consistently.""" + cache = EmbeddingsCache() + text = "Hello world" + model_name = "text-embedding-ada-002" + + # Test deterministic ID generation + entry_id1 = cache._make_entry_id(text, model_name) + entry_id2 = cache._make_entry_id(text, model_name) + assert entry_id1 == entry_id2 + + # Test different inputs produce different IDs + different_id = cache._make_entry_id("Different text", model_name) + assert entry_id1 != different_id + + # Test ID format + assert isinstance(entry_id1, str) + expected_id = hashify(f"{text}:{model_name}") + assert entry_id1 == expected_id + + +def test_make_cache_key(): + """Test that cache keys are constructed properly.""" + cache = EmbeddingsCache(name="test_cache") + text = "Hello world" + model_name = "text-embedding-ada-002" + + # Test key construction + key = cache._make_cache_key(text, model_name) + entry_id = cache._make_entry_id(text, model_name) + expected_key = f"test_cache:{entry_id}" + assert key == expected_key + + # Test with different cache name + cache2 = EmbeddingsCache(name="different_cache") + key2 = cache2._make_cache_key(text, model_name) + assert key2 == f"different_cache:{entry_id}" + + # Make sure keys are unique for different inputs + different_key = cache._make_cache_key("Different text", model_name) + assert key != different_key + + +def test_set_and_get(cache, sample_embedding_data): + """Test setting and retrieving entries from the cache.""" + sample = sample_embedding_data[0] + + # Set the entry + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + metadata=sample["metadata"], + ) + + # Get the entry + result = cache.get(sample["text"], sample["model_name"]) + + # Verify the result + assert result is not None + assert result["text"] == sample["text"] + assert result["model_name"] == sample["model_name"] + assert "embedding" in result + assert result["metadata"] == sample["metadata"] + + # Test get_by_key + key_result = cache.get_by_key(key) + assert key_result is not None + assert key_result["text"] == sample["text"] + + # Test non-existent entry + missing = cache.get("NonexistentText", sample["model_name"]) + assert missing is None + + # Test non-existent key + missing_key = cache.get_by_key("nonexistent:key") + assert missing_key is None + + +def test_exists(cache, sample_embedding_data): + """Test checking if entries exist in the cache.""" + sample = sample_embedding_data[0] + + # Entry shouldn't exist yet + assert not cache.exists(sample["text"], sample["model_name"]) + + # Add the entry + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Now it should exist + assert cache.exists(sample["text"], sample["model_name"]) + + # Test exists_by_key + assert cache.exists_by_key(key) + + # Non-existent entries + assert not cache.exists("NonexistentText", sample["model_name"]) + assert not cache.exists_by_key("nonexistent:key") + + +def test_drop(cache, sample_embedding_data): + """Test removing entries from the cache.""" + sample = sample_embedding_data[0] + + # Add the entry + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Verify it exists + assert cache.exists_by_key(key) + + # Remove it + cache.drop(sample["text"], sample["model_name"]) + + # Verify it's gone + assert not cache.exists_by_key(key) + + # Test drop_by_key + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + cache.drop_by_key(key) + assert not cache.exists_by_key(key) + + +def test_ttl_expiration(cache_with_ttl, sample_embedding_data): + """Test that entries expire after TTL.""" + sample = sample_embedding_data[0] + + # Add the entry + key = cache_with_ttl.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Verify it exists + assert cache_with_ttl.exists_by_key(key) + + # Wait for it to expire (TTL is 2 seconds) + time.sleep(3) + + # Verify it's gone + assert not cache_with_ttl.exists_by_key(key) + + +def test_custom_ttl(cache, sample_embedding_data): + """Test setting a custom TTL for a specific entry.""" + sample = sample_embedding_data[0] + + # Add the entry with a 1 second TTL + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ttl=1, + ) + + # Verify it exists + assert cache.exists_by_key(key) + + # Wait for it to expire + time.sleep(2) + + # Verify it's gone + assert not cache.exists_by_key(key) + + +def test_multiple_entries(cache, sample_embedding_data): + """Test storing and retrieving multiple entries.""" + # Store all samples + keys = [] + for sample in sample_embedding_data: + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + metadata=sample.get("metadata"), + ) + keys.append(key) + + # Check they all exist + for i, key in enumerate(keys): + assert cache.exists_by_key(key) + result = cache.get_by_key(key) + assert result["text"] == sample_embedding_data[i]["text"] + + # Drop one entry + cache.drop_by_key(keys[0]) + assert not cache.exists_by_key(keys[0]) + assert cache.exists_by_key(keys[1]) + + +@pytest.mark.asyncio +async def test_async_set_and_get(cache, sample_embedding_data): + """Test async versions of set and get.""" + sample = sample_embedding_data[0] + + # Set the entry + key = await cache.aset( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + metadata=sample["metadata"], + ) + + # Get the entry + result = await cache.aget(sample["text"], sample["model_name"]) + + # Verify the result + assert result is not None + assert result["text"] == sample["text"] + assert result["model_name"] == sample["model_name"] + assert "embedding" in result + assert result["metadata"] == sample["metadata"] + + # Test aget_by_key + key_result = await cache.aget_by_key(key) + assert key_result is not None + assert key_result["text"] == sample["text"] + + +@pytest.mark.asyncio +async def test_async_exists(cache, sample_embedding_data): + """Test async version of exists.""" + sample = sample_embedding_data[0] + + # Entry shouldn't exist yet + assert not await cache.aexists(sample["text"], sample["model_name"]) + + # Add the entry + key = await cache.aset( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Now it should exist + assert await cache.aexists(sample["text"], sample["model_name"]) + + # Test aexists_by_key + assert await cache.aexists_by_key(key) + + +@pytest.mark.asyncio +async def test_async_drop(cache, sample_embedding_data): + """Test async version of drop.""" + sample = sample_embedding_data[0] + + # Add the entry + key = await cache.aset( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Verify it exists + assert await cache.aexists_by_key(key) + + # Remove it + await cache.adrop(sample["text"], sample["model_name"]) + + # Verify it's gone + assert not await cache.aexists_by_key(key) + + # Test adrop_by_key + key = await cache.aset( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + await cache.adrop_by_key(key) + assert not await cache.aexists_by_key(key) + + +@pytest.mark.asyncio +async def test_async_ttl_expiration(cache_with_ttl, sample_embedding_data): + """Test that entries expire after TTL in async mode.""" + sample = sample_embedding_data[0] + + # Add the entry + key = await cache_with_ttl.aset( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Verify it exists + assert await cache_with_ttl.aexists_by_key(key) + + # Wait for it to expire (TTL is 2 seconds) + await asyncio.sleep(3) + + # Verify it's gone + assert not await cache_with_ttl.aexists_by_key(key) + + +def test_entry_id_consistency(cache, sample_embedding_data): + """Test that entry IDs are consistent between operations.""" + sample = sample_embedding_data[0] + + # Generate an entry ID directly + expected_id = cache._make_entry_id(sample["text"], sample["model_name"]) + + # Set an entry and extract its ID from the key + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + # Key should be cache_name:entry_id + parts = key.split(":") + actual_id = parts[1] + + # IDs should match + assert actual_id == expected_id + + # Get the entry and check its ID + result = cache.get_by_key(key) + assert result["entry_id"] == expected_id + + +def test_redis_client_reuse(cache_with_redis_client, sample_embedding_data): + """Test using the cache with a provided Redis client.""" + sample = sample_embedding_data[0] + + # Set and get an entry + key = cache_with_redis_client.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + + result = cache_with_redis_client.get_by_key(key) + assert result is not None + assert result["text"] == sample["text"] + + +def test_mset_and_mget(cache, sample_embedding_data): + """Test batch setting and getting of embeddings.""" + # Prepare batch items + batch_items = [] + for sample in sample_embedding_data: + batch_items.append( + { + "text": sample["text"], + "model_name": sample["model_name"], + "embedding": sample["embedding"], + "metadata": sample.get("metadata"), + } + ) + + # Use mset to store embeddings + keys = cache.mset(batch_items) + assert len(keys) == len(batch_items) + + # Get texts and model name for mget + texts = [item["text"] for item in batch_items] + model_name = batch_items[0]["model_name"] # Assuming same model + + # Test mget + results = cache.mget(texts, model_name) + assert len(results) == len(texts) + + # Verify all results are returned and in correct order + for i, result in enumerate(results): + assert result is not None + assert result["text"] == texts[i] + assert result["model_name"] == model_name + + +def test_mget_by_keys(cache, sample_embedding_data): + """Test getting multiple embeddings by their keys.""" + # Set embeddings individually and collect keys + keys = [] + for sample in sample_embedding_data: + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + metadata=sample.get("metadata"), + ) + keys.append(key) + + # Test mget_by_keys + results = cache.mget_by_keys(keys) + assert len(results) == len(keys) + + # Verify all results match the original samples + for i, result in enumerate(results): + assert result is not None + assert result["text"] == sample_embedding_data[i]["text"] + assert result["model_name"] == sample_embedding_data[i]["model_name"] + + # Test with mix of existing and non-existing keys + non_existent_key = "test_embed_cache:nonexistent" + mixed_keys = keys[:1] + [non_existent_key] + keys[1:] + mixed_results = cache.mget_by_keys(mixed_keys) + + assert len(mixed_results) == len(mixed_keys) + assert mixed_results[0] is not None + assert mixed_results[1] is None # Non-existent key should return None + assert mixed_results[2] is not None + + +def test_mexists_and_mexists_by_keys(cache, sample_embedding_data): + """Test batch existence checks for embeddings.""" + # Set embeddings individually and collect data + keys = [] + texts = [] + for sample in sample_embedding_data: + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + keys.append(key) + texts.append(sample["text"]) + + model_name = sample_embedding_data[0]["model_name"] # Assuming same model + + # Test mexists + exist_results = cache.mexists(texts, model_name) + assert len(exist_results) == len(texts) + assert all(exist_results) # All should exist + + # Test with mix of existing and non-existing texts + non_existent_text = "This text does not exist" + mixed_texts = texts[:1] + [non_existent_text] + texts[1:] + mixed_results = cache.mexists(mixed_texts, model_name) + + assert len(mixed_results) == len(mixed_texts) + assert mixed_results[0] is True + assert mixed_results[1] is False # Non-existent text should return False + assert mixed_results[2] is True + + # Test mexists_by_keys + key_exist_results = cache.mexists_by_keys(keys) + assert len(key_exist_results) == len(keys) + assert all(key_exist_results) # All should exist + + # Test with mix of existing and non-existing keys + non_existent_key = "test_embed_cache:nonexistent" + mixed_keys = keys[:1] + [non_existent_key] + keys[1:] + mixed_key_results = cache.mexists_by_keys(mixed_keys) + + assert len(mixed_key_results) == len(mixed_keys) + assert mixed_key_results[0] is True + assert mixed_key_results[1] is False # Non-existent key should return False + assert mixed_key_results[2] is True + + +def test_mdrop_and_mdrop_by_keys(cache, sample_embedding_data): + """Test batch deletion of embeddings.""" + # Set embeddings and collect data + keys = [] + texts = [] + for sample in sample_embedding_data: + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + keys.append(key) + texts.append(sample["text"]) + + model_name = sample_embedding_data[0]["model_name"] # Assuming same model + + # Test mdrop_by_keys with subset of keys + subset_keys = keys[:2] + cache.mdrop_by_keys(subset_keys) + + # Verify only selected keys were dropped + for i, key in enumerate(keys): + if i < 2: + assert not cache.exists_by_key(key) # Should be dropped + else: + assert cache.exists_by_key(key) # Should still exist + + # Reset for mdrop test + cache.clear() + keys = [] + texts = [] + for sample in sample_embedding_data: + key = cache.set( + text=sample["text"], + model_name=sample["model_name"], + embedding=sample["embedding"], + ) + keys.append(key) + texts.append(sample["text"]) + + # Test mdrop with subset of texts + subset_texts = texts[:2] + cache.mdrop(subset_texts, model_name) + + # Verify only selected texts were dropped + for i, text in enumerate(texts): + if i < 2: + assert not cache.exists(text, model_name) # Should be dropped + else: + assert cache.exists(text, model_name) # Should still exist + + +@pytest.mark.asyncio +async def test_async_batch_operations(cache, sample_embedding_data): + """Test async batch operations (amset, amget, amexists, amdrop).""" + # Prepare batch items + batch_items = [] + for sample in sample_embedding_data: + batch_items.append( + { + "text": sample["text"], + "model_name": sample["model_name"], + "embedding": sample["embedding"], + "metadata": sample.get("metadata"), + } + ) + + # Use amset to store embeddings + keys = await cache.amset(batch_items) + assert len(keys) == len(batch_items) + + # Get texts and model name for amget + texts = [item["text"] for item in batch_items] + model_name = batch_items[0]["model_name"] # Assuming same model + + # Test amget + results = await cache.amget(texts, model_name) + assert len(results) == len(texts) + for i, result in enumerate(results): + assert result is not None + assert result["text"] == texts[i] + + # Test amget_by_keys + key_results = await cache.amget_by_keys(keys) + assert len(key_results) == len(keys) + for result in key_results: + assert result is not None + + # Test amexists + exist_results = await cache.amexists(texts, model_name) + assert len(exist_results) == len(texts) + assert all(exist_results) # All should exist + + # Test amexists_by_keys + key_exist_results = await cache.amexists_by_keys(keys) + assert len(key_exist_results) == len(keys) + assert all(key_exist_results) # All should exist + + # Test amdrop with first text + await cache.amdrop([texts[0]], model_name) + updated_exists = await cache.aexists(texts[0], model_name) + assert not updated_exists # Should be dropped + + # Test amdrop_by_keys with second key + await cache.amdrop_by_keys([keys[1]]) + updated_key_exists = await cache.aexists_by_key(keys[1]) + assert not updated_key_exists # Should be dropped + + +def test_batch_operations_with_missing_data(cache): + """Test batch operations with empty lists and missing cache entries.""" + # Test with empty lists + assert cache.mget_by_keys([]) == [] + assert cache.mexists_by_keys([]) == [] + cache.mdrop_by_keys([]) # Should not raise errors + + # Test mget with non-existent keys + non_existent_keys = [ + "test_embed_cache:nonexistent1", + "test_embed_cache:nonexistent2", + ] + results = cache.mget_by_keys(non_existent_keys) + assert len(results) == 2 + assert results[0] is None + assert results[1] is None + + # Test mexists with non-existent keys + exist_results = cache.mexists_by_keys(non_existent_keys) + assert len(exist_results) == 2 + assert not any(exist_results) # None should exist + + # Test with empty model names and texts + assert cache.mget([], "model") == [] + assert cache.mexists([], "model") == [] + cache.mdrop([], "model") # Should not raise errors + + +def test_batch_with_ttl(cache_with_ttl, sample_embedding_data): + """Test batch operations with TTL.""" + # Prepare batch items + batch_items = [] + for sample in sample_embedding_data: + batch_items.append( + { + "text": sample["text"], + "model_name": sample["model_name"], + "embedding": sample["embedding"], + "metadata": sample.get("metadata"), + } + ) + + # Store with default TTL (2 seconds from fixture) + keys = cache_with_ttl.mset(batch_items) + + # Verify all exist initially + exist_results = cache_with_ttl.mexists_by_keys(keys) + assert all(exist_results) + + # Wait for TTL to expire + time.sleep(3) + + # Verify all have expired + exist_results_after = cache_with_ttl.mexists_by_keys(keys) + assert not any(exist_results_after) + + # Test with custom TTL override + keys = cache_with_ttl.mset(batch_items, ttl=5) # 5 second TTL + + # Wait for 3 seconds (beyond default but before custom TTL) + time.sleep(3) + + # Should still exist with custom TTL + exist_results = cache_with_ttl.mexists_by_keys(keys) + assert all(exist_results) + + +def test_large_batch_operations(cache): + """Test operations with larger batches to ensure scalability.""" + # Create a larger batch of items + large_batch = [] + for i in range(100): + large_batch.append( + { + "text": f"Sample text {i}", + "model_name": "test-model", + "embedding": [float(i) / 100] * 5, + "metadata": {"index": i}, + } + ) + + # Test storing large batch + keys = cache.mset(large_batch) + assert len(keys) == 100 + + # Test retrieving large batch by keys + results = cache.mget_by_keys(keys) + assert len(results) == 100 + assert all(result is not None for result in results) + + # Get texts for batch retrieval + texts = [item["text"] for item in large_batch] + + # Test retrieving by texts + results = cache.mget(texts, "test-model") + assert len(results) == 100 + assert all(result is not None for result in results) + + # Test existence checks + exist_results = cache.mexists_by_keys(keys) + assert len(exist_results) == 100 + assert all(exist_results) + + # Test batch deletion + cache.mdrop_by_keys(keys[:50]) # Delete first half + + # Verify first half deleted, second half still exists + for i, key in enumerate(keys): + if i < 50: + assert not cache.exists_by_key(key) + else: + assert cache.exists_by_key(key) diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index b489caf9..526709c9 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -9,7 +9,7 @@ from redis.exceptions import ConnectionError from redisvl.exceptions import RedisModuleVersionError -from redisvl.extensions.llmcache import SemanticCache +from redisvl.extensions.cache import SemanticCache from redisvl.index.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer @@ -79,6 +79,12 @@ def disable_deprecation_warnings(): yield +def test_llmcache_backwards_compat(): + from redisvl.extensions.llmcache import SemanticCache as DeprecatedSemanticCache + + assert DeprecatedSemanticCache == SemanticCache + + def test_bad_ttl(cache): with pytest.raises(ValueError): cache.set_ttl(2.5) diff --git a/tests/integration/test_threshold_optimizer.py b/tests/integration/test_threshold_optimizer.py index e227af98..5242fd4f 100644 --- a/tests/integration/test_threshold_optimizer.py +++ b/tests/integration/test_threshold_optimizer.py @@ -5,7 +5,7 @@ if sys.version_info.major == 3 and sys.version_info.minor < 10: pytest.skip("Test requires Python 3.10 or higher", allow_module_level=True) -from redisvl.extensions.llmcache import SemanticCache +from redisvl.extensions.cache.llm import SemanticCache from redisvl.extensions.router import Route, SemanticRouter from redisvl.extensions.router.schema import RoutingConfig from redisvl.redis.connection import compare_versions diff --git a/tests/unit/test_embedcache_schema.py b/tests/unit/test_embedcache_schema.py new file mode 100644 index 00000000..a3d109da --- /dev/null +++ b/tests/unit/test_embedcache_schema.py @@ -0,0 +1,109 @@ +import json + +import pytest +from pydantic import ValidationError + +from redisvl.extensions.cache.embeddings.schema import CacheEntry +from redisvl.redis.utils import hashify + + +def test_valid_cache_entry_creation(): + # Generate an entry_id first + entry_id = hashify(f"What is AI?:text-embedding-ada-002") + entry = CacheEntry( + entry_id=entry_id, + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + ) + assert entry.entry_id == entry_id + assert entry.text == "What is AI?" + assert entry.model_name == "text-embedding-ada-002" + assert entry.embedding == [0.1, 0.2, 0.3] + + +def test_cache_entry_with_given_entry_id(): + entry = CacheEntry( + entry_id="custom_id", + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + ) + assert entry.entry_id == "custom_id" + + +def test_cache_entry_with_invalid_metadata(): + with pytest.raises(ValidationError): + CacheEntry( + entry_id="test_id", + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + metadata="invalid_metadata", + ) + + +def test_cache_entry_to_dict(): + entry_id = hashify(f"What is AI?:text-embedding-ada-002") + entry = CacheEntry( + entry_id=entry_id, + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + metadata={"author": "John"}, + ) + result = entry.to_dict() + assert result["entry_id"] == entry_id + assert result["text"] == "What is AI?" + assert result["model_name"] == "text-embedding-ada-002" + assert isinstance("embedding", str) + assert isinstance("metadata", str) + assert result["metadata"] == json.dumps({"author": "John"}) + + +def test_cache_entry_deserialization(): + """Test that a CacheEntry properly deserializes data from Redis format.""" + serialized_data = { + "entry_id": "test_id", + "text": "What is AI?", + "model_name": "text-embedding-ada-002", + "embedding": json.dumps([0.1, 0.2, 0.3]), # Serialized embedding + "metadata": json.dumps({"source": "user_query"}), # Serialized metadata + "inserted_at": 1625819123.123, + } + + entry = CacheEntry(**serialized_data) + assert entry.entry_id == "test_id" + assert entry.text == "What is AI?" + assert entry.model_name == "text-embedding-ada-002" + assert entry.embedding == [0.1, 0.2, 0.3] # Should be deserialized + assert entry.metadata == {"source": "user_query"} # Should be deserialized + assert entry.inserted_at == 1625819123.123 + + +def test_cache_entry_with_empty_optional_fields(): + entry = CacheEntry( + entry_id="test_id", + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + ) + result = entry.to_dict() + assert "metadata" not in result # Empty metadata should be excluded + + +def test_cache_entry_timestamp_generation(): + """Test that inserted_at timestamp is automatically generated.""" + entry = CacheEntry( + entry_id="test_id", + text="What is AI?", + model_name="text-embedding-ada-002", + embedding=[0.1, 0.2, 0.3], + ) + assert hasattr(entry, "inserted_at") + assert isinstance(entry.inserted_at, float) + + # The timestamp should be included in the dict representation + result = entry.to_dict() + assert "inserted_at" in result + assert isinstance(result["inserted_at"], float) diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index 72f230fc..17e51cc1 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -3,7 +3,7 @@ import pytest from pydantic import ValidationError -from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit +from redisvl.extensions.cache.llm import CacheEntry, CacheHit from redisvl.redis.utils import array_to_buffer, hashify