diff --git a/redisvl/extensions/llmcache/semantic.py b/redisvl/extensions/llmcache/semantic.py index 2b70ae09..2be657f7 100644 --- a/redisvl/extensions/llmcache/semantic.py +++ b/redisvl/extensions/llmcache/semantic.py @@ -22,7 +22,7 @@ SemanticCacheIndexSchema, ) from redisvl.index import AsyncSearchIndex, SearchIndex -from redisvl.query import RangeQuery +from redisvl.query import VectorRangeQuery from redisvl.query.filter import FilterExpression from redisvl.query.query import BaseQuery from redisvl.redis.connection import RedisConnectionFactory @@ -238,9 +238,9 @@ def set_threshold(self, distance_threshold: float) -> None: Raises: ValueError: If the threshold is not between 0 and 1. """ - if not 0 <= float(distance_threshold) <= 1: + if not 0 <= float(distance_threshold) <= 2: raise ValueError( - f"Distance must be between 0 and 1, got {distance_threshold}" + f"Distance must be between 0 and 2, got {distance_threshold}" ) self._distance_threshold = float(distance_threshold) @@ -390,7 +390,7 @@ def check( vector = vector or self._vectorize_prompt(prompt) self._check_vector_dims(vector) - query = RangeQuery( + query = VectorRangeQuery( vector=vector, vector_field_name=CACHE_VECTOR_FIELD_NAME, return_fields=self.return_fields, @@ -473,7 +473,7 @@ async def acheck( vector = vector or await self._avectorize_prompt(prompt) self._check_vector_dims(vector) - query = RangeQuery( + query = VectorRangeQuery( vector=vector, vector_field_name=CACHE_VECTOR_FIELD_NAME, return_fields=self.return_fields, @@ -481,6 +481,7 @@ async def acheck( num_results=num_results, return_score=True, filter_expression=filter_expression, + normalize_vector_distance=True, ) # Search the cache! diff --git a/redisvl/extensions/router/schema.py b/redisvl/extensions/router/schema.py index 1b1d6dc8..d9b38677 100644 --- a/redisvl/extensions/router/schema.py +++ b/redisvl/extensions/router/schema.py @@ -18,7 +18,7 @@ class Route(BaseModel): """List of reference phrases for the route.""" metadata: Dict[str, Any] = Field(default={}) """Metadata associated with the route.""" - distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5 + distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5 """Distance threshold for matching the route.""" @field_validator("name") diff --git a/redisvl/extensions/router/semantic.py b/redisvl/extensions/router/semantic.py index c06789e1..c7c91c86 100644 --- a/redisvl/extensions/router/semantic.py +++ b/redisvl/extensions/router/semantic.py @@ -17,7 +17,7 @@ SemanticRouterIndexSchema, ) from redisvl.index import SearchIndex -from redisvl.query import RangeQuery +from redisvl.query import VectorRangeQuery from redisvl.redis.utils import convert_bytes, hashify, make_dict from redisvl.utils.log import get_logger from redisvl.utils.utils import deprecated_argument, model_to_dict @@ -237,7 +237,7 @@ def _distance_threshold_filter(self) -> str: def _build_aggregate_request( self, - vector_range_query: RangeQuery, + vector_range_query: VectorRangeQuery, aggregation_method: DistanceAggregationMethod, max_k: int, ) -> AggregateRequest: @@ -279,7 +279,7 @@ def _get_route_matches( # therefore you might take the max_threshold and further refine from there. distance_threshold = max(route.distance_threshold for route in self.routes) - vector_range_query = RangeQuery( + vector_range_query = VectorRangeQuery( vector=vector, vector_field_name=ROUTE_VECTOR_FIELD_NAME, distance_threshold=float(distance_threshold), diff --git a/redisvl/index/index.py b/redisvl/index/index.py index c4e5de62..4cf44263 100644 --- a/redisvl/index/index.py +++ b/redisvl/index/index.py @@ -34,7 +34,7 @@ from redisvl.exceptions import RedisModuleVersionError, RedisSearchError from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage -from redisvl.query import BaseQuery, CountQuery, FilterQuery +from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery from redisvl.query.filter import FilterExpression from redisvl.redis.connection import ( RedisConnectionFactory, @@ -42,6 +42,7 @@ ) from redisvl.redis.utils import convert_bytes from redisvl.schema import IndexSchema, StorageType +from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric from redisvl.utils.log import get_logger logger = get_logger(__name__) @@ -62,7 +63,7 @@ def process_results( - results: "Result", query: BaseQuery, storage_type: StorageType + results: "Result", query: BaseQuery, schema: IndexSchema ) -> List[Dict[str, Any]]: """Convert a list of search Result objects into a list of document dictionaries. @@ -87,11 +88,24 @@ def process_results( # Determine if unpacking JSON is needed unpack_json = ( - (storage_type == StorageType.JSON) + (schema.index.storage_type == StorageType.JSON) and isinstance(query, FilterQuery) and not query._return_fields # type: ignore ) + if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance: + dist_metric = VectorDistanceMetric( + schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore + ) + if dist_metric == VectorDistanceMetric.IP: + warnings.warn( + "Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition." + ) + + norm_fn = VECTOR_NORM_MAP[dist_metric.value] + else: + norm_fn = None + # Process records def _process(doc: "Document") -> Dict[str, Any]: doc_dict = doc.__dict__ @@ -105,6 +119,12 @@ def _process(doc: "Document") -> Dict[str, Any]: return {"id": doc_dict.get("id"), **json_data} raise ValueError(f"Unable to parse json data from Redis {json_data}") + if norm_fn: + # convert float back to string to be consistent + doc_dict[query.DISTANCE_ID] = str( # type: ignore + norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore + ) + # Remove 'payload' if present doc_dict.pop("payload", None) @@ -757,11 +777,7 @@ def batch_query( ) all_parsed = [] for query, batch_results in zip(queries, results): - parsed = process_results( - batch_results, - query=query, - storage_type=self.schema.index.storage_type, - ) + parsed = process_results(batch_results, query=query, schema=self.schema) # Create separate lists of parsed results for each query # passed in to the batch_search method, so that callers can # access the results for each query individually @@ -771,9 +787,7 @@ def batch_query( def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query and process results.""" results = self.search(query.query, query_params=query.params) - return process_results( - results, query=query, storage_type=self.schema.index.storage_type - ) + return process_results(results, query=query, schema=self.schema) def query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Execute a query on the index. @@ -1403,7 +1417,7 @@ async def batch_query( parsed = process_results( batch_results, query=query, - storage_type=self.schema.index.storage_type, + schema=self.schema, ) # Create separate lists of parsed results for each query # passed in to the batch_search method, so that callers can @@ -1415,9 +1429,7 @@ async def batch_query( async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query and process results.""" results = await self.search(query.query, query_params=query.params) - return process_results( - results, query=query, storage_type=self.schema.index.storage_type - ) + return process_results(results, query=query, schema=self.schema) async def query(self, query: BaseQuery) -> List[Dict[str, Any]]: """Asynchronously execute a query on the index. diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 8246794f..dcdbcabe 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,5 +1,6 @@ from redisvl.query.query import ( BaseQuery, + BaseVectorQuery, CountQuery, FilterQuery, RangeQuery, @@ -9,6 +10,7 @@ __all__ = [ "BaseQuery", + "BaseVectorQuery", "VectorQuery", "FilterQuery", "RangeQuery", diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 7dbddc91..3909df00 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -5,6 +5,7 @@ from redisvl.query.filter import FilterExpression from redisvl.redis.utils import array_to_buffer +from redisvl.utils.utils import denorm_cosine_distance class BaseQuery(RedisQuery): @@ -175,6 +176,8 @@ class BaseVectorQuery: DISTANCE_ID: str = "vector_distance" VECTOR_PARAM: str = "vector" + _normalize_vector_distance: bool = False + class HybridPolicy(str, Enum): """Enum for valid hybrid policy options in vector queries.""" @@ -198,6 +201,7 @@ def __init__( in_order: bool = False, hybrid_policy: Optional[str] = None, batch_size: Optional[int] = None, + normalize_vector_distance: bool = False, ): """A query for running a vector search along with an optional filter expression. @@ -233,6 +237,12 @@ def __init__( of vectors to fetch in each batch. Larger values may improve performance at the cost of memory usage. Only applies when hybrid_policy="BATCHES". Defaults to None, which lets Redis auto-select an appropriate batch size. + normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean), + IP (inner product), and COSINE. By default, L2 distance returns an unbounded value. + COSINE distance returns a value between 0 and 2. IP returns a value determined by + the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance + to a similarity score between 0 and 1. Note: setting this flag to true for IP will + throw a warning since by definition COSINE similarity is normalized IP. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -246,6 +256,7 @@ def __init__( self._num_results = num_results self._hybrid_policy: Optional[HybridPolicy] = None self._batch_size: Optional[int] = None + self._normalize_vector_distance = normalize_vector_distance self.set_filter(filter_expression) query_string = self._build_query_string() @@ -394,6 +405,7 @@ def __init__( in_order: bool = False, hybrid_policy: Optional[str] = None, batch_size: Optional[int] = None, + normalize_vector_distance: bool = False, ): """A query for running a filtered vector search based on semantic distance threshold. @@ -437,6 +449,19 @@ def __init__( of vectors to fetch in each batch. Larger values may improve performance at the cost of memory usage. Only applies when hybrid_policy="BATCHES". Defaults to None, which lets Redis auto-select an appropriate batch size. + normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean), + IP (inner product), and COSINE. By default, L2 distance returns an unbounded value. + COSINE distance returns a value between 0 and 2. IP returns a value determined by + the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance + to a similarity score between 0 and 1. Note: setting this flag to true for IP will + throw a warning since by definition COSINE similarity is normalized IP. + + Raises: + TypeError: If filter_expression is not of type redisvl.query.FilterExpression + + Note: + Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query + """ self._vector = vector self._vector_field_name = vector_field_name @@ -456,6 +481,7 @@ def __init__( if batch_size is not None: self.set_batch_size(batch_size) + self._normalize_vector_distance = normalize_vector_distance self.set_distance_threshold(distance_threshold) self.set_filter(filter_expression) query_string = self._build_query_string() @@ -493,6 +519,14 @@ def set_distance_threshold(self, distance_threshold: float): raise TypeError("distance_threshold must be of type float or int") if distance_threshold < 0: raise ValueError("distance_threshold must be non-negative") + if self._normalize_vector_distance: + if distance_threshold > 1: + raise ValueError( + "distance_threshold must be between 0 and 1 when normalize_vector_distance is set to True" + ) + + # User sets normalized value 0-1 denormalize for use in DB + distance_threshold = denorm_cosine_distance(distance_threshold) self._distance_threshold = distance_threshold # Reset the query string diff --git a/redisvl/schema/fields.py b/redisvl/schema/fields.py index 17714480..73e8da4e 100644 --- a/redisvl/schema/fields.py +++ b/redisvl/schema/fields.py @@ -16,6 +16,14 @@ from redis.commands.search.field import TextField as RedisTextField from redis.commands.search.field import VectorField as RedisVectorField +from redisvl.utils.utils import norm_cosine_distance, norm_l2_distance + +VECTOR_NORM_MAP = { + "COSINE": norm_cosine_distance, + "L2": norm_l2_distance, + "IP": None, # normalized inner product is cosine similarity by definition +} + class FieldTypes(str, Enum): TAG = "tag" diff --git a/redisvl/utils/utils.py b/redisvl/utils/utils.py index 4c40d41a..016c40fa 100644 --- a/redisvl/utils/utils.py +++ b/redisvl/utils/utils.py @@ -191,3 +191,22 @@ def wrapper(): return return wrapper + + +def norm_cosine_distance(value: float) -> float: + """ + Normalize the cosine distance to a similarity score between 0 and 1. + """ + return max((2 - value) / 2, 0) + + +def denorm_cosine_distance(value: float) -> float: + """Denormalize the distance threshold from [0, 1] to [0, 1] for our db.""" + return max(2 - 2 * value, 0) + + +def norm_l2_distance(value: float) -> float: + """ + Normalize the L2 distance. + """ + return 1 / (1 + value) diff --git a/tests/conftest.py b/tests/conftest.py index 24da05e5..83d34c21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -141,7 +141,7 @@ def sample_data(sample_datetimes): "last_updated": sample_datetimes["high"].timestamp(), "credit_score": "medium", "location": "-110.0839,37.3861", - "user_embedding": [0.9, 0.9, 0.1], + "user_embedding": [-0.1, -0.1, -0.5], }, ] diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index 3f30c02e..9bee41f6 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -4,7 +4,7 @@ from redis.commands.search.result import Result from redisvl.index import SearchIndex -from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery +from redisvl.query import CountQuery, FilterQuery, VectorQuery, VectorRangeQuery from redisvl.query.filter import ( FilterExpression, Geo, @@ -25,6 +25,7 @@ def vector_query(): return VectorQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", + return_score=True, return_fields=[ "user", "credit_score", @@ -53,6 +54,24 @@ def sorted_vector_query(): ) +@pytest.fixture +def normalized_vector_query(): + return VectorQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + normalize_vector_distance=True, + return_score=True, + return_fields=[ + "user", + "credit_score", + "age", + "job", + "location", + "last_updated", + ], + ) + + @pytest.fixture def filter_query(): return FilterQuery( @@ -84,9 +103,21 @@ def sorted_filter_query(): ) +@pytest.fixture +def normalized_range_query(): + return VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + normalize_vector_distance=True, + return_score=True, + return_fields=["user", "credit_score", "age", "job", "location"], + distance_threshold=0.2, + ) + + @pytest.fixture def range_query(): - return RangeQuery( + return VectorRangeQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", return_fields=["user", "credit_score", "age", "job", "location"], @@ -96,7 +127,7 @@ def range_query(): @pytest.fixture def sorted_range_query(): - return RangeQuery( + return VectorRangeQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", return_fields=["user", "credit_score", "age", "job", "location"], @@ -155,6 +186,56 @@ def hash_preprocess(item: dict) -> dict: index.delete(drop=True) +@pytest.fixture +def L2_index(sample_data, redis_url): + # construct a search index from the schema + index = SearchIndex.from_dict( + { + "index": { + "name": "L2_index", + "prefix": "L2_index", + "storage_type": "hash", + }, + "fields": [ + {"name": "credit_score", "type": "tag"}, + {"name": "job", "type": "text"}, + {"name": "age", "type": "numeric"}, + {"name": "last_updated", "type": "numeric"}, + {"name": "location", "type": "geo"}, + { + "name": "user_embedding", + "type": "vector", + "attrs": { + "dims": 3, + "distance_metric": "L2", + "algorithm": "flat", + "datatype": "float32", + }, + }, + ], + }, + redis_url=redis_url, + ) + + # create the index (no data yet) + index.create(overwrite=True) + + # Prepare and load the data + def hash_preprocess(item: dict) -> dict: + return { + **item, + "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + } + + index.load(sample_data, preprocess=hash_preprocess) + + # run the test + yield index + + # clean up + index.delete(drop=True) + + def test_search_and_query(index): # *=>[KNN 7 @user_embedding $vector AS vector_distance] v = VectorQuery( @@ -191,7 +272,7 @@ def test_search_and_query(index): def test_range_query(index): - r = RangeQuery( + r = VectorRangeQuery( vector=[0.1, 0.1, 0.5], vector_field_name="user_embedding", return_fields=["user", "credit_score", "age", "job"], @@ -262,7 +343,7 @@ def search( assert doc.location == location # if range query, test results by distance threshold - if isinstance(query, RangeQuery): + if isinstance(query, VectorRangeQuery): for doc in results.docs: print(doc.vector_distance) assert float(doc.vector_distance) <= distance_threshold @@ -273,7 +354,7 @@ def search( # check results are in sorted order if sort: - if isinstance(query, RangeQuery): + if isinstance(query, VectorRangeQuery): assert [int(doc.age) for doc in results.docs] == [12, 14, 18, 100] else: assert [int(doc.age) for doc in results.docs] == [ @@ -289,7 +370,7 @@ def search( @pytest.fixture( params=["vector_query", "filter_query", "range_query"], - ids=["VectorQuery", "FilterQuery", "RangeQuery"], + ids=["VectorQuery", "FilterQuery", "VectorRangeQuery"], ) def query(request): return request.getfixturevalue(request.param) @@ -659,3 +740,53 @@ def test_range_query_with_filter_and_hybrid_policy(index): for result in results: assert result["credit_score"] == "high" assert float(result["vector_distance"]) <= 0.5 + + +def test_query_normalize_cosine_distance(index, normalized_vector_query): + + res = index.query(normalized_vector_query) + + for r in res: + assert 0 <= float(r["vector_distance"]) <= 1 + + +def test_query_cosine_distance_un_normalized(index, vector_query): + + res = index.query(vector_query) + + assert any(float(r["vector_distance"]) > 1 for r in res) + + +def test_query_l2_distance_un_normalized(L2_index, vector_query): + + res = L2_index.query(vector_query) + + assert any(float(r["vector_distance"]) > 1 for r in res) + + +def test_query_l2_distance_normalized(L2_index, normalized_vector_query): + + res = L2_index.query(normalized_vector_query) + + for r in res: + assert 0 <= float(r["vector_distance"]) <= 1 + + +def test_range_query_normalize_cosine_distance(index, normalized_range_query): + + res = index.query(normalized_range_query) + + for r in res: + assert 0 <= float(r["vector_distance"]) <= 1 + + +def test_range_query_normalize_bad_input(index): + with pytest.raises(ValueError): + VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + normalize_vector_distance=True, + return_score=True, + return_fields=["user", "credit_score", "age", "job", "location"], + distance_threshold=1.2, + ) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 83300d0c..2251998f 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -16,11 +16,30 @@ ) from redisvl.utils.utils import ( assert_no_warnings, + denorm_cosine_distance, deprecated_argument, deprecated_function, + norm_cosine_distance, ) +def test_norm_cosine_distance(): + input = 2 + expected = 0 + assert norm_cosine_distance(input) == expected + + +def test_denorm_cosine_distance(): + input = 0 + expected = 2 + assert denorm_cosine_distance(input) == expected + + +def test_norm_denorm_cosine(): + input = 0.6 + assert input == round(denorm_cosine_distance(norm_cosine_distance(input)), 6) + + def test_even_number_of_elements(): """Test with an even number of elements""" values = ["key1", "value1", "key2", "value2"]