diff --git a/docs/user_guide/10_embeddings_cache.ipynb b/docs/user_guide/10_embeddings_cache.ipynb index d5a90096..b4deb489 100644 --- a/docs/user_guide/10_embeddings_cache.ipynb +++ b/docs/user_guide/10_embeddings_cache.ipynb @@ -51,13 +51,23 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/tyler.hutcherson/Library/Caches/pypoetry/virtualenvs/redisvl-VnTEShF2-py3.13/lib/python3.13/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n", + "Compiling the model with `torch.compile` and using a `torch.mps` device is not supported. Falling back to non-compiled mode.\n" + ] + } + ], "source": [ "# Initialize the vectorizer\n", "vectorizer = HFTextVectorizer(\n", - " model=\"sentence-transformers/all-mpnet-base-v2\",\n", + " model=\"redis/langcache-embed-v1\",\n", " cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n", ")" ] @@ -103,21 +113,21 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Stored with key: embedcache:059d...\n" + "Stored with key: embedcache:909f...\n" ] } ], "source": [ "# Text to embed\n", "text = \"What is machine learning?\"\n", - "model_name = \"sentence-transformers/all-mpnet-base-v2\"\n", + "model_name = \"redis/langcache-embed-v1\"\n", "\n", "# Generate the embedding\n", "embedding = vectorizer.embed(text)\n", @@ -147,7 +157,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -155,7 +165,7 @@ "output_type": "stream", "text": [ "Found in cache: What is machine learning?\n", - "Model: sentence-transformers/all-mpnet-base-v2\n", + "Model: redis/langcache-embed-v1\n", "Metadata: {'category': 'ai', 'source': 'user_query'}\n", "Embedding shape: (768,)\n" ] @@ -184,7 +194,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -218,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -251,14 +261,14 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Stored with key: embedcache:059d...\n", + "Stored with key: embedcache:909f...\n", "Exists by key: True\n", "Retrieved by key: What is machine learning?\n" ] @@ -297,7 +307,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -382,7 +392,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [ { @@ -430,7 +440,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -484,7 +494,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -533,18 +543,13 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Computing embedding for: What is artificial intelligence?\n", - "Computing embedding for: How does machine learning work?\n", - "Found in cache: What is artificial intelligence?\n", - "Computing embedding for: What are neural networks?\n", - "Found in cache: How does machine learning work?\n", "\n", "Statistics:\n", "Total queries: 5\n", @@ -562,25 +567,11 @@ " ttl=3600 # 1 hour TTL\n", ")\n", "\n", - "# Function to get embedding with caching\n", - "def get_cached_embedding(text, model_name):\n", - " # Check if it's in the cache first\n", - " if cached_result := example_cache.get(text=text, model_name=model_name):\n", - " print(f\"Found in cache: {text}\")\n", - " return cached_result[\"embedding\"]\n", - " \n", - " # Not in cache, compute the embedding\n", - " print(f\"Computing embedding for: {text}\")\n", - " embedding = vectorizer.embed(text)\n", - " \n", - " # Store in cache\n", - " example_cache.set(\n", - " text=text,\n", - " model_name=model_name,\n", - " embedding=embedding,\n", - " )\n", - " \n", - " return embedding\n", + "vectorizer = HFTextVectorizer(\n", + " model=model_name,\n", + " cache=example_cache,\n", + " cache_folder=os.getenv(\"SENTENCE_TRANSFORMERS_HOME\")\n", + ")\n", "\n", "# Simulate processing a stream of queries\n", "queries = [\n", @@ -604,7 +595,7 @@ " cache_hits += 1\n", " \n", " # Get embedding (will compute or use cache)\n", - " embedding = get_cached_embedding(query, model_name)\n", + " embedding = vectorizer.embed(query)\n", "\n", "# Report statistics\n", "cache_misses = total_queries - cache_hits\n", @@ -632,7 +623,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -640,24 +631,23 @@ "output_type": "stream", "text": [ "Benchmarking without caching:\n", - "Time taken without caching: 0.0940 seconds\n", - "Average time per embedding: 0.0094 seconds\n", + "Time taken without caching: 0.4735 seconds\n", + "Average time per embedding: 0.0474 seconds\n", "\n", "Benchmarking with caching:\n", - "Time taken with caching: 0.0237 seconds\n", - "Average time per embedding: 0.0024 seconds\n", + "Time taken with caching: 0.0663 seconds\n", + "Average time per embedding: 0.0066 seconds\n", "\n", "Performance comparison:\n", - "Speedup with caching: 3.96x faster\n", - "Time saved: 0.0703 seconds (74.8%)\n", - "Latency reduction: 0.0070 seconds per query\n" + "Speedup with caching: 7.14x faster\n", + "Time saved: 0.4073 seconds (86.0%)\n", + "Latency reduction: 0.0407 seconds per query\n" ] } ], "source": [ "# Text to use for benchmarking\n", "benchmark_text = \"This is a benchmark text to measure the performance of embedding caching.\"\n", - "benchmark_model = \"sentence-transformers/all-mpnet-base-v2\"\n", "\n", "# Create a fresh cache for benchmarking\n", "benchmark_cache = EmbeddingsCache(\n", @@ -665,23 +655,7 @@ " redis_url=\"redis://localhost:6379\",\n", " ttl=3600 # 1 hour TTL\n", ")\n", - "\n", - "# Function to get embeddings without caching\n", - "def get_embedding_without_cache(text, model_name):\n", - " return vectorizer.embed(text)\n", - "\n", - "# Function to get embeddings with caching\n", - "def get_embedding_with_cache(text, model_name):\n", - " if cached_result := benchmark_cache.get(text=text, model_name=model_name):\n", - " return cached_result[\"embedding\"]\n", - " \n", - " embedding = vectorizer.embed(text)\n", - " benchmark_cache.set(\n", - " text=text,\n", - " model_name=model_name,\n", - " embedding=embedding\n", - " )\n", - " return embedding\n", + "vectorizer.cache = benchmark_cache\n", "\n", "# Number of iterations for the benchmark\n", "n_iterations = 10\n", @@ -689,7 +663,8 @@ "# Benchmark without caching\n", "print(\"Benchmarking without caching:\")\n", "start_time = time.time()\n", - "get_embedding_without_cache(benchmark_text, benchmark_model)\n", + "for _ in range(n_iterations):\n", + " embedding = vectorizer.embed(text, skip_cache=True)\n", "no_cache_time = time.time() - start_time\n", "print(f\"Time taken without caching: {no_cache_time:.4f} seconds\")\n", "print(f\"Average time per embedding: {no_cache_time/n_iterations:.4f} seconds\")\n", @@ -697,7 +672,8 @@ "# Benchmark with caching\n", "print(\"\\nBenchmarking with caching:\")\n", "start_time = time.time()\n", - "get_embedding_with_cache(benchmark_text, benchmark_model)\n", + "for _ in range(n_iterations):\n", + " embedding = vectorizer.embed(text)\n", "cache_time = time.time() - start_time\n", "print(f\"Time taken with caching: {cache_time:.4f} seconds\")\n", "print(f\"Average time per embedding: {cache_time/n_iterations:.4f} seconds\")\n", @@ -785,7 +761,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.13.2" } }, "nbformat": 4, diff --git a/poetry.lock b/poetry.lock index df28f211..e819278a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -763,15 +763,15 @@ colorama = {version = "*", markers = "platform_system == \"Windows\""} [[package]] name = "cohere" -version = "5.13.12" +version = "5.15.0" description = "" optional = true python-versions = "<4.0,>=3.9" groups = ["main"] markers = "(python_version <= \"3.11\" or python_version >= \"3.12\") and extra == \"cohere\"" files = [ - {file = "cohere-5.13.12-py3-none-any.whl", hash = "sha256:2a043591a3e5280b47716a6b311e4c7f58e799364113a9cb81b50cd4f6c95f7e"}, - {file = "cohere-5.13.12.tar.gz", hash = "sha256:97bb9ac107e580780b941acbabd3aa5e71960e6835398292c46aaa8a0a4cab88"}, + {file = "cohere-5.15.0-py3-none-any.whl", hash = "sha256:22ff867c2a6f2fc2b585360c6072f584f11f275ef6d9242bac24e0fa2df1dfb5"}, + {file = "cohere-5.15.0.tar.gz", hash = "sha256:e802d4718ddb0bb655654382ebbce002756a3800faac30296cde7f1bdc6ff2cc"}, ] [package.dependencies] @@ -3627,8 +3627,8 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.21", markers = "python_version < \"3.10\""}, - {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.21.2", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.23.3", markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=2.1.0", markers = "python_version >= \"3.13\""}, {version = ">=1.26.0", markers = "python_version >= \"3.12\" and python_version < \"3.13\""}, ] diff --git a/redisvl/extensions/cache/__init__.py b/redisvl/extensions/cache/__init__.py index f3ce33ce..08e1a067 100644 --- a/redisvl/extensions/cache/__init__.py +++ b/redisvl/extensions/cache/__init__.py @@ -6,7 +6,5 @@ """ from redisvl.extensions.cache.base import BaseCache -from redisvl.extensions.cache.embeddings import EmbeddingsCache -from redisvl.extensions.cache.llm import SemanticCache -__all__ = ["BaseCache", "EmbeddingsCache", "SemanticCache"] +__all__ = ["BaseCache"] diff --git a/redisvl/extensions/cache/base.py b/redisvl/extensions/cache/base.py index 68e120e4..aabc548e 100644 --- a/redisvl/extensions/cache/base.py +++ b/redisvl/extensions/cache/base.py @@ -9,6 +9,8 @@ from redis import Redis from redis.asyncio import Redis as AsyncRedis +from redisvl.redis.connection import RedisConnectionFactory + class BaseCache: """Base abstract cache interface for all RedisVL caches. @@ -121,10 +123,15 @@ async def _get_async_redis_client(self) -> AsyncRedis: AsyncRedis: An async Redis client instance. """ if not hasattr(self, "_async_redis_client") or self._async_redis_client is None: - # Create new async Redis client - url = self.redis_kwargs["redis_url"] - kwargs = self.redis_kwargs["connection_kwargs"] - self._async_redis_client = AsyncRedis.from_url(url, **kwargs) # type: ignore + client = self.redis_kwargs.get("redis_client") + if isinstance(client, Redis): + self._async_redis_client = RedisConnectionFactory.sync_to_async_redis( + client + ) + else: + url = self.redis_kwargs["redis_url"] + kwargs = self.redis_kwargs["connection_kwargs"] + self._async_redis_client = RedisConnectionFactory.get_async_redis_connection(url, **kwargs) # type: ignore return self._async_redis_client def expire(self, key: str, ttl: Optional[int] = None) -> None: diff --git a/redisvl/extensions/cache/llm/semantic.py b/redisvl/extensions/cache/llm/semantic.py index 5a97b03d..0232514f 100644 --- a/redisvl/extensions/cache/llm/semantic.py +++ b/redisvl/extensions/cache/llm/semantic.py @@ -22,7 +22,6 @@ from redisvl.index import AsyncSearchIndex, SearchIndex from redisvl.query import VectorRangeQuery from redisvl.query.filter import FilterExpression -from redisvl.redis.connection import RedisConnectionFactory from redisvl.redis.utils import hashify from redisvl.utils.log import get_logger from redisvl.utils.utils import ( @@ -31,7 +30,7 @@ serialize, validate_vector_dims, ) -from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +from redisvl.utils.vectorize.base import BaseVectorizer logger = get_logger("[RedisVL]") @@ -105,6 +104,8 @@ def __init__( ) self._vectorizer = vectorizer else: + from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer + # Create a default vectorizer vectorizer_kwargs = kwargs if dtype: @@ -183,11 +184,8 @@ def _modify_schema( async def _get_async_index(self) -> AsyncSearchIndex: """Lazily construct the async search index class.""" # Construct async index if necessary - async_client = None if self._aindex is None: - client = self.redis_kwargs.get("redis_client") - if isinstance(client, Redis): - async_client = RedisConnectionFactory.sync_to_async_redis(client) + async_client = await self._get_async_redis_client() self._aindex = AsyncSearchIndex( schema=self._index.schema, redis_client=async_client, diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index d9b38677..441cbd76 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -62,7 +62,7 @@ class RoutingConfig(BaseModel): """Configuration for routing behavior.""" """The maximum number of top matches to return.""" - max_k: Annotated[int, Field(strict=True, default=1, gt=0)] = 1 + max_k: Annotated[int, Field(strict=True, gt=0)] = 1 """Aggregation method to use to classify queries.""" aggregation_method: DistanceAggregationMethod = Field( default=DistanceAggregationMethod.avg diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 9ca886ab..c2ebb50d 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -21,11 +21,8 @@ from redisvl.redis.utils import convert_bytes, hashify, make_dict from redisvl.utils.log import get_logger from redisvl.utils.utils import deprecated_argument, model_to_dict -from redisvl.utils.vectorize import ( - BaseVectorizer, - HFTextVectorizer, - vectorizer_from_dict, -) +from redisvl.utils.vectorize.base import BaseVectorizer +from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer logger = get_logger(__name__) @@ -483,6 +480,8 @@ def from_dict( } router = SemanticRouter.from_dict(router_data) """ + from redisvl.utils.vectorize import vectorizer_from_dict + try: name = data["name"] routes_data = data["routes"] diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 9497d06c..d08a7002 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -20,7 +20,7 @@ from redisvl.query import FilterQuery, RangeQuery from redisvl.query.filter import Tag from redisvl.utils.utils import deprecated_argument, validate_vector_dims -from redisvl.utils.vectorize import BaseVectorizer, HFTextVectorizer +from redisvl.utils.vectorize.base import BaseVectorizer class SemanticSessionManager(BaseSessionManager): @@ -82,6 +82,8 @@ def __init__( f"Provided dtype {dtype} does not match vectorizer dtype {vectorizer.dtype}" ) else: + from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer + vectorizer_kwargs = kwargs if dtype: diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index 83e58fe4..c4ffcd3c 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -1,11 +1,16 @@ +import logging from enum import Enum -from typing import Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator +from typing_extensions import Annotated +from redisvl.extensions.cache.embeddings import EmbeddingsCache from redisvl.redis.utils import array_to_buffer from redisvl.schema.fields import VectorDataType +logger = logging.getLogger(__name__) + class Vectorizers(Enum): azure_openai = "azure_openai" @@ -18,19 +23,33 @@ class Vectorizers(Enum): class BaseVectorizer(BaseModel): - """Base vectorizer interface.""" + """Base RedisVL vectorizer interface. + + This class defines the interface for text vectorization with an optional + caching layer to improve performance by avoiding redundant API calls. + + Attributes: + model: The name of the embedding model. + dtype: The data type of the embeddings, defaults to "float32". + dims: The dimensionality of the vectors. + cache: Optional embedding cache to store and retrieve embeddings. + """ model: str dtype: str = "float32" - dims: Optional[int] = None + dims: Annotated[Optional[int], Field(strict=True, gt=0)] = None + cache: Optional[EmbeddingsCache] = Field(default=None) + model_config = ConfigDict(arbitrary_types_allowed=True) @property def type(self) -> str: + """Return the type of vectorizer.""" return "base" @field_validator("dtype") @classmethod def check_dtype(cls, dtype): + """Validate the data type is supported.""" try: VectorDataType(dtype.upper()) except ValueError: @@ -39,33 +58,63 @@ def check_dtype(cls, dtype): ) return dtype - @field_validator("dims") - @classmethod - def check_dims(cls, value): - """Ensures the dims are a positive integer.""" - if value <= 0: - raise ValueError("Dims must be a positive integer.") - return value - def embed( self, text: str, preprocess: Optional[Callable] = None, as_buffer: bool = False, + skip_cache: bool = False, **kwargs, ) -> Union[List[float], bytes]: - """Embed a chunk of text. + """Generate a vector embedding for a text string. Args: - text: Text to embed - preprocess: Optional function to preprocess text - as_buffer: If True, returns a bytes object instead of a list + text: The text to convert to a vector embedding + preprocess: Function to apply to the text before embedding + as_buffer: Return the embedding as a binary buffer instead of a list + skip_cache: Bypass the cache for this request + **kwargs: Additional model-specific parameters Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + The vector embedding as either a list of floats or binary buffer + + Examples: + >>> embedding = vectorizer.embed("Hello world") """ - raise NotImplementedError + # Apply preprocessing if provided + if preprocess is not None: + text = preprocess(text) + + # Check cache if available and not skipped + if self.cache is not None and not skip_cache: + try: + cache_result = self.cache.get(text=text, model_name=self.model) + if cache_result: + logger.debug(f"Cache hit for text with model {self.model}") + return self._process_embedding( + cache_result["embedding"], as_buffer, self.dtype + ) + except Exception as e: + logger.warning(f"Error accessing embedding cache: {str(e)}") + + # Generate embedding using provider-specific implementation + cache_metadata = kwargs.pop("metadata", {}) + embedding = self._embed(text, **kwargs) + + # Store in cache if available and not skipped + if self.cache is not None and not skip_cache: + try: + self.cache.set( + text=text, + model_name=self.model, + embedding=embedding, + metadata=cache_metadata, + ) + except Exception as e: + logger.warning(f"Error storing in embedding cache: {str(e)}") + + # Process and return result + return self._process_embedding(embedding, as_buffer, self.dtype) def embed_many( self, @@ -73,74 +122,382 @@ def embed_many( preprocess: Optional[Callable] = None, batch_size: int = 10, as_buffer: bool = False, + skip_cache: bool = False, **kwargs, ) -> Union[List[List[float]], List[bytes]]: - """Embed multiple chunks of text. + """Generate vector embeddings for multiple texts efficiently. Args: - texts: List of texts to embed - preprocess: Optional function to preprocess text - batch_size: Number of texts to process in each batch - as_buffer: If True, returns each embedding as a bytes object + texts: List of texts to convert to vector embeddings + preprocess: Function to apply to each text before embedding + batch_size: Number of texts to process in each API call + as_buffer: Return embeddings as binary buffers instead of lists + skip_cache: Bypass the cache for this request + **kwargs: Additional model-specific parameters Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List of vector embeddings in the same order as the input texts + + Examples: + >>> embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2) """ - raise NotImplementedError + if not texts: + return [] - async def aembed_many( + # Apply preprocessing if provided + if preprocess is not None: + processed_texts = [preprocess(text) for text in texts] + else: + processed_texts = texts + + # Get cached embeddings and identify misses + results, cache_misses, cache_miss_indices = self._get_from_cache_batch( + processed_texts, skip_cache + ) + + # Generate embeddings for cache misses + if cache_misses: + cache_metadata = kwargs.pop("metadata", {}) + new_embeddings = self._embed_many( + texts=cache_misses, batch_size=batch_size, **kwargs + ) + + # Store new embeddings in cache + self._store_in_cache_batch( + cache_misses, new_embeddings, cache_metadata, skip_cache + ) + + # Insert new embeddings into results array + for idx, embedding in zip(cache_miss_indices, new_embeddings): + results[idx] = embedding + + # Process and return results + return [self._process_embedding(emb, as_buffer, self.dtype) for emb in results] + + async def aembed( self, - texts: List[str], + text: str, preprocess: Optional[Callable] = None, - batch_size: int = 10, as_buffer: bool = False, + skip_cache: bool = False, **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Asynchronously embed multiple chunks of text. + ) -> Union[List[float], bytes]: + """Asynchronously generate a vector embedding for a text string. Args: - texts: List of texts to embed - preprocess: Optional function to preprocess text - batch_size: Number of texts to process in each batch - as_buffer: If True, returns each embedding as a bytes object + text: The text to convert to a vector embedding + preprocess: Function to apply to the text before embedding + as_buffer: Return the embedding as a binary buffer instead of a list + skip_cache: Bypass the cache for this request + **kwargs: Additional model-specific parameters Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + The vector embedding as either a list of floats or binary buffer + + Examples: + >>> embedding = await vectorizer.aembed("Hello world") """ - # Fallback to standard embedding call if no async support - return self.embed_many(texts, preprocess, batch_size, as_buffer, **kwargs) + # Apply preprocessing if provided + if preprocess is not None: + text = preprocess(text) - async def aembed( + # Check cache if available and not skipped + if self.cache is not None and not skip_cache: + try: + cache_result = await self.cache.aget(text=text, model_name=self.model) + if cache_result: + logger.debug(f"Async cache hit for text with model {self.model}") + return self._process_embedding( + cache_result["embedding"], as_buffer, self.dtype + ) + except Exception as e: + logger.warning( + f"Error accessing embedding cache asynchronously: {str(e)}" + ) + + # Generate embedding using provider-specific implementation + cache_metadata = kwargs.pop("metadata", {}) + embedding = await self._aembed(text, **kwargs) + + # Store in cache if available and not skipped + if self.cache is not None and not skip_cache: + try: + await self.cache.aset( + text=text, + model_name=self.model, + embedding=embedding, + metadata=cache_metadata, + ) + except Exception as e: + logger.warning( + f"Error storing in embedding cache asynchronously: {str(e)}" + ) + + # Process and return result + return self._process_embedding(embedding, as_buffer, self.dtype) + + async def aembed_many( self, - text: str, + texts: List[str], preprocess: Optional[Callable] = None, + batch_size: int = 10, as_buffer: bool = False, + skip_cache: bool = False, **kwargs, - ) -> Union[List[float], bytes]: - """Asynchronously embed a chunk of text. + ) -> Union[List[List[float]], List[bytes]]: + """Asynchronously generate vector embeddings for multiple texts efficiently. Args: - text: Text to embed - preprocess: Optional function to preprocess text - as_buffer: If True, returns a bytes object instead of a list + texts: List of texts to convert to vector embeddings + preprocess: Function to apply to each text before embedding + batch_size: Number of texts to process in each API call + as_buffer: Return embeddings as binary buffers instead of lists + skip_cache: Bypass the cache for this request + **kwargs: Additional model-specific parameters Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List of vector embeddings in the same order as the input texts + + Examples: + >>> embeddings = await vectorizer.aembed_many(["Hello", "World"], batch_size=2) """ - # Fallback to standard embedding call if no async support - return self.embed(text, preprocess, as_buffer, **kwargs) + if not texts: + return [] + + # Apply preprocessing if provided + if preprocess is not None: + processed_texts = [preprocess(text) for text in texts] + else: + processed_texts = texts + + # Get cached embeddings and identify misses + results, cache_misses, cache_miss_indices = await self._aget_from_cache_batch( + processed_texts, skip_cache + ) + + # Generate embeddings for cache misses + if cache_misses: + cache_metadata = kwargs.pop("metadata", {}) + new_embeddings = await self._aembed_many( + texts=cache_misses, batch_size=batch_size, **kwargs + ) + + # Store new embeddings in cache + await self._astore_in_cache_batch( + cache_misses, new_embeddings, cache_metadata, skip_cache + ) + + # Insert new embeddings into results array + for idx, embedding in zip(cache_miss_indices, new_embeddings): + results[idx] = embedding + + # Process and return results + return [self._process_embedding(emb, as_buffer, self.dtype) for emb in results] + + def _embed(self, text: str, **kwargs) -> List[float]: + """Generate a vector embedding for a single text.""" + raise NotImplementedError + + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Generate vector embeddings for a batch of texts.""" + raise NotImplementedError + + async def _aembed(self, text: str, **kwargs) -> List[float]: + """Asynchronously generate a vector embedding for a single text.""" + logger.warning( + "This vectorizer has no async embed method. Falling back to sync." + ) + return self._embed(text, **kwargs) + + async def _aembed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Asynchronously generate vector embeddings for a batch of texts.""" + logger.warning( + "This vectorizer has no async embed_many method. Falling back to sync." + ) + return self._embed_many(texts, batch_size, **kwargs) + + def _get_from_cache_batch( + self, texts: List[str], skip_cache: bool + ) -> tuple[List[Optional[List[float]]], List[str], List[int]]: + """Get vector embeddings from cache and track cache misses. + + Args: + texts: List of texts to get from cache + skip_cache: Whether to skip cache lookup + + Returns: + Tuple of (results, cache_misses, cache_miss_indices) + """ + results = [None] * len(texts) + cache_misses = [] + cache_miss_indices = [] + + # Skip cache if requested or no cache available + if skip_cache or self.cache is None: + return results, texts, list(range(len(texts))) # type: ignore + + try: + # Efficient batch cache lookup + cache_results = self.cache.mget(texts=texts, model_name=self.model) + + # Process cache hits and collect misses + for i, (text, cache_result) in enumerate(zip(texts, cache_results)): + if cache_result: + results[i] = cache_result["embedding"] + else: + cache_misses.append(text) + cache_miss_indices.append(i) + + logger.debug( + f"Cache hits: {len(texts) - len(cache_misses)}, misses: {len(cache_misses)}" + ) + except Exception as e: + logger.warning(f"Error accessing embedding cache in batch: {str(e)}") + # On cache error, process all texts + cache_misses = texts + cache_miss_indices = list(range(len(texts))) + + return results, cache_misses, cache_miss_indices # type: ignore + + async def _aget_from_cache_batch( + self, texts: List[str], skip_cache: bool + ) -> tuple[List[Optional[List[float]]], List[str], List[int]]: + """Asynchronously get vector embeddings from cache and track cache misses. + + Args: + texts: List of texts to get from cache + skip_cache: Whether to skip cache lookup + + Returns: + Tuple of (results, cache_misses, cache_miss_indices) + """ + results = [None] * len(texts) + cache_misses = [] + cache_miss_indices = [] + + # Skip cache if requested or no cache available + if skip_cache or self.cache is None: + return results, texts, list(range(len(texts))) # type: ignore + + try: + # Efficient batch cache lookup + cache_results = await self.cache.amget(texts=texts, model_name=self.model) + + # Process cache hits and collect misses + for i, (text, cache_result) in enumerate(zip(texts, cache_results)): + if cache_result: + results[i] = cache_result["embedding"] + else: + cache_misses.append(text) + cache_miss_indices.append(i) + + logger.debug( + f"Async cache hits: {len(texts) - len(cache_misses)}, misses: {len(cache_misses)}" + ) + except Exception as e: + logger.warning( + f"Error accessing embedding cache in batch asynchronously: {str(e)}" + ) + # On cache error, process all texts + cache_misses = texts + cache_miss_indices = list(range(len(texts))) + + return results, cache_misses, cache_miss_indices # type: ignore + + def _store_in_cache_batch( + self, + texts: List[str], + embeddings: List[List[float]], + metadata: Dict, + skip_cache: bool, + ) -> None: + """Store a batch of vector embeddings in the cache. + + Args: + texts: List of texts that were embedded + embeddings: List of vector embeddings + metadata: Metadata to store with the embeddings + skip_cache: Whether to skip cache storage + """ + if skip_cache or self.cache is None: + return + + try: + # Prepare batch cache storage items + cache_items = [ + { + "text": text, + "model_name": self.model, + "embedding": emb, + "metadata": metadata, + } + for text, emb in zip(texts, embeddings) + ] + self.cache.mset(items=cache_items) + except Exception as e: + logger.warning(f"Error storing batch in embedding cache: {str(e)}") + + async def _astore_in_cache_batch( + self, + texts: List[str], + embeddings: List[List[float]], + metadata: Dict, + skip_cache: bool, + ) -> None: + """Asynchronously store a batch of vector embeddings in the cache. + + Args: + texts: List of texts that were embedded + embeddings: List of vector embeddings + metadata: Metadata to store with the embeddings + skip_cache: Whether to skip cache storage + """ + if skip_cache or self.cache is None: + return + + try: + # Prepare batch cache storage items + cache_items = [ + { + "text": text, + "model_name": self.model, + "embedding": emb, + "metadata": metadata, + } + for text, emb in zip(texts, embeddings) + ] + await self.cache.amset(items=cache_items) + except Exception as e: + logger.warning( + f"Error storing batch in embedding cache asynchronously: {str(e)}" + ) def batchify(self, seq: list, size: int, preprocess: Optional[Callable] = None): + """Split a sequence into batches of specified size. + + Args: + seq: Sequence to split into batches + size: Batch size + preprocess: Optional function to preprocess each item + + Yields: + Batches of the sequence + """ for pos in range(0, len(seq), size): if preprocess is not None: yield [preprocess(chunk) for chunk in seq[pos : pos + size]] else: yield seq[pos : pos + size] - def _process_embedding(self, embedding: List[float], as_buffer: bool, dtype: str): - if as_buffer: - return array_to_buffer(embedding, dtype) + def _process_embedding( + self, embedding: Optional[List[float]], as_buffer: bool, dtype: str + ): + """Process the vector embedding format based on the as_buffer flag.""" + if embedding is not None: + if as_buffer: + return array_to_buffer(embedding, dtype) return embedding diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 410280e5..eaf61a3e 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,10 +1,13 @@ import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -28,9 +31,12 @@ class AzureOpenAITextVectorizer(BaseVectorizer): allowing for batch processing of texts and flexibility in handling preprocessing tasks. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Synchronous embedding of a single text + # Basic usage vectorizer = AzureOpenAITextVectorizer( model="text-embedding-ada-002", api_config={ @@ -41,6 +47,26 @@ class AzureOpenAITextVectorizer(BaseVectorizer): ) embedding = vectorizer.embed("Hello, world!") + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="azureopenai_embeddings_cache") + + vectorizer = AzureOpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={ + "api_key": "your_api_key", + "api_version": "your_api_version", + "azure_endpoint": "your_azure_endpoint", + }, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + # Asynchronous batch embedding of multiple texts embeddings = await vectorizer.aembed_many( ["Hello, world!", "How are you?"], @@ -49,14 +75,14 @@ class AzureOpenAITextVectorizer(BaseVectorizer): """ - _client: Any = PrivateAttr() - _aclient: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the AzureOpenAI vectorizer. @@ -71,22 +97,37 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the openai library is not installed. ValueError: If the AzureOpenAI API key, version, or endpoint are not provided. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize clients and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the AzureOpenAI clients and determine the embedding dimensions.""" + # Initialize clients self._initialize_clients(api_config, **kwargs) - # Set model dimensions + # Set model dimensions after client initialization self.dims = self._set_model_dims() def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ - Setup the OpenAI clients using the provided API key or an - environment variable. + Setup the AzureOpenAI clients using the provided API key, API version, + and Azure endpoint. + + Args: + api_config: Dictionary with API configuration options + **kwargs: Additional arguments to pass to AzureOpenAI clients + + Raises: + ImportError: If the openai library is not installed + ValueError: If required parameters are not provided """ if api_config is None: api_config = {} @@ -96,8 +137,8 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): from openai import AsyncAzureOpenAI, AzureOpenAI except ImportError: raise ImportError( - "AzureOpenAI vectorizer requires the openai library. \ - Please install with `pip install openai`" + "AzureOpenAI vectorizer requires the openai library. " + "Please install with `pip install openai>=1.13.0`" ) # Fetch the API key, version and endpoint from api_config or environment variable @@ -110,8 +151,7 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): if not azure_endpoint: raise ValueError( "AzureOpenAI API endpoint is required. " - "Provide it in api_config or set the AZURE_OPENAI_ENDPOINT\ - environment variable." + "Provide it in api_config or set the AZURE_OPENAI_ENDPOINT environment variable." ) api_version = ( @@ -123,8 +163,7 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): if not api_version: raise ValueError( "AzureOpenAI API version is required. " - "Provide it in api_config or set the OPENAI_API_VERSION\ - environment variable." + "Provide it in api_config or set the OPENAI_API_VERSION environment variable." ) api_key = ( @@ -136,10 +175,10 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): if not api_key: raise ValueError( "AzureOpenAI API key is required. " - "Provide it in api_config or set the AZURE_OPENAI_API_KEY\ - environment variable." + "Provide it in api_config or set the AZURE_OPENAI_API_KEY environment variable." ) + # Store clients as regular attributes instead of PrivateAttr self._client = AzureOpenAI( api_key=api_key, api_version=api_version, @@ -156,198 +195,164 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): ) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the AzureOpenAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of texts using the AzureOpenAI API. + def _embed(self, text: str, **kwargs) -> List[float]: + """ + Generate a vector embedding for a single text using the AzureOpenAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing - callable to perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the AzureOpenAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create( - input=batch, model=self.model, **kwargs + try: + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the AzureOpenAI API. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Generate vector embeddings for a batch of texts using the AzureOpenAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the AzureOpenAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = self._client.embeddings.create( - input=[text], model=self.model, **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) + embeddings.extend([r.embedding for r in response.data]) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Asynchronously embed many chunks of texts using the AzureOpenAI API. + async def _aembed(self, text: str, **kwargs) -> List[float]: + """ + Asynchronously generate a vector embedding for a single text using the AzureOpenAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the AzureOpenAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = await self._aclient.embeddings.create( - input=batch, model=self.model, **kwargs + try: + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Asynchronously embed a chunk of text using the OpenAI API. + async def _aembed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Asynchronously generate vector embeddings for a batch of texts using the AzureOpenAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the AzureOpenAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = await self._aclient.embeddings.create( - input=[text], model=self.model, **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = await self._aclient.embeddings.create( + input=batch, model=self.model, **kwargs + ) + embeddings.extend([r.embedding for r in response.data]) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index 2d40685d..ac4bb415 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -1,11 +1,14 @@ import json import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -24,9 +27,12 @@ class BedrockTextVectorizer(BaseVectorizer): The vectorizer supports synchronous operations with batch processing and preprocessing capabilities. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Initialize with explicit credentials + # Basic usage with explicit credentials vectorizer = AmazonBedrockTextVectorizer( model="amazon.titan-embed-text-v2:0", api_config={ @@ -36,21 +42,33 @@ class BedrockTextVectorizer(BaseVectorizer): } ) - # Initialize using environment variables - vectorizer = AmazonBedrockTextVectorizer() + # With environment variables and caching + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="bedrock_embeddings_cache") - # Generate embeddings - embedding = vectorizer.embed("Hello, world!") + vectorizer = AmazonBedrockTextVectorizer( + model="amazon.titan-embed-text-v2:0", + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + + # Generate batch embeddings embeddings = vectorizer.embed_many(["Hello", "World"], batch_size=2) """ - _client: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "amazon.titan-embed-text-v2:0", api_config: Optional[Dict[str, str]] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ) -> None: """Initialize the AWS Bedrock Vectorizer. @@ -63,22 +81,37 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ValueError: If credentials are not provided in config or environment. ImportError: If boto3 is not installed. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize client and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the Bedrock client and determine the embedding dimensions.""" + # Initialize client self._initialize_client(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the Bedrock client using the provided API keys or environment variables. + + Args: + api_config: Dictionary with AWS credentials and configuration + **kwargs: Additional arguments to pass to boto3 client + + Raises: + ImportError: If boto3 is not installed + ValueError: If AWS credentials are not provided """ try: import boto3 # type: ignore @@ -105,6 +138,7 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): "AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY" ) + # Store client as a regular attribute instead of PrivateAttr self._client = boto3.client( "bedrock-runtime", aws_access_key_id=aws_access_key_id, @@ -114,113 +148,106 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): ) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check") + return len(embedding) except (KeyError, IndexError) as ke: - raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}") + raise ValueError(f"Unexpected response from the Bedrock API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the AWS Bedrock Embeddings API. + def _embed(self, text: str, **kwargs) -> List[float]: + """ + Generate a vector embedding for a single text using the AWS Bedrock API. Args: - text (str): Text to embed. - preprocess (Optional[Callable]): Optional preprocessing function. - as_buffer (bool): Whether to return as byte buffer. + text: Text to embed + **kwargs: Additional parameters to pass to the AWS Bedrock API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If text is not a string. + TypeError: If text is not a string + ValueError: If embedding fails """ if not isinstance(text, str): raise TypeError("Text must be a string") - if preprocess: - text = preprocess(text) - - response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}), **kwargs - ) - response_body = json.loads(response["body"].read()) - embedding = response_body["embedding"] - - dtype = kwargs.pop("dtype", self.dtype) - return self._process_embedding(embedding, as_buffer, dtype) + try: + response = self._client.invoke_model( + modelId=self.model, body=json.dumps({"inputText": text}), **kwargs + ) + response_body = json.loads(response["body"].read()) + return response_body["embedding"] + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of text using the AWS Bedrock Embeddings API. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Generate vector embeddings for a batch of texts using the AWS Bedrock API. Args: - texts (List[str]): List of texts to embed. - preprocess (Optional[Callable]): Optional preprocessing function. - batch_size (int): Size of batches for processing. Defaults to 10. - as_buffer (bool): Whether to return as byte buffers. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the AWS Bedrock API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If texts is not a list of strings. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ if not isinstance(texts, list): raise TypeError("Texts must be a list of strings") if texts and not isinstance(texts[0], str): raise TypeError("Texts must be a list of strings") - embeddings: List[List[float]] = [] - dtype = kwargs.pop("dtype", self.dtype) - - for batch in self.batchify(texts, batch_size, preprocess): - # Process each text in the batch individually since Bedrock - # doesn't support batch embedding - batch_embeddings = [] - for text in batch: - response = self._client.invoke_model( - modelId=self.model, body=json.dumps({"inputText": text}), **kwargs - ) - response_body = json.loads(response["body"].read()) - batch_embeddings.append(response_body["embedding"]) - - embeddings.extend( - [ - self._process_embedding(embedding, as_buffer, dtype) - for embedding in batch_embeddings - ] - ) - - return embeddings + try: + embeddings: List[List[float]] = [] + + for batch in self.batchify(texts, batch_size): + # Process each text in the batch individually since Bedrock + # doesn't support batch embedding + batch_embeddings = [] + for text in batch: + response = self._client.invoke_model( + modelId=self.model, + body=json.dumps({"inputText": text}), + **kwargs, + ) + response_body = json.loads(response["body"].read()) + batch_embeddings.append(response_body["embedding"]) + + embeddings.extend(batch_embeddings) + + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 4e6192e2..7e5c0e38 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,11 +1,14 @@ import os import warnings -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Union -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -27,10 +30,14 @@ class CohereTextVectorizer(BaseVectorizer): The vectorizer supports only synchronous operations, allows for batch processing of texts and flexibility in handling preprocessing tasks. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python from redisvl.utils.vectorize import CohereTextVectorizer + # Basic usage vectorizer = CohereTextVectorizer( model="embed-english-v3.0", api_config={"api_key": "your-cohere-api-key"} # OR set COHERE_API_KEY in your env @@ -39,20 +46,43 @@ class CohereTextVectorizer(BaseVectorizer): text="your input query text here", input_type="search_query" ) - doc_embeddings = cohere.embed_many( + doc_embeddings = vectorizer.embed_many( texts=["your document text", "more document text"], input_type="search_document" ) + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="cohere_embeddings_cache") + + vectorizer = CohereTextVectorizer( + model="embed-english-v3.0", + api_config={"api_key": "your-cohere-api-key"}, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed( + text="your input query text here", + input_type="search_query" + ) + + # Second call will retrieve from cache + embedding2 = vectorizer.embed( + text="your input query text here", + input_type="search_query" + ) + """ - _client: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "embed-english-v3.0", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the Cohere vectorizer. @@ -67,22 +97,37 @@ def __init__( Used when setting `as_buffer=True` in calls to embed() and embed_many(). 'float32' will use Cohere's float embeddings, 'int8' and 'uint8' will map to Cohere's corresponding embedding types. Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the cohere library is not installed. ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize client and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the Cohere client and determine the embedding dimensions.""" + # Initialize client self._initialize_client(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ - Setup the Cohere clients using the provided API key or an + Setup the Cohere client using the provided API key or an environment variable. + + Args: + api_config: Dictionary with API configuration options + **kwargs: Additional arguments to pass to Cohere client + + Raises: + ImportError: If the cohere library is not installed + ValueError: If no API key is provided """ if api_config is None: api_config = {} @@ -92,8 +137,8 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): from cohere import Client except ImportError: raise ImportError( - "Cohere vectorizer requires the cohere library. \ - Please install with `pip install cohere`" + "Cohere vectorizer requires the cohere library. " + "Please install with `pip install cohere`" ) api_key = ( @@ -104,20 +149,39 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) + self._client = Client(api_key=api_key, client_name="redisvl", **kwargs) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check", input_type="search_document") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check", input_type="search_document") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) def _get_cohere_embedding_type(self, dtype: str) -> List[str]: - """Map dtype to appropriate Cohere embedding_types value.""" + """ + Map dtype to appropriate Cohere embedding_types value. + + Args: + dtype: The data type to map to Cohere embedding types + + Returns: + List of embedding type strings compatible with Cohere API + """ if dtype == "int8": return ["int8"] elif dtype == "uint8": @@ -125,40 +189,30 @@ def _get_cohere_embedding_type(self, dtype: str) -> List[str]: else: return ["float"] - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], List[int], bytes]: - """Embed a chunk of text using the Cohere Embeddings API. + def _validate_input_type(self, input_type) -> None: + """ + Validate that a proper input_type parameter was provided. - Must provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. + Args: + input_type: The input type parameter to validate - Supported input types: - - ``search_document``: Used for embeddings stored in a vector database for search use-cases. - - ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents. - - ``classification``: Used for embeddings passed through a text classifier - - ``clustering``: Used for the embeddings run through a clustering algorithm. + Raises: + TypeError: If input_type is not a string + """ + if not isinstance(input_type, str): + raise TypeError( + "Must pass in a str value for cohere embedding input_type. " + "See https://docs.cohere.com/reference/embed." + ) - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type= "search_document" and when you are - querying the database, you should set the input_type = "search query". - If you want to use the embeddings for a classification or clustering - task downstream, you should set input_type= "classification" or - "clustering". + def _embed(self, text: str, **kwargs) -> List[Union[float, int]]: + """ + Generate a vector embedding for a single text using the Cohere API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - Required for embedding models v3 and higher. + text: Text to embed + **kwargs: Additional parameters to pass to the Cohere API, + must include 'input_type' Returns: Union[List[float], List[int], bytes]: @@ -168,23 +222,14 @@ def embed( - For dtype="int8" or "uint8": Returns a list of integers Raises: - TypeError: In an invalid input_type is provided. - + TypeError: If text is not a string or input_type is not provided + ValueError: If embedding fails """ - input_type = kwargs.pop("input_type", None) - if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") - if not isinstance(input_type, str): - raise TypeError( - "Must pass in a str value for cohere embedding input_type. \ - See https://docs.cohere.com/reference/embed." - ) - - if preprocess: - text = preprocess(text) - dtype = kwargs.pop("dtype", self.dtype) + input_type = kwargs.pop("input_type", None) + self._validate_input_type(input_type) # Check if embedding_types was provided and warn user if "embedding_types" in kwargs: @@ -197,93 +242,59 @@ def embed( kwargs.pop("embedding_types") # Map dtype to appropriate embedding_type - embedding_types = self._get_cohere_embedding_type(dtype) - - response = self._client.embed( - texts=[text], - model=self.model, - input_type=input_type, - embedding_types=embedding_types, - **kwargs, - ) + embedding_types = self._get_cohere_embedding_type(self.dtype) - # Extract the appropriate embedding based on embedding_types - embed_type = embedding_types[0] - if hasattr(response.embeddings, embed_type): - embedding = getattr(response.embeddings, embed_type)[0] - else: - embedding = response.embeddings[0] # Fallback for older API versions + try: + response = self._client.embed( + texts=[text], + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, + ) - return self._process_embedding(embedding, as_buffer, dtype) + # Extract the appropriate embedding based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + embedding = getattr(response.embeddings, embed_type)[0] + else: + embedding = response.embeddings[0] # type: ignore + + return embedding + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[List[int]], List[bytes]]: - """Embed many chunks of text using the Cohere Embeddings API. - - Must provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. - - Supported input types: - - ``search_document``: Used for embeddings stored in a vector database for search use-cases. - - ``search_query``: Used for embeddings of search queries run against a vector DB to find relevant documents. - - ``classification``: Used for embeddings passed through a text classifier - - ``clustering``: Used for the embeddings run through a clustering algorithm. - - - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type= "search_document" and when you are - querying the database, you should set the input_type = "search query". - If you want to use the embeddings for a classification or clustering - task downstream, you should set input_type= "classification" or - "clustering". + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[Union[float, int]]]: + """ + Generate vector embeddings for a batch of texts using the Cohere API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - Required for embedding models v3 and higher. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the Cohere API, + must include 'input_type' Returns: - Union[List[List[float]], List[List[int]], List[bytes]]: - - If as_buffer=True: Returns a list of bytes objects - - If as_buffer=False: - - For dtype="float32": Returns a list of lists of floats - - For dtype="int8" or "uint8": Returns a list of lists of integers + List[List[Union[float, int]]]: List of vector embeddings Raises: - TypeError: In an invalid input_type is provided. - + TypeError: If texts is not a list of strings or input_type is not provided + ValueError: If embedding fails """ - input_type = kwargs.pop("input_type", None) - if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): + if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - if not isinstance(input_type, str): - raise TypeError( - "Must pass in a str value for cohere embedding input_type.\ - See https://docs.cohere.com/reference/embed." - ) - dtype = kwargs.pop("dtype", self.dtype) + input_type = kwargs.pop("input_type", None) + self._validate_input_type(input_type) # Check if embedding_types was provided and warn user if "embedding_types" in kwargs: @@ -296,31 +307,31 @@ def embed_many( kwargs.pop("embedding_types") # Map dtype to appropriate embedding_type - embedding_types = self._get_cohere_embedding_type(dtype) + embedding_types = self._get_cohere_embedding_type(self.dtype) embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embed( - texts=batch, - model=self.model, - input_type=input_type, - embedding_types=embedding_types, - **kwargs, - ) + for batch in self.batchify(texts, batch_size): + try: + response = self._client.embed( + texts=batch, + model=self.model, + input_type=input_type, + embedding_types=embedding_types, + **kwargs, + ) + + # Extract the appropriate embeddings based on embedding_types + embed_type = embedding_types[0] + if hasattr(response.embeddings, embed_type): + batch_embeddings = getattr(response.embeddings, embed_type) + else: + # Fallback for older API versions + batch_embeddings = response.embeddings + + embeddings.extend(batch_embeddings) + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") - # Extract the appropriate embeddings based on embedding_types - embed_type = embedding_types[0] - if hasattr(response.embeddings, embed_type): - batch_embeddings = getattr(response.embeddings, embed_type) - else: - batch_embeddings = ( - response.embeddings - ) # Fallback for older API versions - - embeddings += [ - self._process_embedding(embedding, as_buffer, dtype) - for embedding in batch_embeddings - ] return embeddings @property diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index ed284d29..a4e80787 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -1,8 +1,10 @@ -from typing import Any, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable, List, Optional -from pydantic import PrivateAttr +from pydantic import ConfigDict + +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache -from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -32,30 +34,6 @@ def _check_vector(result: list, method_name: str) -> None: raise ValueError(f"{method_name} must return a list of floats.") -def validate_async(method): - """ - Decorator that lazily validates the output of async methods (aembed, aembed_many). - On first call, it checks the returned embeddings with _check_vector, then sets a flag - so subsequent calls skip re-validation. - """ - - async def wrapper(self, *args, **kwargs): - result = await method(self, *args, **kwargs) - method_name = method.__name__ - validated_attr = f"_{method_name}_validated" - - try: - if not getattr(self, validated_attr): - _check_vector(result, method_name) - setattr(self, validated_attr, True) - except Exception as e: - raise ValueError(f"Invalid embedding method: {e}") - - return result - - return wrapper - - class CustomTextVectorizer(BaseVectorizer): """The CustomTextVectorizer class wraps user-defined embedding methods to create embeddings for text data. @@ -66,14 +44,32 @@ class CustomTextVectorizer(BaseVectorizer): allows for batch processing of texts, but at a minimum only syncronous embedding is required to satisfy the 'embed()' method. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Synchronous embedding of a single text + # Basic usage with a custom embedding function vectorizer = CustomTextVectorizer( embed = my_vectorizer.generate_embedding ) embedding = vectorizer.embed("Hello, world!") + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="my_embeddings_cache") + + vectorizer = CustomTextVectorizer( + embed=my_vectorizer.generate_embedding, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + # Asynchronous batch embedding of multiple texts embeddings = await vectorizer.aembed_many( ["Hello, world!", "How are you?"], @@ -82,15 +78,7 @@ class CustomTextVectorizer(BaseVectorizer): """ - # User-provided callables - _embed: Callable = PrivateAttr() - _embed_many: Optional[Callable] = PrivateAttr() - _aembed: Optional[Callable] = PrivateAttr() - _aembed_many: Optional[Callable] = PrivateAttr() - - # Validation flags for async methods - _aembed_validated: bool = PrivateAttr(default=False) - _aembed_many_validated: bool = PrivateAttr(default=False) + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, @@ -99,234 +87,193 @@ def __init__( aembed: Optional[Callable] = None, aembed_many: Optional[Callable] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, ): """Initialize the Custom vectorizer. Args: embed (Callable): a Callable function that accepts a string object and returns a list of floats. - embed_many (Optional[Callable)]: a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. + embed_many (Optional[Callable]): a Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. aembed (Optional[Callable]): an asyncronous Callable function that accepts a string object and returns a lists of floats. Defaults to None. - aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. + aembed_many (Optional[Callable]): an asyncronous Callable function that accepts a list of string objects and returns a list containing lists of floats. Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ValueError: if embedding validation fails. """ - super().__init__(model=self.type, dtype=dtype) + # First, determine the dimensions + try: + test_result = embed("dimension test") + _check_vector(test_result, "embed") + dims = len(test_result) + except Exception as e: + raise ValueError(f"Failed to validate embed method: {e}") + + # Initialize parent with known information + super().__init__(model="custom", dtype=dtype, dims=dims, cache=cache) - # Store user-provided callables - self._embed = embed - self._embed_many = embed_many - self._aembed = aembed - self._aembed_many = aembed_many + # Now setup the functions and validation flags + self._setup_functions(embed, embed_many, aembed, aembed_many) - # Set dims - self.dims = self._validate_sync_callables() + def _setup_functions(self, embed, embed_many, aembed, aembed_many): + """Setup the user-provided embedding functions.""" + self._embed_func = embed + self._embed_func_many = embed_many + self._aembed_func = aembed + self._aembed_func_many = aembed_many + + # Initialize validation flags + self._aembed_validated = False + self._aembed_many_validated = False + + # Validate the other functions if provided + self._validate_optional_funcs() @property def type(self) -> str: return "custom" - def _validate_sync_callables(self) -> int: + def _validate_optional_funcs(self) -> None: """ - Validate the sync embed function with a test call and discover the dimension. - Optionally validate embed_many if provided. Returns the discovered dimension. + Optionally validate the other user-provided functions if they exist. Raises: - ValueError: If embed or embed_many produce malformed results or fail entirely. + ValueError: If any provided function produces invalid results. """ - # Check embed - try: - test_single = self._embed("dimension test") - _check_vector(test_single, "embed") - dims = len(test_single) - except Exception as e: - raise ValueError(f"Invalid embedding method: {e}") - - # Check embed_many - if self._embed_many: + # Check embed_many if provided + if self._embed_func_many: try: - test_batch = self._embed_many(["dimension test (many)"]) + test_batch = self._embed_func_many(["dimension test (many)"]) _check_vector(test_batch, "embed_many") except Exception as e: - raise ValueError(f"Invalid embedding method: {e}") + raise ValueError(f"Invalid embed_many function: {e}") - return dims - - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """ - Generate an embedding for a single piece of text using your sync embed function. + def _embed(self, text: str, **kwargs) -> List[float]: + """Generate a vector embedding for a single text using the provided user function. Args: - text (str): The text to embed. - preprocess (Optional[Callable]): An optional callable to preprocess the text. - as_buffer (bool): If True, return the embedding as a byte buffer. + text: Text to embed + **kwargs: Additional parameters to pass to the user function Returns: - Union[List[float], bytes]: The embedding of the input text. + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the input is not a string. + TypeError: If text is not a string + ValueError: If embedding fails """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) - try: - result = self._embed(text, **kwargs) + result = self._embed_func(text, **kwargs) + return result except Exception as e: raise ValueError(f"Embedding text failed: {e}") - return self._process_embedding(result, as_buffer, dtype) - - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """ - Generate embeddings for multiple pieces of text in batches using your sync embed_many function. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Generate vector embeddings for a batch of texts using the provided user function. Args: - texts (List[str]): A list of texts to embed. - preprocess (Optional[Callable]): Optional preprocessing for each text. - batch_size (int): Number of texts per batch. - as_buffer (bool): If True, convert each embedding to a byte buffer. + texts: List of texts to embed + batch_size: Number of texts to process in each batch + **kwargs: Additional parameters to pass to the user function Returns: - Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the input is not a list of strings. - NotImplementedError: If no embed_many function was provided. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - if not self._embed_many: - raise NotImplementedError("No embed_many function was provided.") - - dtype = kwargs.pop("dtype", self.dtype) - embeddings: Union[List[List[float]], List[bytes]] = [] + if not self._embed_func_many: + # Fallback: Use _embed for each text if no batch function provided + return [self._embed(text, **kwargs) for text in texts] try: - for batch in self.batchify(texts, batch_size, preprocess): - results = self._embed_many(batch, **kwargs) - processed = [ - self._process_embedding(r, as_buffer, dtype) for r in results - ] - embeddings.extend(processed) + results = self._embed_func_many(texts, **kwargs) + return results except Exception as e: - raise ValueError(f"Embedding text failed: {e}") + raise ValueError(f"Embedding texts failed: {e}") - return embeddings - - @validate_async - @deprecated_argument("dtype") - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - """ - Asynchronously generate an embedding for a single piece of text. + async def _aembed(self, text: str, **kwargs) -> List[float]: + """Asynchronously generate a vector embedding for a single text. Args: - text (str): The text to embed. - preprocess (Optional[Callable]): An optional callable to preprocess the text. - as_buffer (bool): If True, return the embedding as a byte buffer. + text: Text to embed + **kwargs: Additional parameters to pass to the user async function Returns: - List[float]: The embedding of the input text. + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the input is not a string. - NotImplementedError: If no aembed function was provided. + TypeError: If text is not a string + NotImplementedError: If no aembed function was provided + ValueError: If embedding fails """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") - if not self._aembed: - raise NotImplementedError("No aembed function was provided.") + if not self._aembed_func: + return self._embed(text, **kwargs) - if preprocess: - text = preprocess(text) + try: + result = await self._aembed_func(text, **kwargs) - dtype = kwargs.pop("dtype", self.dtype) + # Validate result on first call + if not self._aembed_validated: + _check_vector(result, "aembed") + self._aembed_validated = True - try: - result = await self._aembed(text, **kwargs) + return result except Exception as e: raise ValueError(f"Embedding text failed: {e}") - return self._process_embedding(result, as_buffer, dtype) - - @validate_async - @deprecated_argument("dtype") - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """ - Asynchronously generate embeddings for multiple pieces of text in batches. + async def _aembed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Asynchronously generate vector embeddings for a batch of texts. Args: - texts (List[str]): The texts to embed. - preprocess (Optional[Callable]): Optional preprocessing for each text. - batch_size (int): Number of texts per batch. - as_buffer (bool): If True, convert each embedding to a byte buffer. + texts: List of texts to embed + batch_size: Number of texts to process in each batch + **kwargs: Additional parameters to pass to the user async function Returns: - Union[List[List[float]], List[bytes]]: A list of embeddings, where each embedding is a list of floats or bytes. + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the input is not a list of strings. - NotImplementedError: If no aembed_many function was provided. + TypeError: If texts is not a list of strings + NotImplementedError: If no aembed_many function was provided + ValueError: If embedding fails """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if texts and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - if not self._aembed_many: - raise NotImplementedError("No aembed_many function was provided.") - - dtype = kwargs.pop("dtype", self.dtype) - embeddings: Union[List[List[float]], List[bytes]] = [] + if not self._aembed_func_many: + return self._embed_many(texts, batch_size, **kwargs) try: - for batch in self.batchify(texts, batch_size, preprocess): - results = await self._aembed_many(batch, **kwargs) - processed = [ - self._process_embedding(r, as_buffer, dtype) for r in results - ] - embeddings.extend(processed) - except Exception as e: - raise ValueError(f"Embedding text failed: {e}") + results = await self._aembed_func_many(texts, **kwargs) + + # Validate result on first call + if not self._aembed_many_validated: + _check_vector(results, "aembed_many") + self._aembed_many_validated = True - return embeddings + return results + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index bafba41d..2188d9e4 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -1,34 +1,59 @@ -from typing import Any, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Union from pydantic.v1 import PrivateAttr +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer class HFTextVectorizer(BaseVectorizer): - """The HFTextVectorizer class is designed to leverage the power of Hugging - Face's Sentence Transformers for generating text embeddings. This vectorizer - is particularly useful in scenarios where advanced natural language + """The HFTextVectorizer class leverages Hugging Face's Sentence Transformers + for generating vector embeddings from text input. + + This vectorizer is particularly useful in scenarios where advanced natural language processing and understanding are required, and ideal for running on your own - hardware (for free). + hardware without usage fees. + + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. Utilizing this vectorizer involves specifying a pre-trained model from Hugging Face's vast collection of Sentence Transformers. These models are trained on a variety of datasets and tasks, ensuring versatility and - robust performance across different text embedding needs. Additionally, - make sure the `sentence-transformers` library is installed with - `pip install sentence-transformers==2.2.2`. + robust performance across different embedding needs. + + Requirements: + - The `sentence-transformers` library must be installed with pip. .. code-block:: python - # Embedding a single text + # Basic usage vectorizer = HFTextVectorizer(model="sentence-transformers/all-mpnet-base-v2") embedding = vectorizer.embed("Hello, world!") - # Embedding a batch of texts - embeddings = vectorizer.embed_many(["Hello, world!", "How are you?"], batch_size=2) + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="my_embeddings_cache") + vectorizer = HFTextVectorizer( + model="sentence-transformers/all-mpnet-base-v2", + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + + # Batch processing + embeddings = vectorizer.embed_many( + ["Hello, world!", "How are you?"], + batch_size=2 + ) """ _client: Any = PrivateAttr() @@ -37,6 +62,7 @@ def __init__( self, model: str = "sentence-transformers/all-mpnet-base-v2", dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the Hugging Face text vectorizer. @@ -48,13 +74,17 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. + **kwargs: Additional parameters to pass to the SentenceTransformer + constructor. Raises: ImportError: If the sentence-transformers library is not installed. ValueError: If there is an error setting the embedding model dimensions. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) + super().__init__(model=model, dtype=dtype, cache=cache) # Init client self._initialize_client(model, **kwargs) # Set model dimensions after init @@ -74,7 +104,7 @@ def _initialize_client(self, model: str, **kwargs): def _set_model_dims(self): try: - embedding = self.embed("dimension check") + embedding = self._embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Empty response from the embedding model: {str(ke)}") except Exception as e: # pylint: disable=broad-except @@ -82,85 +112,50 @@ def _set_model_dims(self): raise ValueError(f"Error setting embedding model dimensions: {str(e)}") return len(embedding) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the Hugging Face sentence transformer. + def _embed(self, text: str, **kwargs) -> List[float]: + """Generate a vector embedding for a single text using the Hugging Face model. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing - callable to perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional model-specific parameters Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the text. + TypeError: If the input is not a string """ if not isinstance(text, str): raise TypeError("Must pass in a str value to embed.") - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) - embedding = self._client.encode([text], **kwargs)[0] - return self._process_embedding(embedding.tolist(), as_buffer, dtype) + return embedding.tolist() - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Asynchronously embed many chunks of texts using the Hugging Face - sentence transformer. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Generate vector embeddings for a batch of texts using the Hugging Face model. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing - callable to perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each batch + **kwargs: Additional model-specific parameters Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If the input is not a list of strings """ if not isinstance(texts, list): raise TypeError("Must pass in a list of str values to embed.") if len(texts) > 0 and not isinstance(texts[0], str): raise TypeError("Must pass in a list of str values to embed.") - dtype = kwargs.pop("dtype", self.dtype) - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): + for batch in self.batchify(texts, batch_size, None): batch_embeddings = self._client.encode(batch, **kwargs) - embeddings.extend( - [ - self._process_embedding(embedding.tolist(), as_buffer, dtype) - for embedding in batch_embeddings - ] - ) + embeddings.extend([embedding.tolist() for embedding in batch_embeddings]) return embeddings @property diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 05133b37..576acd87 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -1,10 +1,13 @@ import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -27,15 +30,34 @@ class MistralAITextVectorizer(BaseVectorizer): allowing for batch processing of texts and flexibility in handling preprocessing tasks. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Synchronous embedding of a single text + # Basic usage vectorizer = MistralAITextVectorizer( - model="mistral-embed" + model="mistral-embed", api_config={"api_key": "your_api_key"} # OR set MISTRAL_API_KEY in your env ) embedding = vectorizer.embed("Hello, world!") + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="mistral_embeddings_cache") + + vectorizer = MistralAITextVectorizer( + model="mistral-embed", + api_config={"api_key": "your_api_key"}, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + # Asynchronous batch embedding of multiple texts embeddings = await vectorizer.aembed_many( ["Hello, world!", "How are you?"], @@ -44,41 +66,57 @@ class MistralAITextVectorizer(BaseVectorizer): """ - _client: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "mistral-embed", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the MistralAI vectorizer. Args: model (str): Model to use for embedding. Defaults to - 'text-embedding-ada-002'. + 'mistral-embed'. api_config (Optional[Dict], optional): Dictionary containing the API key. Defaults to None. dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the mistralai library is not installed. ValueError: If the Mistral API key is not provided. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize client and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the MistralAI client and determine the embedding dimensions.""" + # Initialize client self._initialize_client(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ - Setup the Mistral clients using the provided API key or an + Setup the Mistral client using the provided API key or an environment variable. + + Args: + api_config: Dictionary with API configuration options + **kwargs: Additional arguments to pass to MistralAI client + + Raises: + ImportError: If the mistralai library is not installed + ValueError: If no API key is provided """ if api_config is None: api_config = {} @@ -88,8 +126,8 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): from mistralai import Mistral except ImportError: raise ImportError( - "MistralAI vectorizer requires the mistralai library. \ - Please install with `pip install mistralai`" + "MistralAI vectorizer requires the mistralai library. " + "Please install with `pip install mistralai`" ) # Fetch the API key from api_config or environment variable @@ -99,203 +137,171 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): if not api_key: raise ValueError( "MISTRAL API key is required. " - "Provide it in api_config or set the MISTRAL_API_KEY\ - environment variable." + "Provide it in api_config or set the MISTRAL_API_KEY environment variable." ) + # Store client as a regular attribute instead of PrivateAttr self._client = Mistral(api_key=api_key, **kwargs) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the MISTRAL API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of texts using the Mistral API. + def _embed(self, text: str, **kwargs) -> List[float]: + """ + Generate a vector embedding for a single text using the MistralAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing - callable to perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the MistralAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create( - model=self.model, inputs=batch, **kwargs + try: + result = self._client.embeddings.create( + model=self.model, inputs=[text], **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding # type: ignore + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the Mistral API. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Generate vector embeddings for a batch of texts using the MistralAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the MistralAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = self._client.embeddings.create( - model=self.model, inputs=[text], **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = self._client.embeddings.create( + model=self.model, inputs=batch, **kwargs + ) + embeddings.extend([r.embedding for r in response.data]) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> List[List[float]]: - """Asynchronously embed many chunks of texts using the Mistral API. + async def _aembed(self, text: str, **kwargs) -> List[float]: + """ + Asynchronously generate a vector embedding for a single text using the MistralAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the MistralAI API Returns: - List[List[float]]: List of embeddings. + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = await self._client.embeddings.create_async( - model=self.model, inputs=batch, **kwargs + try: + result = await self._client.embeddings.create_async( + model=self.model, inputs=[text], **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding # type: ignore + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - """Asynchronously embed a chunk of text using the MistralAPI. + async def _aembed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Asynchronously generate vector embeddings for a batch of texts using the MistralAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the MistralAI API Returns: - List[float]: Embedding. + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = await self._client.embeddings.create_async( - model=self.model, inputs=[text], **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = await self._client.embeddings.create_async( + model=self.model, inputs=batch, **kwargs + ) + embeddings.extend([r.embedding for r in response.data]) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index eee0764a..6ff1f99c 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,10 +1,13 @@ import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -27,15 +30,34 @@ class OpenAITextVectorizer(BaseVectorizer): allowing for batch processing of texts and flexibility in handling preprocessing tasks. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Synchronous embedding of a single text + # Basic usage with OpenAI embeddings vectorizer = OpenAITextVectorizer( model="text-embedding-ada-002", api_config={"api_key": "your_api_key"} # OR set OPENAI_API_KEY in your env ) embedding = vectorizer.embed("Hello, world!") + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="openai_embeddings_cache") + + vectorizer = OpenAITextVectorizer( + model="text-embedding-ada-002", + api_config={"api_key": "your_api_key"}, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + # Asynchronous batch embedding of multiple texts embeddings = await vectorizer.aembed_many( ["Hello, world!", "How are you?"], @@ -44,14 +66,14 @@ class OpenAITextVectorizer(BaseVectorizer): """ - _client: Any = PrivateAttr() - _aclient: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the OpenAI vectorizer. @@ -64,22 +86,37 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the openai library is not installed. ValueError: If the OpenAI API key is not provided. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init clients + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize clients and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the OpenAI clients and determine the embedding dimensions.""" + # Initialize clients self._initialize_clients(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after client initialization self.dims = self._set_model_dims() def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ Setup the OpenAI clients using the provided API key or an environment variable. + + Args: + api_config: Dictionary with API configuration options + **kwargs: Additional arguments to pass to OpenAI clients + + Raises: + ImportError: If the openai library is not installed + ValueError: If no API key is provided """ if api_config is None: api_config = {} @@ -89,8 +126,8 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): from openai import AsyncOpenAI, OpenAI except ImportError: raise ImportError( - "OpenAI vectorizer requires the openai library. \ - Please install with `pip install openai`" + "OpenAI vectorizer requires the openai library. " + "Please install with `pip install openai>=1.13.0`" ) api_key = ( @@ -99,206 +136,167 @@ def _initialize_clients(self, api_config: Optional[Dict], **kwargs): if not api_key: raise ValueError( "OpenAI API key is required. " - "Provide it in api_config or set the OPENAI_API_KEY\ - environment variable." + "Provide it in api_config or set the OPENAI_API_KEY environment variable." ) self._client = OpenAI(api_key=api_key, **api_config, **kwargs) self._aclient = AsyncOpenAI(api_key=api_key, **api_config, **kwargs) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check") + # Use the parent embed() method which handles caching + embedding = self._embed("dimension check") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of texts using the OpenAI API. + def _embed(self, text: str, **kwargs) -> List[float]: + """Generate a vector embedding for a single text using the OpenAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing - callable to perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the OpenAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the text. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embeddings.create( - input=batch, model=self.model, **kwargs + try: + result = self._client.embeddings.create( + input=[text], model=self.model, **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the OpenAI API. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Generate vector embeddings for a batch of texts using the OpenAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the OpenAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the text. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = self._client.embeddings.create( - input=[text], model=self.model, **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + try: + response = self._client.embeddings.create( + input=batch, model=self.model, **kwargs + ) + embeddings += [r.embedding for r in response.data] + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") + return embeddings @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Asynchronously embed many chunks of texts using the OpenAI API. + async def _aembed(self, text: str, **kwargs) -> List[float]: + """Asynchronously generate a vector embedding for a single text using the OpenAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the OpenAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the text. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = await self._aclient.embeddings.create( - input=batch, model=self.model, **kwargs + try: + result = await self._aclient.embeddings.create( + input=[text], model=self.model, **kwargs ) - embeddings += [ - self._process_embedding(r.embedding, as_buffer, dtype) - for r in response.data - ] - return embeddings + return result.data[0].embedding + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Asynchronously embed a chunk of text using the OpenAI API. + async def _aembed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """Asynchronously generate vector embeddings for a batch of texts using the OpenAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the OpenAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the text. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = await self._aclient.embeddings.create( - input=[text], model=self.model, **kwargs - ) - return self._process_embedding(result.data[0].embedding, as_buffer, dtype) + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + try: + response = await self._aclient.embeddings.create( + input=batch, model=self.model, **kwargs + ) + embeddings += [r.embedding for r in response.data] + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") + return embeddings @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index ebe2a625..f20f4866 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,10 +1,13 @@ import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -23,9 +26,12 @@ class VertexAITextVectorizer(BaseVectorizer): env var. Additionally, the vertexai python client must be installed with `pip install google-cloud-aiplatform>=1.26`. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python - # Synchronous embedding of a single text + # Basic usage vectorizer = VertexAITextVectorizer( model="textembedding-gecko", api_config={ @@ -34,21 +40,41 @@ class VertexAITextVectorizer(BaseVectorizer): }) embedding = vectorizer.embed("Hello, world!") - # Asynchronous batch embedding of multiple texts - embeddings = await vectorizer.embed_many( + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="vertexai_embeddings_cache") + + vectorizer = VertexAITextVectorizer( + model="textembedding-gecko", + api_config={ + "project_id": "your_gcp_project_id", + "location": "your_gcp_location", + }, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed("Hello, world!") + + # Second call will retrieve from cache + embedding2 = vectorizer.embed("Hello, world!") + + # Batch embedding of multiple texts + embeddings = vectorizer.embed_many( ["Hello, world!", "Goodbye, world!"], batch_size=2 ) """ - _client: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "textembedding-gecko", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the VertexAI vectorizer. @@ -61,22 +87,37 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the google-cloud-aiplatform library is not installed. ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize client and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the VertexAI client and determine the embedding dimensions.""" + # Initialize client self._initialize_client(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ - Setup the VertexAI clients using the provided API key or an - environment variable. + Setup the VertexAI client using the provided config options or + environment variables. + + Args: + api_config: Dictionary with GCP configuration options + **kwargs: Additional arguments for initialization + + Raises: + ImportError: If the google-cloud-aiplatform library is not installed + ValueError: If required parameters are not provided """ # Fetch the project_id and location from api_config or environment variables project_id = ( @@ -116,104 +157,94 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): "Please install with `pip install google-cloud-aiplatform>=1.26`" ) + # Store client as a regular attribute instead of PrivateAttr self._client = TextEmbeddingModel.from_pretrained(self.model) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: int = 10, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of text using the VertexAI Embeddings API. + def _embed(self, text: str, **kwargs) -> List[float]: + """ + Generate a vector embedding for a single text using the VertexAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. Defaults to 10. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + text: Text to embed + **kwargs: Additional parameters to pass to the VertexAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If text is not a string + ValueError: If embedding fails """ - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(text, str): + raise TypeError("Must pass in a str value to embed.") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.get_embeddings(batch, **kwargs) - embeddings += [ - self._process_embedding(r.values, as_buffer, dtype) for r in response - ] - return embeddings + try: + result = self._client.get_embeddings([text], **kwargs) + return result[0].values + except Exception as e: + raise ValueError(f"Embedding text failed: {e}") @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the VertexAI Embeddings API. + def _embed_many( + self, texts: List[str], batch_size: int = 10, **kwargs + ) -> List[List[float]]: + """ + Generate vector embeddings for a batch of texts using the VertexAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the VertexAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If the wrong input type is passed in for the test. + TypeError: If texts is not a list of strings + ValueError: If embedding fails """ - if not isinstance(text, str): - raise TypeError("Must pass in a str value to embed.") - - if preprocess: - text = preprocess(text) - - dtype = kwargs.pop("dtype", self.dtype) + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") - result = self._client.get_embeddings([text], **kwargs) - return self._process_embedding(result[0].values, as_buffer, dtype) + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = self._client.get_embeddings(batch, **kwargs) + embeddings.extend([r.values for r in response]) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index 9d015a81..8b181689 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -1,10 +1,13 @@ import os -from typing import Any, Callable, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional -from pydantic import PrivateAttr +from pydantic import ConfigDict from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type +if TYPE_CHECKING: + from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache + from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -26,10 +29,14 @@ class VoyageAITextVectorizer(BaseVectorizer): The vectorizer supports both synchronous and asynchronous operations, allows for batch processing of texts and flexibility in handling preprocessing tasks. + You can optionally enable caching to improve performance when generating + embeddings for repeated text inputs. + .. code-block:: python from redisvl.utils.vectorize import VoyageAITextVectorizer + # Basic usage vectorizer = VoyageAITextVectorizer( model="voyage-large-2", api_config={"api_key": "your-voyageai-api-key"} # OR set VOYAGE_API_KEY in your env @@ -43,16 +50,38 @@ class VoyageAITextVectorizer(BaseVectorizer): input_type="document" ) + # With caching enabled + from redisvl.extensions.cache.embeddings import EmbeddingsCache + cache = EmbeddingsCache(name="voyageai_embeddings_cache") + + vectorizer = VoyageAITextVectorizer( + model="voyage-large-2", + api_config={"api_key": "your-voyageai-api-key"}, + cache=cache + ) + + # First call will compute and cache the embedding + embedding1 = vectorizer.embed( + text="your input query text here", + input_type="query" + ) + + # Second call will retrieve from cache + embedding2 = vectorizer.embed( + text="your input query text here", + input_type="query" + ) + """ - _client: Any = PrivateAttr() - _aclient: Any = PrivateAttr() + model_config = ConfigDict(arbitrary_types_allowed=True) def __init__( self, model: str = "voyage-large-2", api_config: Optional[Dict] = None, dtype: str = "float32", + cache: Optional["EmbeddingsCache"] = None, **kwargs, ): """Initialize the VoyageAI vectorizer. @@ -66,22 +95,37 @@ def __init__( dtype (str): the default datatype to use when embedding text as byte arrays. Used when setting `as_buffer=True` in calls to embed() and embed_many(). Defaults to 'float32'. + cache (Optional[EmbeddingsCache]): Optional EmbeddingsCache instance to cache embeddings for + better performance with repeated texts. Defaults to None. Raises: ImportError: If the voyageai library is not installed. ValueError: If the API key is not provided. """ - super().__init__(model=model, dtype=dtype) - # Init client + super().__init__(model=model, dtype=dtype, cache=cache) + # Initialize client and set up the model + self._setup(api_config, **kwargs) + + def _setup(self, api_config: Optional[Dict], **kwargs): + """Set up the VoyageAI client and determine the embedding dimensions.""" + # Initialize client self._initialize_client(api_config, **kwargs) - # Set model dimensions after init + # Set model dimensions after initialization self.dims = self._set_model_dims() def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the VoyageAI clients using the provided API key or an environment variable. + + Args: + api_config: Dictionary with API configuration options + **kwargs: Additional arguments to pass to VoyageAI clients + + Raises: + ImportError: If the voyageai library is not installed + ValueError: If no API key is provided """ if api_config is None: api_config = {} @@ -91,8 +135,8 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): from voyageai import AsyncClient, Client except ImportError: raise ImportError( - "VoyageAI vectorizer requires the voyageai library. \ - Please install with `pip install voyageai`" + "VoyageAI vectorizer requires the voyageai library. " + "Please install with `pip install voyageai`" ) # Fetch the API key from api_config or environment variable @@ -104,265 +148,206 @@ def _initialize_client(self, api_config: Optional[Dict], **kwargs): "VoyageAI API key is required. " "Provide it in api_config or set the VOYAGE_API_KEY environment variable." ) + self._client = Client(api_key=api_key, **kwargs) self._aclient = AsyncClient(api_key=api_key, **kwargs) def _set_model_dims(self) -> int: + """ + Determine the dimensionality of the embedding model by making a test call. + + Returns: + int: Dimensionality of the embedding model + + Raises: + ValueError: If embedding dimensions cannot be determined + """ try: - embedding = self.embed("dimension check", input_type="document") + # Call the protected _embed method to avoid caching this test embedding + embedding = self._embed("dimension check", input_type="document") + return len(embedding) except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VoyageAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except # fall back (TODO get more specific) raise ValueError(f"Error setting embedding model dimensions: {str(e)}") - return len(embedding) - @deprecated_argument("dtype") - def embed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[float], bytes]: - """Embed a chunk of text using the VoyageAI Embeddings API. + def _get_batch_size(self) -> int: + """ + Determine the appropriate batch size based on the model being used. - Can provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. For retrieval/search use cases, - we recommend specifying this argument when encoding queries or documents to enhance retrieval quality. - Embeddings generated with and without the input_type argument are compatible. + Returns: + int: Recommended batch size for the current model + """ + if self.model in ["voyage-2", "voyage-02"]: + return 72 + elif self.model == "voyage-3-lite": + return 30 + elif self.model == "voyage-3": + return 10 + else: + return 7 # Default for other models + + def _validate_input( + self, texts: List[str], input_type: Optional[str], truncation: Optional[bool] + ): + """ + Validate the inputs to the embedding methods. - Supported input types are ``document`` and ``query`` + Args: + texts: List of texts to embed + input_type: Type of input (document or query) + truncation: Whether to truncate long texts - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type="document" and when you are - querying the database, you should set the input_type="query". + Raises: + TypeError: If inputs are invalid + """ + if not isinstance(texts, list): + raise TypeError("Must pass in a list of str values to embed.") + if texts and not isinstance(texts[0], str): + raise TypeError("Must pass in a list of str values to embed.") + if input_type is not None and input_type not in ["document", "query"]: + raise TypeError( + "Must pass in a allowed value for voyageai embedding input_type. " + "See https://docs.voyageai.com/docs/embeddings." + ) + if truncation is not None and not isinstance(truncation, bool): + raise TypeError("Truncation (optional) parameter is a bool.") + + def _embed(self, text: str, **kwargs) -> List[float]: + """ + Generate a vector embedding for a single text using the VoyageAI API. Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - truncation (bool): Whether to truncate the input texts to fit within the context length. - Check https://docs.voyageai.com/docs/embeddings + text: Text to embed + **kwargs: Additional parameters to pass to the VoyageAI API Returns: - Union[List[float], bytes]: Embedding as a list of floats, or as a bytes - object if as_buffer=True + List[float]: Vector embedding as a list of floats Raises: - TypeError: If an invalid input_type is provided. + TypeError: If text is not a string or parameters are invalid + ValueError: If embedding fails """ - return self.embed_many( - texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs - )[0] + # Simply call _embed_many with a single text and return the first result + result = self._embed_many([text], **kwargs) + return result[0] @retry( wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6), retry=retry_if_not_exception_type(TypeError), ) - @deprecated_argument("dtype") - def embed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: Optional[int] = None, - as_buffer: bool = False, - **kwargs, - ) -> Union[List[List[float]], List[bytes]]: - """Embed many chunks of text using the VoyageAI Embeddings API. - - Can provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. For retrieval/search use cases, - we recommend specifying this argument when encoding queries or documents to enhance retrieval quality. - Embeddings generated with and without the input_type argument are compatible. - - Supported input types are ``document`` and ``query`` - - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type="document" and when you are - querying the database, you should set the input_type="query". + def _embed_many( + self, texts: List[str], batch_size: Optional[int] = None, **kwargs + ) -> List[List[float]]: + """ + Generate vector embeddings for a batch of texts using the VoyageAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. . - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - truncation (bool): Whether to truncate the input texts to fit within the context length. - Check https://docs.voyageai.com/docs/embeddings + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the VoyageAI API Returns: - Union[List[List[float]], List[bytes]]: List of embeddings as lists of floats, - or as bytes objects if as_buffer=True + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: If an invalid input_type is provided. - + TypeError: If texts is not a list of strings or parameters are invalid + ValueError: If embedding fails """ input_type = kwargs.pop("input_type", None) truncation = kwargs.pop("truncation", None) - dtype = kwargs.pop("dtype", self.dtype) - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - if input_type is not None and input_type not in ["document", "query"]: - raise TypeError( - "Must pass in a allowed value for voyageai embedding input_type. \ - See https://docs.voyageai.com/docs/embeddings." - ) - - if truncation is not None and not isinstance(truncation, bool): - raise TypeError("Truncation (optional) parameter is a bool.") + # Validate inputs + self._validate_input(texts, input_type, truncation) + # Determine batch size if not provided if batch_size is None: - batch_size = ( - 72 - if self.model in ["voyage-2", "voyage-02"] - else ( - 30 - if self.model == "voyage-3-lite" - else (10 if self.model == "voyage-3" else 7) + batch_size = self._get_batch_size() + + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = self._client.embed( + texts=batch, + model=self.model, + input_type=input_type, + truncation=truncation, + **kwargs, ) - ) + embeddings.extend(response.embeddings) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = self._client.embed( - texts=batch, model=self.model, input_type=input_type, **kwargs - ) - embeddings += [ - self._process_embedding(embedding, as_buffer, dtype) - for embedding in response.embeddings - ] - return embeddings - - @deprecated_argument("dtype") - async def aembed_many( - self, - texts: List[str], - preprocess: Optional[Callable] = None, - batch_size: Optional[int] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[List[float]]: - """Embed many chunks of text using the VoyageAI Embeddings API. + async def _aembed(self, text: str, **kwargs) -> List[float]: + """ + Asynchronously generate a vector embedding for a single text using the VoyageAI API. - Can provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. For retrieval/search use cases, - we recommend specifying this argument when encoding queries or documents to enhance retrieval quality. - Embeddings generated with and without the input_type argument are compatible. + Args: + text: Text to embed + **kwargs: Additional parameters to pass to the VoyageAI API - Supported input types are ``document`` and ``query`` + Returns: + List[float]: Vector embedding as a list of floats - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type="document" and when you are - querying the database, you should set the input_type="query". + Raises: + TypeError: If text is not a string or parameters are invalid + ValueError: If embedding fails + """ + # Simply call _aembed_many with a single text and return the first result + result = await self._aembed_many([text], **kwargs) + return result[0] + + @retry( + wait=wait_random_exponential(min=1, max=60), + stop=stop_after_attempt(6), + retry=retry_if_not_exception_type(TypeError), + ) + async def _aembed_many( + self, texts: List[str], batch_size: Optional[int] = None, **kwargs + ) -> List[List[float]]: + """ + Asynchronously generate vector embeddings for a batch of texts using the VoyageAI API. Args: - texts (List[str]): List of text chunks to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - batch_size (int, optional): Batch size of texts to use when creating - embeddings. . - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - truncation (bool): Whether to truncate the input texts to fit within the context length. - Check https://docs.voyageai.com/docs/embeddings + texts: List of texts to embed + batch_size: Number of texts to process in each API call + **kwargs: Additional parameters to pass to the VoyageAI API Returns: - List[List[float]]: List of embeddings. + List[List[float]]: List of vector embeddings as lists of floats Raises: - TypeError: In an invalid input_type is provided. - + TypeError: If texts is not a list of strings or parameters are invalid + ValueError: If embedding fails """ input_type = kwargs.pop("input_type", None) truncation = kwargs.pop("truncation", None) - dtype = kwargs.pop("dtype", self.dtype) - if not isinstance(texts, list): - raise TypeError("Must pass in a list of str values to embed.") - if len(texts) > 0 and not isinstance(texts[0], str): - raise TypeError("Must pass in a list of str values to embed.") - if input_type is not None and input_type not in ["document", "query"]: - raise TypeError( - "Must pass in a allowed value for voyageai embedding input_type. \ - See https://docs.voyageai.com/docs/embeddings." - ) - - if truncation is not None and not isinstance(truncation, bool): - raise TypeError("Truncation (optional) parameter is a bool.") + # Validate inputs + self._validate_input(texts, input_type, truncation) + # Determine batch size if not provided if batch_size is None: - batch_size = ( - 72 - if self.model in ["voyage-2", "voyage-02"] - else ( - 30 - if self.model == "voyage-3-lite" - else (10 if self.model == "voyage-3" else 7) - ) - ) - - embeddings: List = [] - for batch in self.batchify(texts, batch_size, preprocess): - response = await self._aclient.embed( - texts=batch, model=self.model, input_type=input_type, **kwargs - ) - embeddings += [ - self._process_embedding(embedding, as_buffer, dtype) - for embedding in response.embeddings - ] - return embeddings - - @deprecated_argument("dtype") - async def aembed( - self, - text: str, - preprocess: Optional[Callable] = None, - as_buffer: bool = False, - **kwargs, - ) -> List[float]: - """Embed a chunk of text using the VoyageAI Embeddings API. - - Can provide the embedding `input_type` as a `kwarg` to this method - that specifies the type of input you're giving to the model. For retrieval/search use cases, - we recommend specifying this argument when encoding queries or documents to enhance retrieval quality. - Embeddings generated with and without the input_type argument are compatible. - - Supported input types are ``document`` and ``query`` - - When hydrating your Redis DB, the documents you want to search over - should be embedded with input_type="document" and when you are - querying the database, you should set the input_type="query". - - Args: - text (str): Chunk of text to embed. - preprocess (Optional[Callable], optional): Optional preprocessing callable to - perform before vectorization. Defaults to None. - as_buffer (bool, optional): Whether to convert the raw embedding - to a byte string. Defaults to False. - input_type (str): Specifies the type of input passed to the model. - truncation (bool): Whether to truncate the input texts to fit within the context length. - Check https://docs.voyageai.com/docs/embeddings - - Returns: - List[float]: Embedding. + batch_size = self._get_batch_size() - Raises: - TypeError: In an invalid input_type is provided. - """ - result = await self.aembed_many( - texts=[text], preprocess=preprocess, as_buffer=as_buffer, **kwargs - ) - return result[0] + try: + embeddings: List = [] + for batch in self.batchify(texts, batch_size): + response = await self._aclient.embed( + texts=batch, + model=self.model, + input_type=input_type, + truncation=truncation, + **kwargs, + ) + embeddings.extend(response.embeddings) + return embeddings + except Exception as e: + raise ValueError(f"Embedding texts failed: {e}") + + @property + def type(self) -> str: + return "voyageai" diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 526709c9..43bad503 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -9,7 +9,7 @@ from redis.exceptions import ConnectionError from redisvl.exceptions import RedisModuleVersionError -from redisvl.extensions.cache import SemanticCache +from redisvl.extensions.cache.llm import SemanticCache from redisvl.index.index import AsyncSearchIndex, SearchIndex from redisvl.query.filter import Num, Tag, Text from redisvl.utils.vectorize import HFTextVectorizer diff --git a/tests/integration/test_vectorizers.py b/tests/integration/test_vectorizers.py index 36e444de..d5727664 100644 --- a/tests/integration/test_vectorizers.py +++ b/tests/integration/test_vectorizers.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from redisvl.extensions.cache.embeddings.embeddings import EmbeddingsCache +from redisvl.utils.utils import create_ulid from redisvl.utils.vectorize import ( AzureOpenAITextVectorizer, BedrockTextVectorizer, @@ -15,6 +17,25 @@ VoyageAITextVectorizer, ) +# Constants for testing +TEST_TEXT = "This is a test sentence." +TEST_TEXTS = ["This is the first test sentence.", "This is the second test sentence."] +TEST_VECTOR = [1.1, 2.2, 3.3, 4.4] + + +@pytest.fixture +def embeddings_cache(client): + """Create a real EmbeddingsCache for testing with a unique namespace.""" + # Use a unique prefix for this test run to avoid conflicts + unique_prefix = f"test_cache_{create_ulid()}" + + # Create the cache with a short TTL + cache = EmbeddingsCache(name=unique_prefix, ttl=10, redis_client=client) + + yield cache + + cache.clear() + @pytest.fixture( params=[ @@ -53,25 +74,49 @@ def vectorizer(request): elif request.param == CustomTextVectorizer: def embed(text): - return [1.1, 2.2, 3.3, 4.4] + return TEST_VECTOR def embed_many(texts): - return [[1.1, 2.2, 3.3, 4.4]] * len(texts) + return [TEST_VECTOR] * len(texts) + + async def aembed_func(text): + return TEST_VECTOR + + async def aembed_many_func(texts): + return [TEST_VECTOR] * len(texts) return request.param(embed=embed, embed_many=embed_many) @pytest.fixture -def bedrock_vectorizer(): - return BedrockTextVectorizer( - model=os.getenv("BEDROCK_MODEL_ID", "amazon.titan-embed-text-v2:0") +def cached_vectorizer(embeddings_cache): + """Create a simple custom vectorizer for testing.""" + + def embed(text): + return TEST_VECTOR + + def embed_many(texts): + return [TEST_VECTOR] * len(texts) + + async def aembed(text): + return TEST_VECTOR + + async def aembed_many(texts): + return [TEST_VECTOR] * len(texts) + + return CustomTextVectorizer( + embed=embed, + embed_many=embed_many, + aembed=aembed, + aembed_many=aembed_many, + cache=embeddings_cache, ) @pytest.fixture def custom_embed_func(): def embed(text: str): - return [1.1, 2.2, 3.3, 4.4] + return TEST_VECTOR return embed @@ -80,10 +125,10 @@ def embed(text: str): def custom_embed_class(): class MyEmbedder: def embed(self, text: str): - return [1.1, 2.2, 3.3, 4.4] + return TEST_VECTOR def embed_with_args(self, text: str, max_len=None): - return [1.1, 2.2, 3.3, 4.4][0:max_len] + return TEST_VECTOR[0:max_len] def embed_many(self, text_list): return [[1.1, 2.2, 3.3], [4.4, 5.5, 6.6]] @@ -99,7 +144,7 @@ def embed_many_with_args(self, texts, param=True): @pytest.mark.requires_api_keys def test_vectorizer_embed(vectorizer): - text = "This is a test sentence." + text = TEST_TEXT if isinstance(vectorizer, CohereTextVectorizer): embedding = vectorizer.embed(text, input_type="search_document") elif isinstance(vectorizer, VoyageAITextVectorizer): @@ -113,7 +158,7 @@ def test_vectorizer_embed(vectorizer): @pytest.mark.requires_api_keys def test_vectorizer_embed_many(vectorizer): - texts = ["This is the first test sentence.", "This is the second test sentence."] + texts = TEST_TEXTS if isinstance(vectorizer, CohereTextVectorizer): embeddings = vectorizer.embed_many(texts, input_type="search_document") elif isinstance(vectorizer, VoyageAITextVectorizer): @@ -140,6 +185,131 @@ def test_vectorizer_bad_input(vectorizer): vectorizer.embed_many(42) +def test_vectorizer_with_cache(cached_vectorizer): + """Test the complete cache flow - miss, store, hit.""" + # First call - should be a cache miss + first_result = cached_vectorizer.embed(TEST_TEXT) + assert first_result == TEST_VECTOR + + # Second call - should be a cache hit + second_result = cached_vectorizer.embed(TEST_TEXT) + assert second_result == TEST_VECTOR + + # Verify it's actually using the cache by checking the cached value exists + cached_entry = cached_vectorizer.cache.get( + text=TEST_TEXT, model_name=cached_vectorizer.model + ) + assert cached_entry is not None + assert cached_entry["embedding"] == TEST_VECTOR + + +def test_vectorizer_with_cache_skip(cached_vectorizer): + """Test embedding with skip_cache=True.""" + # Store a value in the cache + cached_vectorizer.embed(TEST_TEXT) + + # Call embed with skip_cache=True - should bypass cache + cached_vectorizer.cache.drop(text=TEST_TEXT, model_name=cached_vectorizer.model) + + # Store a deliberately different value in the cache + cached_vectorizer.cache.set( + text=TEST_TEXT, + model_name=cached_vectorizer.model, + embedding=[9.9, 8.8, 7.7, 6.6], + ) + + # Now call with skip_cache=True + result = cached_vectorizer.embed(TEST_TEXT, skip_cache=True) + + # Should generate fresh result, not use cached value + assert result == TEST_VECTOR + + # Cache should still have the original value + cached_entry = cached_vectorizer.cache.get( + text=TEST_TEXT, model_name=cached_vectorizer.model + ) + assert cached_entry["embedding"] == [9.9, 8.8, 7.7, 6.6] + + +def test_vectorizer_with_cache_many(cached_vectorizer): + """Test embedding many texts with partial cache hits/misses.""" + # Store an embedding for the first text only + cached_vectorizer.cache.set( + text=TEST_TEXTS[0], + model_name=cached_vectorizer.model, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Call embed_many - should hit cache for first text, miss for second + results = cached_vectorizer.embed_many(TEST_TEXTS) + + # Verify results + assert results[0] == [0.1, 0.2, 0.3, 0.4] # From cache + assert results[1] == TEST_VECTOR # Generated + + # Both should now be in cache + for text in TEST_TEXTS: + assert cached_vectorizer.cache.exists( + text=text, model_name=cached_vectorizer.model + ) + + +def test_vectorizer_with_cached_metadata(cached_vectorizer): + """Test passing metadata through to the cache.""" + # Call embed with metadata + test_metadata = {"source": "test", "importance": "high"} + cached_vectorizer.embed(TEST_TEXT, metadata=test_metadata) + + # Verify metadata was stored in cache + cached_entry = cached_vectorizer.cache.get( + text=TEST_TEXT, model_name=cached_vectorizer.model + ) + assert cached_entry["metadata"] == test_metadata + + +@pytest.mark.asyncio +async def test_vectorizer_with_cache_async(cached_vectorizer): + """Test async embedding with cache.""" + # First call - should be a cache miss + first_result = await cached_vectorizer.aembed(TEST_TEXT) + assert first_result == TEST_VECTOR + + # Second call - should be a cache hit + second_result = await cached_vectorizer.aembed(TEST_TEXT) + assert second_result == TEST_VECTOR + + # Verify it's actually using the cache + cached_entry = await cached_vectorizer.cache.aget( + text=TEST_TEXT, model_name=cached_vectorizer.model + ) + assert cached_entry is not None + assert cached_entry["embedding"] == TEST_VECTOR + + +@pytest.mark.asyncio +async def test_vectorizer_with_cache_async_many(cached_vectorizer): + """Test async embedding many texts with partial cache hits/misses.""" + # Store an embedding for the first text only + await cached_vectorizer.cache.aset( + text=TEST_TEXTS[0], + model_name=cached_vectorizer.model, + embedding=[0.1, 0.2, 0.3, 0.4], + ) + + # Call aembed_many - should hit cache for first text, miss for second + results = await cached_vectorizer.aembed_many(TEST_TEXTS) + + # Verify results + assert results[0] == [0.1, 0.2, 0.3, 0.4] # From cache + assert results[1] == TEST_VECTOR # Generated + + # Both should now be in cache + for text in TEST_TEXTS: + assert await cached_vectorizer.cache.aexists( + text=text, model_name=cached_vectorizer.model + ) + + @pytest.mark.requires_api_keys def test_bedrock_bad_credentials(): with pytest.raises(ValueError): @@ -152,7 +322,7 @@ def test_bedrock_bad_credentials(): @pytest.mark.requires_api_keys -def test_bedrock_invalid_model(bedrock_vectorizer): +def test_bedrock_invalid_model(): with pytest.raises(ValueError): bedrock = BedrockTextVectorizer(model="invalid-model") bedrock.embed("test") @@ -161,15 +331,15 @@ def test_bedrock_invalid_model(bedrock_vectorizer): def test_custom_vectorizer_embed(custom_embed_class, custom_embed_func): custom_wrapper = CustomTextVectorizer(embed=custom_embed_func) embedding = custom_wrapper.embed("This is a test sentence.") - assert embedding == [1.1, 2.2, 3.3, 4.4] + assert embedding == TEST_VECTOR custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed) embedding = custom_wrapper.embed("This is a test sentence.") - assert embedding == [1.1, 2.2, 3.3, 4.4] + assert embedding == TEST_VECTOR custom_wrapper = CustomTextVectorizer(embed=custom_embed_class().embed_with_args) embedding = custom_wrapper.embed("This is a test sentence.", max_len=4) - assert embedding == [1.1, 2.2, 3.3, 4.4] + assert embedding == TEST_VECTOR embedding = custom_wrapper.embed("This is a test sentence.", max_len=2) assert embedding == [1.1, 2.2] @@ -331,70 +501,38 @@ def test_non_supported_dtypes(vectorizer_): vectorizer_(dtype=None) -@pytest.fixture( - params=[ - OpenAITextVectorizer, - BedrockTextVectorizer, - MistralAITextVectorizer, - CustomTextVectorizer, - VoyageAITextVectorizer, - ] -) -def avectorizer(request): - if request.param == CustomTextVectorizer: - - def embed_func(text): - return [1.1, 2.2, 3.3, 4.4] - - async def aembed_func(text): - return [1.1, 2.2, 3.3, 4.4] - - async def aembed_many_func(texts): - return [[1.1, 2.2, 3.3, 4.4]] * len(texts) - - return request.param( - embed=embed_func, aembed=aembed_func, aembed_many=aembed_many_func - ) - else: - return request.param() - - @pytest.mark.requires_api_keys @pytest.mark.asyncio -async def test_vectorizer_aembed(avectorizer): - text = "This is a test sentence." - embedding = await avectorizer.aembed(text) - +async def test_vectorizer_aembed(vectorizer): + text = TEST_TEXT + if isinstance(vectorizer, CohereTextVectorizer): + embedding = await vectorizer.aembed(text, input_type="search_document") + elif isinstance(vectorizer, VoyageAITextVectorizer): + embedding = await vectorizer.aembed(text, input_type="document") + else: + embedding = await vectorizer.aembed(text) assert isinstance(embedding, list) - assert len(embedding) == avectorizer.dims + assert len(embedding) == vectorizer.dims @pytest.mark.requires_api_keys @pytest.mark.asyncio -async def test_vectorizer_aembed_many(avectorizer): - texts = ["This is the first test sentence.", "This is the second test sentence."] - embeddings = await avectorizer.aembed_many(texts) +async def test_vectorizer_aembed_many(vectorizer): + texts = TEST_TEXTS + if isinstance(vectorizer, CohereTextVectorizer): + embeddings = await vectorizer.aembed_many(texts, input_type="search_document") + elif isinstance(vectorizer, VoyageAITextVectorizer): + embeddings = await vectorizer.aembed_many(texts, input_type="document") + else: + embeddings = await vectorizer.aembed_many(texts) assert isinstance(embeddings, list) assert len(embeddings) == len(texts) assert all( - isinstance(emb, list) and len(emb) == avectorizer.dims for emb in embeddings + isinstance(emb, list) and len(emb) == vectorizer.dims for emb in embeddings ) -@pytest.mark.requires_api_keys -@pytest.mark.asyncio -async def test_avectorizer_bad_input(avectorizer): - with pytest.raises(TypeError): - avectorizer.embed(1) - - with pytest.raises(TypeError): - avectorizer.embed({"foo": "bar"}) - - with pytest.raises(TypeError): - avectorizer.embed_many(42) - - @pytest.mark.requires_api_keys @pytest.mark.parametrize( "dtype,expected_type", @@ -406,8 +544,8 @@ async def test_avectorizer_bad_input(avectorizer): ) def test_cohere_dtype_support(dtype, expected_type): """Test that CohereTextVectorizer properly handles different dtypes for embeddings.""" - text = "This is a test sentence." - texts = ["First test sentence.", "Second test sentence."] + text = TEST_TEXT + texts = TEST_TEXTS # Create vectorizer with specified dtype vectorizer = CohereTextVectorizer(dtype=dtype) @@ -464,8 +602,8 @@ def test_cohere_dtype_support(dtype, expected_type): @pytest.mark.requires_api_keys def test_cohere_embedding_types_warning(): """Test that a warning is raised when embedding_types parameter is passed.""" - text = "This is a test sentence." - texts = ["First test sentence.", "Second test sentence."] + text = TEST_TEXT + texts = TEST_TEXTS vectorizer = CohereTextVectorizer() # Test warning for single embedding