Skip to content

Enable embedding caching on all vectorizers #320

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 47 additions & 71 deletions docs/user_guide/10_embeddings_cache.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,23 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/tyler.hutcherson/Library/Caches/pypoetry/virtualenvs/redisvl-VnTEShF2-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n",
"Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. Falling back to non-compiled mode.\n"
]
}
],
"source": [
"# Initialize the vectorizer\n",
"vectorizer = HFTextVectorizer(\n",
" model=\"sentence-transformers/all-mpnet-base-v2\",\n",
" model=\"redis/langcache-embed-v1\",\n",
" cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n",
")"
]
Expand Down Expand Up @@ -103,21 +113,21 @@
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored with key: embedcache:059d...\n"
"Stored with key: embedcache:909f...\n"
]
}
],
"source": [
"# Text to embed\n",
"text = \"What is machine learning?\"\n",
"model_name = \"sentence-transformers/all-mpnet-base-v2\"\n",
"model_name = \"redis/langcache-embed-v1\"\n",
"\n",
"# Generate the embedding\n",
"embedding = vectorizer.embed(text)\n",
Expand Down Expand Up @@ -147,15 +157,15 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Found in cache: What is machine learning?\n",
"Model: sentence-transformers/all-mpnet-base-v2\n",
"Model: redis/langcache-embed-v1\n",
"Metadata: {'category': 'ai', 'source': 'user_query'}\n",
"Embedding shape: (768,)\n"
]
Expand Down Expand Up @@ -184,7 +194,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 7,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -218,7 +228,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -251,14 +261,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Stored with key: embedcache:059d...\n",
"Stored with key: embedcache:909f...\n",
"Exists by key: True\n",
"Retrieved by key: What is machine learning?\n"
]
Expand Down Expand Up @@ -297,7 +307,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 10,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -382,7 +392,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 11,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -430,7 +440,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 12,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -484,7 +494,7 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 13,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -533,18 +543,13 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Computing embedding for: What is artificial intelligence?\n",
"Computing embedding for: How does machine learning work?\n",
"Found in cache: What is artificial intelligence?\n",
"Computing embedding for: What are neural networks?\n",
"Found in cache: How does machine learning work?\n",
"\n",
"Statistics:\n",
"Total queries: 5\n",
Expand All @@ -562,25 +567,11 @@
" ttl=3600 # 1 hour TTL\n",
")\n",
"\n",
"# Function to get embedding with caching\n",
"def get_cached_embedding(text, model_name):\n",
" # Check if it's in the cache first\n",
" if cached_result := example_cache.get(text=text, model_name=model_name):\n",
" print(f\"Found in cache: {text}\")\n",
" return cached_result[\"embedding\"]\n",
" \n",
" # Not in cache, compute the embedding\n",
" print(f\"Computing embedding for: {text}\")\n",
" embedding = vectorizer.embed(text)\n",
" \n",
" # Store in cache\n",
" example_cache.set(\n",
" text=text,\n",
" model_name=model_name,\n",
" embedding=embedding,\n",
" )\n",
" \n",
" return embedding\n",
"vectorizer = HFTextVectorizer(\n",
" model=model_name,\n",
" cache=example_cache,\n",
" cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n",
")\n",
"\n",
"# Simulate processing a stream of queries\n",
"queries = [\n",
Expand All @@ -604,7 +595,7 @@
" cache_hits += 1\n",
" \n",
" # Get embedding (will compute or use cache)\n",
" embedding = get_cached_embedding(query, model_name)\n",
" embedding = vectorizer.embed(query)\n",
"\n",
"# Report statistics\n",
"cache_misses = total_queries - cache_hits\n",
Expand Down Expand Up @@ -632,72 +623,57 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Benchmarking without caching:\n",
"Time taken without caching: 0.0940 seconds\n",
"Average time per embedding: 0.0094 seconds\n",
"Time taken without caching: 0.4735 seconds\n",
"Average time per embedding: 0.0474 seconds\n",
"\n",
"Benchmarking with caching:\n",
"Time taken with caching: 0.0237 seconds\n",
"Average time per embedding: 0.0024 seconds\n",
"Time taken with caching: 0.0663 seconds\n",
"Average time per embedding: 0.0066 seconds\n",
"\n",
"Performance comparison:\n",
"Speedup with caching: 3.96x faster\n",
"Time saved: 0.0703 seconds (74.8%)\n",
"Latency reduction: 0.0070 seconds per query\n"
"Speedup with caching: 7.14x faster\n",
"Time saved: 0.4073 seconds (86.0%)\n",
"Latency reduction: 0.0407 seconds per query\n"
]
}
],
"source": [
"# Text to use for benchmarking\n",
"benchmark_text = \"This is a benchmark text to measure the performance of embedding caching.\"\n",
"benchmark_model = \"sentence-transformers/all-mpnet-base-v2\"\n",
"\n",
"# Create a fresh cache for benchmarking\n",
"benchmark_cache = EmbeddingsCache(\n",
" name=\"benchmark_cache\",\n",
" redis_url=\"redis://localhost:6379\",\n",
" ttl=3600 # 1 hour TTL\n",
")\n",
"\n",
"# Function to get embeddings without caching\n",
"def get_embedding_without_cache(text, model_name):\n",
" return vectorizer.embed(text)\n",
"\n",
"# Function to get embeddings with caching\n",
"def get_embedding_with_cache(text, model_name):\n",
" if cached_result := benchmark_cache.get(text=text, model_name=model_name):\n",
" return cached_result[\"embedding\"]\n",
" \n",
" embedding = vectorizer.embed(text)\n",
" benchmark_cache.set(\n",
" text=text,\n",
" model_name=model_name,\n",
" embedding=embedding\n",
" )\n",
" return embedding\n",
"vectorizer.cache = benchmark_cache\n",
"\n",
"# Number of iterations for the benchmark\n",
"n_iterations = 10\n",
"\n",
"# Benchmark without caching\n",
"print(\"Benchmarking without caching:\")\n",
"start_time = time.time()\n",
"get_embedding_without_cache(benchmark_text, benchmark_model)\n",
"for _ in range(n_iterations):\n",
" embedding = vectorizer.embed(text, skip_cache=True)\n",
"no_cache_time = time.time() - start_time\n",
"print(f\"Time taken without caching: {no_cache_time:.4f} seconds\")\n",
"print(f\"Average time per embedding: {no_cache_time/n_iterations:.4f} seconds\")\n",
"\n",
"# Benchmark with caching\n",
"print(\"\\nBenchmarking with caching:\")\n",
"start_time = time.time()\n",
"get_embedding_with_cache(benchmark_text, benchmark_model)\n",
"for _ in range(n_iterations):\n",
" embedding = vectorizer.embed(text)\n",
"cache_time = time.time() - start_time\n",
"print(f\"Time taken with caching: {cache_time:.4f} seconds\")\n",
"print(f\"Average time per embedding: {cache_time/n_iterations:.4f} seconds\")\n",
Expand Down Expand Up @@ -785,7 +761,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.13.2"
}
},
"nbformat": 4,
Expand Down
8 changes: 4 additions & 4 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 1 addition & 3 deletions redisvl/extensions/cache/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,5 @@
"""

from redisvl.extensions.cache.base import BaseCache
from redisvl.extensions.cache.embeddings import EmbeddingsCache
from redisvl.extensions.cache.llm import SemanticCache

__all__ = ["BaseCache", "EmbeddingsCache", "SemanticCache"]
__all__ = ["BaseCache"]
15 changes: 11 additions & 4 deletions redisvl/extensions/cache/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from redis import Redis
from redis.asyncio import Redis as AsyncRedis

from redisvl.redis.connection import RedisConnectionFactory


class BaseCache:
"""Base abstract cache interface for all RedisVL caches.
Expand Down Expand Up @@ -121,10 +123,15 @@ async def _get_async_redis_client(self) -> AsyncRedis:
AsyncRedis: An async Redis client instance.
"""
if not hasattr(self, "_async_redis_client") or self._async_redis_client is None:
# Create new async Redis client
url = self.redis_kwargs["redis_url"]
kwargs = self.redis_kwargs["connection_kwargs"]
self._async_redis_client = AsyncRedis.from_url(url, **kwargs) # type: ignore
client = self.redis_kwargs.get("redis_client")
if isinstance(client, Redis):
self._async_redis_client = RedisConnectionFactory.sync_to_async_redis(
client
)
else:
url = self.redis_kwargs["redis_url"]
kwargs = self.redis_kwargs["connection_kwargs"]
self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore
return self._async_redis_client

def expire(self, key: str, ttl: Optional[int] = None) -> None:
Expand Down
Loading
Loading