diff --git a/docs/user_guide/01_getting_started.ipynb b/docs/user_guide/01_getting_started.ipynb index dfa2b581..9b7d0f66 100644 --- a/docs/user_guide/01_getting_started.ipynb +++ b/docs/user_guide/01_getting_started.ipynb @@ -209,9 +209,20 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from redis import Redis\n", "\n", @@ -238,7 +249,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 5, @@ -293,8 +304,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[32m11:53:23\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", - "\u001b[32m11:53:23\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" + "\u001b[32m11:50:15\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m Indices:\n", + "\u001b[32m11:50:15\u001b[0m \u001b[34m[RedisVL]\u001b[0m \u001b[1;30mINFO\u001b[0m 1. user_simple\n" ] } ], @@ -320,15 +331,15 @@ "│ user_simple │ HASH │ ['user_simple_docs'] │ [] │ 0 │\n", "╰──────────────┴────────────────┴──────────────────────┴─────────────────┴────────────╯\n", "Index Fields:\n", - "╭────────────────┬────────────────┬─────────┬────────────────┬────────────────╮\n", - "│ Name │ Attribute │ Type │ Field Option │ Option Value │\n", - "├────────────────┼────────────────┼─────────┼────────────────┼────────────────┤\n", - "│ user │ user │ TAG │ SEPARATOR │ , │\n", - "│ credit_score │ credit_score │ TAG │ SEPARATOR │ , │\n", - "│ job │ job │ TEXT │ WEIGHT │ 1 │\n", - "│ age │ age │ NUMERIC │ │ │\n", - "│ user_embedding │ user_embedding │ VECTOR │ │ │\n", - "╰────────────────┴────────────────┴─────────┴────────────────┴────────────────╯\n" + "╭────────────────┬────────────────┬─────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬────────────────┬─────────────────┬────────────────╮\n", + "│ Name │ Attribute │ Type │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │ Field Option │ Option Value │\n", + "├────────────────┼────────────────┼─────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼────────────────┼─────────────────┼────────────────┤\n", + "│ user │ user │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ credit_score │ credit_score │ TAG │ SEPARATOR │ , │ │ │ │ │ │ │\n", + "│ job │ job │ TEXT │ WEIGHT │ 1 │ │ │ │ │ │ │\n", + "│ age │ age │ NUMERIC │ │ │ │ │ │ │ │ │\n", + "│ user_embedding │ user_embedding │ VECTOR │ algorithm │ FLAT │ data_type │ FLOAT32 │ dim │ 3 │ distance_metric │ COSINE │\n", + "╰────────────────┴────────────────┴─────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴────────────────┴─────────────────┴────────────────╯\n" ] } ], @@ -354,7 +365,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:d424b73c516442f7919cc11ed3bb1882', 'user_simple_docs:6da16f88342048e79b3500bec5448805', 'user_simple_docs:ef5a590ef85e4d4888fd8ebe79ae1e8c']\n" + "['user_simple_docs:01JM2NWFWNH0BNA640MT5DS8BD', 'user_simple_docs:01JM2NWFWNF4S2V4E4HYG25CVA', 'user_simple_docs:01JM2NWFWNBFXJJ4PV9F4KMJSE']\n" ] } ], @@ -388,7 +399,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "['user_simple_docs:9806a362604f4700b17513cc94fcf10d']\n" + "['user_simple_docs:01JM2NWJGYMJ0QTR5YB4MB0BX9']\n" ] } ], @@ -476,9 +487,50 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "{'index': {'name': 'user_simple', 'prefix': 'user_simple_docs'},\n", + " 'fields': [{'name': 'user', 'type': 'tag'},\n", + " {'name': 'credit_score', 'type': 'tag'},\n", + " {'name': 'job', 'type': 'text'},\n", + " {'name': 'age', 'type': 'numeric'},\n", + " {'name': 'user_embedding',\n", + " 'type': 'vector',\n", + " 'attrs': {'dims': 3,\n", + " 'distance_metric': 'cosine',\n", + " 'algorithm': 'flat',\n", + " 'datatype': 'float32'}}]}" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "schema" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "from redisvl.index import AsyncSearchIndex\n", "from redis.asyncio import Redis\n", @@ -491,7 +543,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -532,7 +584,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -620,24 +672,24 @@ "│ Stat Key │ Value │\n", "├─────────────────────────────┼─────────────┤\n", "│ num_docs │ 4 │\n", - "│ num_terms │ 0 │\n", + "│ num_terms │ 4 │\n", "│ max_doc_id │ 4 │\n", - "│ num_records │ 20 │\n", + "│ num_records │ 22 │\n", "│ percent_indexed │ 1 │\n", "│ hash_indexing_failures │ 0 │\n", - "│ number_of_uses │ 2 │\n", - "│ bytes_per_record_avg │ 1 │\n", - "│ doc_table_size_mb │ 0.00044632 │\n", - "│ inverted_sz_mb │ 1.90735e-05 │\n", - "│ key_table_size_mb │ 0.000138283 │\n", - "│ offset_bits_per_record_avg │ nan │\n", - "│ offset_vectors_sz_mb │ 0 │\n", - "│ offsets_per_term_avg │ 0 │\n", - "│ records_per_doc_avg │ 5 │\n", + "│ number_of_uses │ 5 │\n", + "│ bytes_per_record_avg │ 50.9091 │\n", + "│ doc_table_size_mb │ 0.000423431 │\n", + "│ inverted_sz_mb │ 0.00106812 │\n", + "│ key_table_size_mb │ 0.000165939 │\n", + "│ offset_bits_per_record_avg │ 8 │\n", + "│ offset_vectors_sz_mb │ 5.72205e-06 │\n", + "│ offsets_per_term_avg │ 0.272727 │\n", + "│ records_per_doc_avg │ 5.5 │\n", "│ sortable_values_size_mb │ 0 │\n", - "│ total_indexing_time │ 1.796 │\n", - "│ total_inverted_index_blocks │ 11 │\n", - "│ vector_index_sz_mb │ 0.235603 │\n", + "│ total_indexing_time │ 0.197 │\n", + "│ total_inverted_index_blocks │ 12 │\n", + "│ vector_index_sz_mb │ 0.0201416 │\n", "╰─────────────────────────────┴─────────────╯\n" ] } @@ -657,7 +709,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Below we will clean up after our work. First, you can optionally flush all data from Redis associated with the index by\n", + "Below we will clean up after our work. First, you can flush all data from Redis associated with the index by\n", "using the `.clear()` method. This will leave the secondary index in place for future insertions or updates.\n", "\n", "But if you want to clean up everything, including the index, just use `.delete()`\n", @@ -666,31 +718,53 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 19, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "4" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# (optionally) clear all data from Redis associated with the index\n", + "# Clear all data from Redis associated with the index\n", "await index.clear()" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 20, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "# but the index is still in place\n", + "# Butm the index is still in place\n", "await index.exists()" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ - "# remove / delete the index in its entirety\n", + "# Remove / delete the index in its entirety\n", "await index.delete()" ] } @@ -711,7 +785,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.13.2" }, "orig_nbformat": 4 }, diff --git a/docs/user_guide/08_semantic_router.ipynb b/docs/user_guide/08_semantic_router.ipynb index ab3bf2cc..360108d0 100644 --- a/docs/user_guide/08_semantic_router.ipynb +++ b/docs/user_guide/08_semantic_router.ipynb @@ -421,7 +421,7 @@ "source": [ "router2 = SemanticRouter.from_dict(router.to_dict(), redis_url=\"redis://localhost:6379\")\n", "\n", - "assert router2 == router" + "assert router2.to_dict() == router.to_dict()" ] }, { @@ -449,7 +449,7 @@ "source": [ "router3 = SemanticRouter.from_yaml(\"router.yaml\", redis_url=\"redis://localhost:6379\")\n", "\n", - "assert router3 == router2 == router" + "assert router3.to_dict() == router2.to_dict() == router.to_dict()" ] }, { diff --git a/redisvl/extensions/llmcache/schema.py b/redisvl/extensions/llmcache/schema.py index ad72d749..fa6f720a 100644 --- a/redisvl/extensions/llmcache/schema.py +++ b/redisvl/extensions/llmcache/schema.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional -from pydantic.v1 import BaseModel, Field, root_validator, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from redisvl.extensions.constants import ( CACHE_VECTOR_FIELD_NAME, @@ -34,7 +34,7 @@ class CacheEntry(BaseModel): filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def generate_id(cls, values): # Ensure entry_id is set @@ -42,14 +42,15 @@ def generate_id(cls, values): values["entry_id"] = hashify(values["prompt"], values.get("filters")) return values - @validator("metadata") + @field_validator("metadata") + @classmethod def non_empty_metadata(cls, v): if v is not None and not isinstance(v, dict): raise TypeError("Metadata must be a dictionary.") return v def to_dict(self, dtype: str) -> Dict: - data = self.dict(exclude_none=True) + data = self.model_dump(exclude_none=True) data["prompt_vector"] = array_to_buffer(self.prompt_vector, dtype) if self.metadata is not None: data["metadata"] = serialize(self.metadata) @@ -79,33 +80,33 @@ class CacheHit(BaseModel): filters: Optional[Dict[str, Any]] = Field(default=None) """Optional filter data stored on the cache entry for customizing retrieval""" - @root_validator(pre=True) + # Allow extra fields to simplify handling filters + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") @classmethod - def validate_cache_hit(cls, values): + def validate_cache_hit(cls, values: Dict[str, Any]) -> Dict[str, Any]: # Deserialize metadata if necessary if "metadata" in values and isinstance(values["metadata"], str): values["metadata"] = deserialize(values["metadata"]) - # Separate filters from other fields - known_fields = set(cls.__fields__.keys()) - filters = {k: v for k, v in values.items() if k not in known_fields} - - # Add filters to values - if filters: - values["filters"] = filters - - # Remove filter fields from the main values - for k in filters: - values.pop(k) + # Collect any extra fields and store them as filters + extra_data = values.pop("__pydantic_extra__", {}) or {} + if extra_data: + current_filters = values.get("filters") or {} + if not isinstance(current_filters, dict): + current_filters = {} + current_filters.update(extra_data) + values["filters"] = current_filters return values - def to_dict(self) -> Dict: - data = self.dict(exclude_none=True) - if self.filters: - data.update(self.filters) + def to_dict(self) -> Dict[str, Any]: + """Convert this model to a dictionary, merging filters into the result.""" + data = self.model_dump(exclude_none=True) + if data.get("filters"): + data.update(data["filters"]) del data["filters"] - return data diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 5553bfa4..86cdb02e 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -125,7 +125,7 @@ def __init__( # Create semantic cache schema and index schema = SemanticCacheIndexSchema.from_params( - name, prefix, vectorizer.dims, vectorizer.dtype + name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore ) schema = self._modify_schema(schema, filterable_fields) self._index = SearchIndex(schema=schema) @@ -141,7 +141,7 @@ def __init__( existing_index = SearchIndex.from_existing( name, redis_client=self._index.client ) - if existing_index.schema != self._index.schema: + if existing_index.schema.to_dict() != self._index.schema.to_dict(): raise ValueError( f"Existing index {name} schema does not match the user provided schema for the semantic cache. " "If you wish to overwrite the index schema, set overwrite=True during initialization." diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 9c0d24c6..1b1d6dc8 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -1,7 +1,9 @@ +import warnings from enum import Enum -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from pydantic.v1 import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator +from typing_extensions import Annotated from redisvl.extensions.constants import ROUTE_VECTOR_FIELD_NAME from redisvl.schema import IndexSchema @@ -14,18 +16,20 @@ class Route(BaseModel): """The name of the route.""" references: List[str] """List of reference phrases for the route.""" - metadata: Dict[str, str] = Field(default={}) + metadata: Dict[str, Any] = Field(default={}) """Metadata associated with the route.""" - distance_threshold: float = Field(default=0.5) + distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5 """Distance threshold for matching the route.""" - @validator("name") + @field_validator("name") + @classmethod def name_must_not_be_empty(cls, v): if not v or not v.strip(): raise ValueError("Route name must not be empty") return v - @validator("references") + @field_validator("references") + @classmethod def references_must_not_be_empty(cls, v): if not v: raise ValueError("References must not be empty") @@ -33,12 +37,6 @@ def references_must_not_be_empty(cls, v): raise ValueError("All references must be non-empty strings") return v - @validator("distance_threshold") - def distance_threshold_must_be_positive(cls, v): - if v is not None and v <= 0: - raise ValueError("Route distance threshold must be greater than zero") - return v - class RouteMatch(BaseModel): """Model representing a matched route with distance information.""" @@ -63,27 +61,26 @@ class DistanceAggregationMethod(Enum): class RoutingConfig(BaseModel): """Configuration for routing behavior.""" - # distance_threshold: float = Field(default=0.5) - """The threshold for semantic distance.""" - max_k: int = Field(default=1) - + """The maximum number of top matches to return.""" + max_k: Annotated[int, Field(strict=True, default=1, gt=0)] = 1 """Aggregation method to use to classify queries.""" aggregation_method: DistanceAggregationMethod = Field( default=DistanceAggregationMethod.avg ) - """The maximum number of top matches to return.""" - distance_threshold: float = Field( - default=0.5, - deprecated=True, - description="Global distance threshold is deprecated all distance_thresholds now apply at route level.", - ) + model_config = ConfigDict(extra="ignore") - @validator("max_k") - def max_k_must_be_positive(cls, v): - if v <= 0: - raise ValueError("max_k must be a positive integer") - return v + @model_validator(mode="before") + @classmethod + def remove_distance_threshold(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if "distance_threshold" in values: + warnings.warn( + "The 'distance_threshold' field is deprecated and will be ignored. Set distance_threshold per Route.", + DeprecationWarning, + stacklevel=2, + ) + values.pop("distance_threshold") + return values class SemanticRouterIndexSchema(IndexSchema): diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index 8a349f46..be519dd8 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -3,7 +3,7 @@ import redis.commands.search.reducers as reducers import yaml -from pydantic.v1 import BaseModel, Field, PrivateAttr +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr from redis import Redis from redis.commands.search.aggregation import AggregateRequest, AggregateResult, Reducer from redis.exceptions import ResponseError @@ -44,8 +44,7 @@ class SemanticRouter(BaseModel): _index: SearchIndex = PrivateAttr() - class Config: - arbitrary_types_allowed = True + model_config = ConfigDict(arbitrary_types_allowed=True) @deprecated_argument("dtype", "vectorizer") def __init__( @@ -109,7 +108,7 @@ def _initialize_index( ): """Initialize the search index and handle Redis connection.""" schema = SemanticRouterIndexSchema.from_params( - self.name, self.vectorizer.dims, self.vectorizer.dtype + self.name, self.vectorizer.dims, self.vectorizer.dtype # type: ignore ) self._index = SearchIndex(schema=schema) @@ -124,7 +123,7 @@ def _initialize_index( existing_index = SearchIndex.from_existing( self.name, redis_client=self._index.client ) - if existing_index.schema != self._index.schema: + if existing_index.schema.to_dict() != self._index.schema.to_dict(): raise ValueError( f"Existing index {self.name} schema does not match the user provided schema for the semantic router. " "If you wish to overwrite the index schema, set overwrite=True during initialization." diff --git a/redisvl/extensions/session_manager/schema.py b/redisvl/extensions/session_manager/schema.py index f943889d..6be28f22 100644 --- a/redisvl/extensions/session_manager/schema.py +++ b/redisvl/extensions/session_manager/schema.py @@ -1,6 +1,6 @@ from typing import Dict, List, Optional -from pydantic.v1 import BaseModel, Field, root_validator +from pydantic import BaseModel, ConfigDict, Field, model_validator from redisvl.extensions.constants import ( CONTENT_FIELD_NAME, @@ -33,11 +33,9 @@ class ChatMessage(BaseModel): """An optional identifier for a tool call associated with the message.""" vector_field: Optional[List[float]] = Field(default=None) """The vector representation of the message content.""" + model_config = ConfigDict(arbitrary_types_allowed=True) - class Config: - arbitrary_types_allowed = True - - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def generate_id(cls, values): if TIMESTAMP_FIELD_NAME not in values: @@ -49,7 +47,7 @@ def generate_id(cls, values): return values def to_dict(self, dtype: Optional[str] = None) -> Dict: - data = self.dict(exclude_none=True) + data = self.model_dump(exclude_none=True) # handle optional fields if SESSION_VECTOR_FIELD_NAME in data: diff --git a/redisvl/extensions/session_manager/semantic_session.py b/redisvl/extensions/session_manager/semantic_session.py index 74d8f4ab..6924ce5d 100644 --- a/redisvl/extensions/session_manager/semantic_session.py +++ b/redisvl/extensions/session_manager/semantic_session.py @@ -94,7 +94,7 @@ def __init__( self.set_distance_threshold(distance_threshold) schema = SemanticSessionIndexSchema.from_params( - name, prefix, self._vectorizer.dims, vectorizer.dtype + name, prefix, vectorizer.dims, vectorizer.dtype # type: ignore ) self._index = SearchIndex(schema=schema) @@ -110,7 +110,7 @@ def __init__( existing_index = SearchIndex.from_existing( name, redis_client=self._index.client ) - if existing_index.schema != self._index.schema: + if existing_index.schema.to_dict() != self._index.schema.to_dict(): raise ValueError( f"Existing index {name} schema does not match the user provided schema for the semantic session. " "If you wish to overwrite the index schema, set overwrite=True during initialization." diff --git a/redisvl/index/storage.py b/redisvl/index/storage.py index 108237f3..2be386c0 100644 --- a/redisvl/index/storage.py +++ b/redisvl/index/storage.py @@ -1,7 +1,7 @@ import asyncio from typing import Any, Callable, Dict, Iterable, List, Optional -from pydantic.v1 import BaseModel +from pydantic import BaseModel from redis import Redis from redis.asyncio import Redis as AsyncRedis from redis.commands.search.indexDefinition import IndexType diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 2b5db557..17714480 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -6,9 +6,9 @@ """ from enum import Enum -from typing import Any, Dict, Optional, Tuple, Type, Union +from typing import Any, Dict, Literal, Optional, Tuple, Type, Union -from pydantic.v1 import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from redis.commands.search.field import Field as RedisField from redis.commands.search.field import GeoField as RedisGeoField from redis.commands.search.field import NumericField as RedisNumericField @@ -16,7 +16,13 @@ from redis.commands.search.field import TextField as RedisTextField from redis.commands.search.field import VectorField as RedisVectorField -### Attribute Enums ### + +class FieldTypes(str, Enum): + TAG = "tag" + TEXT = "text" + NUMERIC = "numeric" + GEO = "geo" + VECTOR = "vector" class VectorDistanceMetric(str, Enum): @@ -99,7 +105,7 @@ class BaseVectorFieldAttributes(BaseModel): initial_cap: Optional[int] = None """Initial vector capacity in the index affecting memory allocation size of the index""" - @validator("algorithm", "datatype", "distance_metric", pre=True) + @field_validator("algorithm", "datatype", "distance_metric", mode="before") @classmethod def uppercase_strings(cls, v): """Validate that provided values are cast to uppercase""" @@ -121,9 +127,7 @@ def field_data(self) -> Dict[str, Any]: class FlatVectorFieldAttributes(BaseVectorFieldAttributes): """FLAT vector field attributes""" - algorithm: VectorIndexAlgorithm = Field( - default=VectorIndexAlgorithm.FLAT, const=True - ) + algorithm: Literal[VectorIndexAlgorithm.FLAT] = VectorIndexAlgorithm.FLAT """The indexing algorithm for the vector field""" block_size: Optional[int] = None """Block size to hold amount of vectors in a contiguous array. This is useful when the index is dynamic with respect to addition and deletion""" @@ -132,9 +136,7 @@ class FlatVectorFieldAttributes(BaseVectorFieldAttributes): class HNSWVectorFieldAttributes(BaseVectorFieldAttributes): """HNSW vector field attributes""" - algorithm: VectorIndexAlgorithm = Field( - default=VectorIndexAlgorithm.HNSW, const=True - ) + algorithm: Literal[VectorIndexAlgorithm.HNSW] = VectorIndexAlgorithm.HNSW """The indexing algorithm for the vector field""" m: int = Field(default=16) """Number of max outgoing edges for each graph node in each layer""" @@ -173,7 +175,7 @@ def as_redis_field(self) -> RedisField: class TextField(BaseField): """Text field supporting a full text search index""" - type: str = Field(default="text", const=True) + type: Literal[FieldTypes.TEXT] = FieldTypes.TEXT attrs: TextFieldAttributes = Field(default_factory=TextFieldAttributes) def as_redis_field(self) -> RedisField: @@ -191,7 +193,7 @@ def as_redis_field(self) -> RedisField: class TagField(BaseField): """Tag field for simple boolean-style filtering""" - type: str = Field(default="tag", const=True) + type: Literal[FieldTypes.TAG] = FieldTypes.TAG attrs: TagFieldAttributes = Field(default_factory=TagFieldAttributes) def as_redis_field(self) -> RedisField: @@ -208,7 +210,7 @@ def as_redis_field(self) -> RedisField: class NumericField(BaseField): """Numeric field for numeric range filtering""" - type: str = Field(default="numeric", const=True) + type: Literal[FieldTypes.NUMERIC] = FieldTypes.NUMERIC attrs: NumericFieldAttributes = Field(default_factory=NumericFieldAttributes) def as_redis_field(self) -> RedisField: @@ -223,7 +225,7 @@ def as_redis_field(self) -> RedisField: class GeoField(BaseField): """Geo field with a geo-spatial index for location based search""" - type: str = Field(default="geo", const=True) + type: Literal[FieldTypes.GEO] = FieldTypes.GEO attrs: GeoFieldAttributes = Field(default_factory=GeoFieldAttributes) def as_redis_field(self) -> RedisField: @@ -238,7 +240,7 @@ def as_redis_field(self) -> RedisField: class FlatVectorField(BaseField): "Vector field with a FLAT index (brute force nearest neighbors search)" - type: str = Field(default="vector", const=True) + type: Literal[FieldTypes.VECTOR] = FieldTypes.VECTOR attrs: FlatVectorFieldAttributes def as_redis_field(self) -> RedisField: @@ -253,7 +255,7 @@ def as_redis_field(self) -> RedisField: class HNSWVectorField(BaseField): """Vector field with an HNSW index (approximate nearest neighbors search)""" - type: str = Field(default="vector", const=True) + type: Literal["vector"] = "vector" attrs: HNSWVectorFieldAttributes def as_redis_field(self) -> RedisField: @@ -271,20 +273,21 @@ def as_redis_field(self) -> RedisField: return RedisVectorField(name, self.attrs.algorithm, field_data, as_name=as_name) -class FieldFactory: - """Factory class to create fields from client data and kwargs.""" +FIELD_TYPE_MAP = { + "tag": TagField, + "text": TextField, + "numeric": NumericField, + "geo": GeoField, +} - FIELD_TYPE_MAP = { - "tag": TagField, - "text": TextField, - "numeric": NumericField, - "geo": GeoField, - } +VECTOR_FIELD_TYPE_MAP = { + "flat": FlatVectorField, + "hnsw": HNSWVectorField, +} - VECTOR_FIELD_TYPE_MAP = { - "flat": FlatVectorField, - "hnsw": HNSWVectorField, - } + +class FieldFactory: + """Factory class to create fields from client data and kwargs.""" @classmethod def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]: @@ -296,10 +299,10 @@ def pick_vector_field_type(cls, attrs: Dict[str, Any]) -> Type[BaseField]: raise ValueError("Must provide dims param for the vector field.") algorithm = attrs["algorithm"].lower() - if algorithm not in cls.VECTOR_FIELD_TYPE_MAP: + if algorithm not in VECTOR_FIELD_TYPE_MAP: raise ValueError(f"Unknown vector field algorithm: {algorithm}") - return cls.VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore + return VECTOR_FIELD_TYPE_MAP[algorithm] # type: ignore @classmethod def create_field( @@ -314,8 +317,14 @@ def create_field( if type == "vector": field_class = cls.pick_vector_field_type(attrs) else: - if type not in cls.FIELD_TYPE_MAP: + if type not in FIELD_TYPE_MAP: raise ValueError(f"Unknown field type: {type}") - field_class = cls.FIELD_TYPE_MAP[type] # type: ignore + field_class = FIELD_TYPE_MAP[type] # type: ignore - return field_class(name=name, path=path, attrs=attrs) # type: ignore + return field_class.model_validate( + { + "name": name, + "path": path, + "attrs": attrs, + } + ) diff --git a/redisvl/schema/schema.py b/redisvl/schema/schema.py index a4f2e3b3..33dfd9c7 100644 --- a/redisvl/schema/schema.py +++ b/redisvl/schema/schema.py @@ -1,10 +1,10 @@ import re from enum import Enum from pathlib import Path -from typing import Any, Dict, List +from typing import Any, Dict, List, Literal import yaml -from pydantic.v1 import BaseModel, Field, root_validator +from pydantic import BaseModel, model_validator from redis.commands.search.field import Field as RedisField from redisvl.schema.fields import BaseField, FieldFactory @@ -12,7 +12,6 @@ from redisvl.utils.utils import model_to_dict logger = get_logger(__name__) -SCHEMA_VERSION = "0.1.0" class StorageType(Enum): @@ -145,7 +144,7 @@ class IndexSchema(BaseModel): """Details of the basic index configurations.""" fields: Dict[str, BaseField] = {} """Fields associated with the search index and their properties""" - version: str = Field(default=SCHEMA_VERSION, const=True) + version: Literal["0.1.0"] = "0.1.0" """Version of the underlying index schema.""" @staticmethod @@ -168,7 +167,7 @@ def _make_field(storage_type, **field_inputs) -> BaseField: field.path = None return field - @root_validator(pre=True) + @model_validator(mode="before") @classmethod def validate_and_create_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]: """ @@ -222,7 +221,7 @@ def from_yaml(cls, file_path: str) -> "IndexSchema": with open(fp, "r") as f: yaml_data = yaml.safe_load(f) - return cls(**yaml_data) + return cls.model_validate(yaml_data) @classmethod def from_dict(cls, data: Dict[str, Any]) -> "IndexSchema": @@ -260,7 +259,8 @@ def from_dict(cls, data: Dict[str, Any]) -> "IndexSchema": ] }) """ - return cls(**data) + schema_dict = data.copy() + return cls.model_validate(schema_dict) @property def field_names(self) -> List[str]: @@ -413,7 +413,7 @@ def generate_fields( FieldFactory.create_field( field_type, field_name, - ).dict() + ).model_dump() ) except ValueError as e: if strict: diff --git a/redisvl/utils/rerank/base.py b/redisvl/utils/rerank/base.py index f4602662..d36abeb5 100644 --- a/redisvl/utils/rerank/base.py +++ b/redisvl/utils/rerank/base.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic.v1 import BaseModel, validator +from pydantic import BaseModel, field_validator class BaseReranker(BaseModel, ABC): @@ -10,7 +10,7 @@ class BaseReranker(BaseModel, ABC): limit: int return_score: bool - @validator("limit") + @field_validator("limit") @classmethod def check_limit(cls, value): """Ensures the limit is a positive integer.""" @@ -18,7 +18,7 @@ def check_limit(cls, value): raise ValueError("Limit must be a positive integer.") return value - @validator("rank_by") + @field_validator("rank_by") @classmethod def check_rank_by(cls, value): """Ensures that rank_by is a list of strings if provided.""" diff --git a/redisvl/utils/rerank/cohere.py b/redisvl/utils/rerank/cohere.py index 87163c98..edeb8e72 100644 --- a/redisvl/utils/rerank/cohere.py +++ b/redisvl/utils/rerank/cohere.py @@ -1,7 +1,7 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -45,7 +45,8 @@ def __init__( limit: int = 5, return_score: bool = True, api_config: Optional[Dict] = None, - ) -> None: + **kwargs, + ): """ Initialize the CohereReranker with specified model, ranking criteria, and API configuration. @@ -71,9 +72,9 @@ def __init__( super().__init__( model=model, rank_by=rank_by, limit=limit, return_score=return_score ) - self._initialize_clients(api_config) + self._initialize_clients(api_config, **kwargs) - def _initialize_clients(self, api_config: Optional[Dict]): + def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ Setup the Cohere clients using the provided API key or an environment variable. @@ -96,8 +97,8 @@ def _initialize_clients(self, api_config: Optional[Dict]): "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) - self._client = Client(api_key=api_key, client_name="redisvl") - self._aclient = AsyncClient(api_key=api_key, client_name="redisvl") + self._client = Client(api_key=api_key, client_name="redisvl", **kwargs) + self._aclient = AsyncClient(api_key=api_key, client_name="redisvl", **kwargs) def _preprocess( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs diff --git a/redisvl/utils/rerank/hf_cross_encoder.py b/redisvl/utils/rerank/hf_cross_encoder.py index ff8a1300..f04ec8d0 100644 --- a/redisvl/utils/rerank/hf_cross_encoder.py +++ b/redisvl/utils/rerank/hf_cross_encoder.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -39,7 +39,7 @@ def __init__( limit: int = 3, return_score: bool = True, **kwargs, - ) -> None: + ): """ Initialize the HFCrossEncoderReranker with a specified model and ranking criteria. diff --git a/redisvl/utils/rerank/voyageai.py b/redisvl/utils/rerank/voyageai.py index 2c3741ce..86ccd892 100644 --- a/redisvl/utils/rerank/voyageai.py +++ b/redisvl/utils/rerank/voyageai.py @@ -1,7 +1,7 @@ import os from typing import Any, Dict, List, Optional, Tuple, Union -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from redisvl.utils.rerank.base import BaseReranker @@ -45,7 +45,8 @@ def __init__( limit: int = 5, return_score: bool = True, api_config: Optional[Dict] = None, - ) -> None: + **kwargs, + ): """ Initialize the VoyageAIReranker with specified model, ranking criteria, and API configuration. @@ -70,9 +71,9 @@ def __init__( super().__init__( model=model, rank_by=rank_by, limit=limit, return_score=return_score ) - self._initialize_clients(api_config) + self._initialize_clients(api_config, **kwargs) - def _initialize_clients(self, api_config: Optional[Dict]): + def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ Setup the VoyageAI clients using the provided API key or an environment variable. @@ -95,8 +96,8 @@ def _initialize_clients(self, api_config: Optional[Dict]): "VoyageAI API key is required. " "Provide it in api_config or set the VOYAGE_API_KEY environment variable." ) - self._client = Client(api_key=api_key) - self._aclient = AsyncClient(api_key=api_key) + self._client = Client(api_key=api_key, **kwargs) + self._aclient = AsyncClient(api_key=api_key, **kwargs) def _preprocess( self, query: str, docs: Union[List[Dict[str, Any]], List[str]], **kwargs diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 164fb2ac..a8f07338 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -1,14 +1,14 @@ import inspect import json import warnings -from contextlib import ContextDecorator, contextmanager +from contextlib import contextmanager from enum import Enum from functools import wraps from time import time from typing import Any, Callable, Dict, Optional from warnings import warn -from pydantic.v1 import BaseModel +from pydantic import BaseModel from ulid import ULID @@ -38,7 +38,7 @@ def serialize_item(item): else: return item - serialized_data = model.dict(exclude_none=True) + serialized_data = model.model_dump(exclude_none=True) for key, value in serialized_data.items(): serialized_data[key] = serialize_item(value) return serialized_data diff --git a/redisvl/utils/vectorize/__init__.py b/redisvl/utils/vectorize/__init__.py index d305306d..62341874 100644 --- a/redisvl/utils/vectorize/__init__.py +++ b/redisvl/utils/vectorize/__init__.py @@ -27,16 +27,18 @@ def vectorizer_from_dict(vectorizer: dict) -> BaseVectorizer: vectorizer_type = Vectorizers(vectorizer["type"]) model = vectorizer["model"] if vectorizer_type == Vectorizers.cohere: - return CohereTextVectorizer(model) + return CohereTextVectorizer(model=model) elif vectorizer_type == Vectorizers.openai: - return OpenAITextVectorizer(model) + return OpenAITextVectorizer(model=model) elif vectorizer_type == Vectorizers.azure_openai: - return AzureOpenAITextVectorizer(model) + return AzureOpenAITextVectorizer(model=model) elif vectorizer_type == Vectorizers.hf: - return HFTextVectorizer(model) + return HFTextVectorizer(model=model) elif vectorizer_type == Vectorizers.mistral: - return MistralAITextVectorizer(model) + return MistralAITextVectorizer(model=model) elif vectorizer_type == Vectorizers.vertexai: - return VertexAITextVectorizer(model) + return VertexAITextVectorizer(model=model) elif vectorizer_type == Vectorizers.voyageai: - return VoyageAITextVectorizer(model) + return VoyageAITextVectorizer(model=model) + else: + raise ValueError(f"Unsupported vectorizer type: {vectorizer_type}") diff --git a/redisvl/utils/vectorize/base.py b/redisvl/utils/vectorize/base.py index e05c7fee..b3a63fa9 100644 --- a/redisvl/utils/vectorize/base.py +++ b/redisvl/utils/vectorize/base.py @@ -2,7 +2,7 @@ from enum import Enum from typing import Callable, List, Optional -from pydantic.v1 import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from redisvl.redis.utils import array_to_buffer from redisvl.schema.fields import VectorDataType @@ -19,16 +19,19 @@ class Vectorizers(Enum): class BaseVectorizer(BaseModel, ABC): + """Base vectorizer interface.""" + model: str - dims: int - dtype: str = Field(default="float32") + dtype: str = "float32" + dims: Optional[int] = None @property def type(self) -> str: return "base" - @validator("dtype") - def check_dtype(dtype): + @field_validator("dtype") + @classmethod + def check_dtype(cls, dtype): try: VectorDataType(dtype.upper()) except ValueError: @@ -37,7 +40,7 @@ def check_dtype(dtype): ) return dtype - @validator("dims") + @field_validator("dims") @classmethod def check_dims(cls, value): """Ensures the dims are a positive integer.""" diff --git a/redisvl/utils/vectorize/text/azureopenai.py b/redisvl/utils/vectorize/text/azureopenai.py index 5a77f7dd..7b3b7d01 100644 --- a/redisvl/utils/vectorize/text/azureopenai.py +++ b/redisvl/utils/vectorize/text/azureopenai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -57,6 +57,7 @@ def __init__( model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the AzureOpenAI vectorizer. @@ -76,10 +77,13 @@ def __init__( ValueError: If the AzureOpenAI API key, version, or endpoint are not provided. ValueError: If an invalid dtype is provided. """ - self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_clients(api_config, **kwargs) + # Set model dimensions + self.dims = self._set_model_dims() - def _initialize_clients(self, api_config: Optional[Dict]): + def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ Setup the OpenAI clients using the provided API key or an environment variable. @@ -141,21 +145,19 @@ def _initialize_clients(self, api_config: Optional[Dict]): api_version=api_version, azure_endpoint=azure_endpoint, **api_config, + **kwargs, ) self._aclient = AsyncAzureOpenAI( api_key=api_key, api_version=api_version, azure_endpoint=azure_endpoint, **api_config, + **kwargs, ) - def _set_model_dims(self, model) -> int: + def _set_model_dims(self) -> int: try: - embedding = ( - self._client.embeddings.create(input=["dimension test"], model=model) - .data[0] - .embedding - ) + embedding = self.embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the AzureOpenAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/bedrock.py b/redisvl/utils/vectorize/text/bedrock.py index d64ca9ca..5858aff8 100644 --- a/redisvl/utils/vectorize/text/bedrock.py +++ b/redisvl/utils/vectorize/text/bedrock.py @@ -2,7 +2,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -51,6 +51,7 @@ def __init__( model: str = "amazon.titan-embed-text-v2:0", api_config: Optional[Dict[str, str]] = None, dtype: str = "float32", + **kwargs, ) -> None: """Initialize the AWS Bedrock Vectorizer. @@ -68,6 +69,17 @@ def __init__( ImportError: If boto3 is not installed. ValueError: If an invalid dtype is provided. """ + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() + + def _initialize_client(self, api_config: Optional[Dict], **kwargs): + """ + Setup the Bedrock client using the provided API keys or + environment variables. + """ try: import boto3 # type: ignore except ImportError: @@ -98,21 +110,18 @@ def __init__( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key, region_name=aws_region, + **kwargs, ) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) - - def _set_model_dims(self, model: str) -> int: - """Initialize model and determine embedding dimensions.""" + def _set_model_dims(self) -> int: try: - response = self._client.invoke_model( - modelId=model, body=json.dumps({"inputText": "dimension test"}) - ) - response_body = json.loads(response["body"].read()) - embedding = response_body["embedding"] - return len(embedding) - except Exception as e: - raise ValueError(f"Error initializing Bedrock model: {str(e)}") + embedding = self.embed("dimension check") + except (KeyError, IndexError) as ke: + raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}") + except Exception as e: # pylint: disable=broad-except + # fall back (TODO get more specific) + raise ValueError(f"Error setting embedding model dimensions: {str(e)}") + return len(embedding) @retry( wait=wait_random_exponential(min=1, max=60), diff --git a/redisvl/utils/vectorize/text/cohere.py b/redisvl/utils/vectorize/text/cohere.py index 859fd83e..bd6481fe 100644 --- a/redisvl/utils/vectorize/text/cohere.py +++ b/redisvl/utils/vectorize/text/cohere.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -52,6 +52,7 @@ def __init__( model: str = "embed-english-v3.0", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the Cohere vectorizer. @@ -70,24 +71,29 @@ def __init__( ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ - self._initialize_client(api_config) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_client(self, api_config: Optional[Dict]): + def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the Cohere clients using the provided API key or an environment variable. """ + if api_config is None: + api_config = {} + # Dynamic import of the cohere module try: - from cohere import AsyncClient, Client + from cohere import Client except ImportError: raise ImportError( "Cohere vectorizer requires the cohere library. \ Please install with `pip install cohere`" ) - # Fetch the API key from api_config or environment variable api_key = ( api_config.get("api_key") if api_config else os.getenv("COHERE_API_KEY") ) @@ -96,15 +102,11 @@ def _initialize_client(self, api_config: Optional[Dict]): "Cohere API key is required. " "Provide it in api_config or set the COHERE_API_KEY environment variable." ) - self._client = Client(api_key=api_key, client_name="redisvl") + self._client = Client(api_key=api_key, client_name="redisvl", **kwargs) - def _set_model_dims(self, model) -> int: + def _set_model_dims(self) -> int: try: - embedding = self._client.embed( - texts=["dimension test"], - model=model, - input_type="search_document", - ).embeddings[0] + embedding = self.embed("dimension check", input_type="search_document") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the Cohere API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/custom.py b/redisvl/utils/vectorize/text/custom.py index 59122435..4558d4d7 100644 --- a/redisvl/utils/vectorize/text/custom.py +++ b/redisvl/utils/vectorize/text/custom.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from redisvl.utils.utils import deprecated_argument from redisvl.utils.vectorize.base import BaseVectorizer @@ -114,17 +114,16 @@ def __init__( Raises: ValueError: if embedding validation fails. """ + super().__init__(model=self.type, dtype=dtype) + # Store user-provided callables self._embed = embed self._embed_many = embed_many self._aembed = aembed self._aembed_many = aembed_many - # Manually validate sync methods to discover dimension - dims = self._validate_sync_callables() - - # Initialize the base class now that we know the dimension - super().__init__(model=self.type, dims=dims, dtype=dtype) + # Set dims + self.dims = self._validate_sync_callables() @property def type(self) -> str: diff --git a/redisvl/utils/vectorize/text/huggingface.py b/redisvl/utils/vectorize/text/huggingface.py index c387f04e..8f81b85c 100644 --- a/redisvl/utils/vectorize/text/huggingface.py +++ b/redisvl/utils/vectorize/text/huggingface.py @@ -54,12 +54,14 @@ def __init__( ValueError: If there is an error setting the embedding model dimensions. ValueError: If an invalid dtype is provided. """ - self._initialize_client(model) - super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(model, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_client(self, model: str): + def _initialize_client(self, model: str, **kwargs): """Setup the HuggingFace client""" - # Dynamic import of the cohere module\ try: from sentence_transformers import SentenceTransformer except ImportError: @@ -68,11 +70,11 @@ def _initialize_client(self, model: str): "Please install with `pip install sentence-transformers`" ) - self._client = SentenceTransformer(model) + self._client = SentenceTransformer(model, **kwargs) def _set_model_dims(self): try: - embedding = self._client.encode(["dimension check"])[0] + embedding = self.embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Empty response from the embedding model: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/mistral.py b/redisvl/utils/vectorize/text/mistral.py index 44f189d6..e930b3a4 100644 --- a/redisvl/utils/vectorize/text/mistral.py +++ b/redisvl/utils/vectorize/text/mistral.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -51,6 +51,7 @@ def __init__( model: str = "mistral-embed", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the MistralAI vectorizer. @@ -68,14 +69,20 @@ def __init__( ValueError: If the Mistral API key is not provided. ValueError: If an invalid dtype is provided. """ - self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_clients(self, api_config: Optional[Dict]): + def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the Mistral clients using the provided API key or an environment variable. """ + if api_config is None: + api_config = {} + # Dynamic import of the mistralai module try: from mistralai import Mistral @@ -96,15 +103,11 @@ def _initialize_clients(self, api_config: Optional[Dict]): environment variable." ) - self._client = Mistral(api_key=api_key) + self._client = Mistral(api_key=api_key, **kwargs) - def _set_model_dims(self, model) -> int: + def _set_model_dims(self) -> int: try: - embedding = ( - self._client.embeddings.create(model=model, inputs=["dimension test"]) - .data[0] - .embedding - ) + embedding = self.embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the MISTRAL API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/openai.py b/redisvl/utils/vectorize/text/openai.py index 9fa9cc18..25b21c67 100644 --- a/redisvl/utils/vectorize/text/openai.py +++ b/redisvl/utils/vectorize/text/openai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -52,6 +52,7 @@ def __init__( model: str = "text-embedding-ada-002", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the OpenAI vectorizer. @@ -69,10 +70,13 @@ def __init__( ValueError: If the OpenAI API key is not provided. ValueError: If an invalid dtype is provided. """ - self._initialize_clients(api_config) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init clients + self._initialize_clients(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_clients(self, api_config: Optional[Dict]): + def _initialize_clients(self, api_config: Optional[Dict], **kwargs): """ Setup the OpenAI clients using the provided API key or an environment variable. @@ -89,7 +93,6 @@ def _initialize_clients(self, api_config: Optional[Dict]): Please install with `pip install openai`" ) - # Pull the API key from api_config or environment variable api_key = ( api_config.pop("api_key") if api_config else os.getenv("OPENAI_API_KEY") ) @@ -100,16 +103,12 @@ def _initialize_clients(self, api_config: Optional[Dict]): environment variable." ) - self._client = OpenAI(api_key=api_key, **api_config) - self._aclient = AsyncOpenAI(api_key=api_key, **api_config) + self._client = OpenAI(api_key=api_key, **api_config, **kwargs) + self._aclient = AsyncOpenAI(api_key=api_key, **api_config, **kwargs) - def _set_model_dims(self, model) -> int: + def _set_model_dims(self) -> int: try: - embedding = ( - self._client.embeddings.create(input=["dimension test"], model=model) - .data[0] - .embedding - ) + embedding = self.embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the OpenAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/vertexai.py b/redisvl/utils/vectorize/text/vertexai.py index 1be10c7a..6d455c67 100644 --- a/redisvl/utils/vectorize/text/vertexai.py +++ b/redisvl/utils/vectorize/text/vertexai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -49,6 +49,7 @@ def __init__( model: str = "textembedding-gecko", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the VertexAI vectorizer. @@ -66,10 +67,13 @@ def __init__( ValueError: If the API key is not provided. ValueError: If an invalid dtype is provided. """ - self._initialize_client(model, api_config) - super().__init__(model=model, dims=self._set_model_dims(), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_client(self, model: str, api_config: Optional[Dict]): + def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the VertexAI clients using the provided API key or an environment variable. @@ -112,11 +116,11 @@ def _initialize_client(self, model: str, api_config: Optional[Dict]): "Please install with `pip install google-cloud-aiplatform>=1.26`" ) - self._client = TextEmbeddingModel.from_pretrained(model) + self._client = TextEmbeddingModel.from_pretrained(self.model) def _set_model_dims(self) -> int: try: - embedding = self._client.get_embeddings(["dimension test"])[0].values + embedding = self.embed("dimension check") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VertexAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/redisvl/utils/vectorize/text/voyageai.py b/redisvl/utils/vectorize/text/voyageai.py index b1f4ac72..fbcbfd9e 100644 --- a/redisvl/utils/vectorize/text/voyageai.py +++ b/redisvl/utils/vectorize/text/voyageai.py @@ -1,7 +1,7 @@ import os from typing import Any, Callable, Dict, List, Optional -from pydantic.v1 import PrivateAttr +from pydantic import PrivateAttr from tenacity import retry, stop_after_attempt, wait_random_exponential from tenacity.retry import retry_if_not_exception_type @@ -53,6 +53,7 @@ def __init__( model: str = "voyage-large-2", api_config: Optional[Dict] = None, dtype: str = "float32", + **kwargs, ): """Initialize the VoyageAI vectorizer. @@ -71,14 +72,20 @@ def __init__( ValueError: If the API key is not provided. """ - self._initialize_client(api_config) - super().__init__(model=model, dims=self._set_model_dims(model), dtype=dtype) + super().__init__(model=model, dtype=dtype) + # Init client + self._initialize_client(api_config, **kwargs) + # Set model dimensions after init + self.dims = self._set_model_dims() - def _initialize_client(self, api_config: Optional[Dict]): + def _initialize_client(self, api_config: Optional[Dict], **kwargs): """ Setup the VoyageAI clients using the provided API key or an environment variable. """ + if api_config is None: + api_config = {} + # Dynamic import of the voyageai module try: from voyageai import AsyncClient, Client @@ -97,16 +104,12 @@ def _initialize_client(self, api_config: Optional[Dict]): "VoyageAI API key is required. " "Provide it in api_config or set the VOYAGE_API_KEY environment variable." ) - self._client = Client(api_key=api_key) - self._aclient = AsyncClient(api_key=api_key) + self._client = Client(api_key=api_key, **kwargs) + self._aclient = AsyncClient(api_key=api_key, **kwargs) - def _set_model_dims(self, model) -> int: + def _set_model_dims(self) -> int: try: - embedding = self._client.embed( - texts=["dimension test"], - model=model, - input_type="document", - ).embeddings[0] + embedding = self.embed("dimension check", input_type="document") except (KeyError, IndexError) as ke: raise ValueError(f"Unexpected response from the VoyageAI API: {str(ke)}") except Exception as e: # pylint: disable=broad-except diff --git a/schemas/semantic_router.yaml b/schemas/semantic_router.yaml index 25aa08e1..0f5bab19 100644 --- a/schemas/semantic_router.yaml +++ b/schemas/semantic_router.yaml @@ -20,4 +20,3 @@ vectorizer: routing_config: max_k: 2 aggregation_method: avg - distance_threshold: 0.3 diff --git a/tests/integration/test_llmcache.py b/tests/integration/test_llmcache.py index 380a00b8..5b32b918 100644 --- a/tests/integration/test_llmcache.py +++ b/tests/integration/test_llmcache.py @@ -5,7 +5,7 @@ from time import sleep, time import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from redis.exceptions import ConnectionError from redisvl.exceptions import RedisModuleVersionError diff --git a/tests/integration/test_semantic_router.py b/tests/integration/test_semantic_router.py index d88fcc6f..ed8f2f0e 100644 --- a/tests/integration/test_semantic_router.py +++ b/tests/integration/test_semantic_router.py @@ -6,9 +6,12 @@ from redis.exceptions import ConnectionError from redisvl.exceptions import RedisModuleVersionError -from redisvl.extensions.llmcache.semantic import SemanticCache from redisvl.extensions.router import SemanticRouter -from redisvl.extensions.router.schema import Route, RoutingConfig +from redisvl.extensions.router.schema import ( + DistanceAggregationMethod, + Route, + RoutingConfig, +) from redisvl.redis.connection import compare_versions from redisvl.utils.vectorize.text.huggingface import HFTextVectorizer @@ -58,7 +61,6 @@ def disable_deprecation_warnings(): def test_initialize_router(semantic_router): assert semantic_router.name == "test-router" assert len(semantic_router.routes) == 2 - assert semantic_router.routing_config.distance_threshold == 0.3 assert semantic_router.routing_config.max_k == 2 @@ -114,10 +116,13 @@ def test_multiple_query(semantic_router): def test_update_routing_config(semantic_router): - new_config = RoutingConfig(distance_threshold=0.5, max_k=1) + new_config = RoutingConfig(max_k=27, aggregation_method="min") semantic_router.update_routing_config(new_config) - assert semantic_router.routing_config.distance_threshold == 0.5 - assert semantic_router.routing_config.max_k == 1 + assert semantic_router.routing_config.max_k == 27 + assert ( + semantic_router.routing_config.aggregation_method + == DistanceAggregationMethod.min + ) def test_vector_query(semantic_router): @@ -189,7 +194,7 @@ def test_from_dict(semantic_router): new_router = SemanticRouter.from_dict( router_dict, redis_client=semantic_router._index.client, overwrite=True ) - assert new_router == semantic_router + assert new_router.to_dict() == router_dict def test_to_yaml(semantic_router): @@ -203,7 +208,7 @@ def test_from_yaml(semantic_router): new_router = SemanticRouter.from_yaml( yaml_file, redis_client=semantic_router._index.client, overwrite=True ) - assert new_router == semantic_router + assert new_router.to_dict() == semantic_router.to_dict() def test_to_dict_missing_fields(): @@ -290,14 +295,14 @@ def test_different_vector_dtypes(redis_url, routes): def test_bad_dtype_connecting_to_exiting_router(redis_url, routes): try: router = SemanticRouter( - name="float64 router", + name="float64-router", routes=routes, dtype="float64", redis_url=redis_url, ) same_type = SemanticRouter( - name="float64 router", + name="float64-router", routes=routes, dtype="float64", redis_url=redis_url, @@ -308,7 +313,7 @@ def test_bad_dtype_connecting_to_exiting_router(redis_url, routes): with pytest.raises(ValueError): bad_type = SemanticRouter( - name="float64 router", + name="float64-router", routes=routes, dtype="float16", redis_url=redis_url, diff --git a/tests/unit/test_llmcache_schema.py b/tests/unit/test_llmcache_schema.py index a686ebe0..72f230fc 100644 --- a/tests/unit/test_llmcache_schema.py +++ b/tests/unit/test_llmcache_schema.py @@ -1,7 +1,7 @@ import json import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from redisvl.extensions.llmcache.schema import CacheEntry, CacheHit from redisvl.redis.utils import array_to_buffer, hashify diff --git a/tests/unit/test_route_schema.py b/tests/unit/test_route_schema.py index 746c1182..19fe1f39 100644 --- a/tests/unit/test_route_schema.py +++ b/tests/unit/test_route_schema.py @@ -1,5 +1,5 @@ import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from redisvl.extensions.router.schema import ( DistanceAggregationMethod, @@ -74,7 +74,7 @@ def test_route_invalid_threshold_zero(): metadata={"key": "value"}, distance_threshold=0, ) - assert "Route distance threshold must be greater than zero" in str(excinfo.value) + assert "Input should be greater than 0" in str(excinfo.value) def test_route_invalid_threshold_negative(): @@ -85,7 +85,7 @@ def test_route_invalid_threshold_negative(): metadata={"key": "value"}, distance_threshold=-0.1, ) - assert "Route distance threshold must be greater than zero" in str(excinfo.value) + assert "Input should be greater than 0" in str(excinfo.value) def test_route_match(): @@ -115,4 +115,4 @@ def test_routing_config_valid(): def test_routing_config_invalid_max_k(): with pytest.raises(ValidationError) as excinfo: RoutingConfig(max_k=0) - assert "max_k must be a positive integer" in str(excinfo.value) + assert "Input should be greater than 0" in str(excinfo.value) diff --git a/tests/unit/test_session_schema.py b/tests/unit/test_session_schema.py index df6b8d91..5bd2c221 100644 --- a/tests/unit/test_session_schema.py +++ b/tests/unit/test_session_schema.py @@ -1,5 +1,5 @@ import pytest -from pydantic.v1 import ValidationError +from pydantic import ValidationError from redisvl.extensions.session_manager.schema import ChatMessage from redisvl.redis.utils import array_to_buffer