Skip to content

Add option to normalize vector distances on query #298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions redisvl/extensions/llmcache/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SemanticCacheIndexSchema,
)
from redisvl.index import AsyncSearchIndex, SearchIndex
from redisvl.query import RangeQuery
from redisvl.query import VectorRangeQuery
from redisvl.query.filter import FilterExpression
from redisvl.query.query import BaseQuery
from redisvl.redis.connection import RedisConnectionFactory
Expand Down Expand Up @@ -238,9 +238,9 @@ def set_threshold(self, distance_threshold: float) -> None:
Raises:
ValueError: If the threshold is not between 0 and 1.
"""
if not 0 <= float(distance_threshold) <= 1:
if not 0 <= float(distance_threshold) <= 2:
raise ValueError(
f"Distance must be between 0 and 1, got {distance_threshold}"
f"Distance must be between 0 and 2, got {distance_threshold}"
)
self._distance_threshold = float(distance_threshold)

Expand Down Expand Up @@ -390,7 +390,7 @@ def check(
vector = vector or self._vectorize_prompt(prompt)
self._check_vector_dims(vector)

query = RangeQuery(
query = VectorRangeQuery(
vector=vector,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
Expand Down Expand Up @@ -473,14 +473,15 @@ async def acheck(
vector = vector or await self._avectorize_prompt(prompt)
self._check_vector_dims(vector)

query = RangeQuery(
query = VectorRangeQuery(
vector=vector,
vector_field_name=CACHE_VECTOR_FIELD_NAME,
return_fields=self.return_fields,
distance_threshold=distance_threshold,
num_results=num_results,
return_score=True,
filter_expression=filter_expression,
normalize_vector_distance=True,
)

# Search the cache!
Expand Down
2 changes: 1 addition & 1 deletion redisvl/extensions/router/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class Route(BaseModel):
"""List of reference phrases for the route."""
metadata: Dict[str, Any] = Field(default={})
"""Metadata associated with the route."""
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5
"""Distance threshold for matching the route."""

@field_validator("name")
Expand Down
6 changes: 3 additions & 3 deletions redisvl/extensions/router/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
SemanticRouterIndexSchema,
)
from redisvl.index import SearchIndex
from redisvl.query import RangeQuery
from redisvl.query import VectorRangeQuery
from redisvl.redis.utils import convert_bytes, hashify, make_dict
from redisvl.utils.log import get_logger
from redisvl.utils.utils import deprecated_argument, model_to_dict
Expand Down Expand Up @@ -237,7 +237,7 @@ def _distance_threshold_filter(self) -> str:

def _build_aggregate_request(
self,
vector_range_query: RangeQuery,
vector_range_query: VectorRangeQuery,
aggregation_method: DistanceAggregationMethod,
max_k: int,
) -> AggregateRequest:
Expand Down Expand Up @@ -279,7 +279,7 @@ def _get_route_matches(
# therefore you might take the max_threshold and further refine from there.
distance_threshold = max(route.distance_threshold for route in self.routes)

vector_range_query = RangeQuery(
vector_range_query = VectorRangeQuery(
vector=vector,
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
distance_threshold=float(distance_threshold),
Expand Down
42 changes: 27 additions & 15 deletions redisvl/index/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,15 @@

from redisvl.exceptions import RedisModuleVersionError, RedisSearchError
from redisvl.index.storage import BaseStorage, HashStorage, JsonStorage
from redisvl.query import BaseQuery, CountQuery, FilterQuery
from redisvl.query import BaseQuery, BaseVectorQuery, CountQuery, FilterQuery
from redisvl.query.filter import FilterExpression
from redisvl.redis.connection import (
RedisConnectionFactory,
convert_index_info_to_schema,
)
from redisvl.redis.utils import convert_bytes
from redisvl.schema import IndexSchema, StorageType
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
from redisvl.utils.log import get_logger

logger = get_logger(__name__)
Expand All @@ -62,7 +63,7 @@


def process_results(
results: "Result", query: BaseQuery, storage_type: StorageType
results: "Result", query: BaseQuery, schema: IndexSchema
) -> List[Dict[str, Any]]:
"""Convert a list of search Result objects into a list of document
dictionaries.
Expand All @@ -87,11 +88,24 @@ def process_results(

# Determine if unpacking JSON is needed
unpack_json = (
(storage_type == StorageType.JSON)
(schema.index.storage_type == StorageType.JSON)
and isinstance(query, FilterQuery)
and not query._return_fields # type: ignore
)

if (isinstance(query, BaseVectorQuery)) and query._normalize_vector_distance:
dist_metric = VectorDistanceMetric(
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
)
if dist_metric == VectorDistanceMetric.IP:
warnings.warn(
"Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition."
)

norm_fn = VECTOR_NORM_MAP[dist_metric.value]
else:
norm_fn = None

# Process records
def _process(doc: "Document") -> Dict[str, Any]:
doc_dict = doc.__dict__
Expand All @@ -105,6 +119,12 @@ def _process(doc: "Document") -> Dict[str, Any]:
return {"id": doc_dict.get("id"), **json_data}
raise ValueError(f"Unable to parse json data from Redis {json_data}")

if norm_fn:
# convert float back to string to be consistent
doc_dict[query.DISTANCE_ID] = str( # type: ignore
norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore
)

# Remove 'payload' if present
doc_dict.pop("payload", None)

Expand Down Expand Up @@ -757,11 +777,7 @@ def batch_query(
)
all_parsed = []
for query, batch_results in zip(queries, results):
parsed = process_results(
batch_results,
query=query,
storage_type=self.schema.index.storage_type,
)
parsed = process_results(batch_results, query=query, schema=self.schema)
# Create separate lists of parsed results for each query
# passed in to the batch_search method, so that callers can
# access the results for each query individually
Expand All @@ -771,9 +787,7 @@ def batch_query(
def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Execute a query and process results."""
results = self.search(query.query, query_params=query.params)
return process_results(
results, query=query, storage_type=self.schema.index.storage_type
)
return process_results(results, query=query, schema=self.schema)

def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Execute a query on the index.
Expand Down Expand Up @@ -1403,7 +1417,7 @@ async def batch_query(
parsed = process_results(
batch_results,
query=query,
storage_type=self.schema.index.storage_type,
schema=self.schema,
)
# Create separate lists of parsed results for each query
# passed in to the batch_search method, so that callers can
Expand All @@ -1415,9 +1429,7 @@ async def batch_query(
async def _query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Asynchronously execute a query and process results."""
results = await self.search(query.query, query_params=query.params)
return process_results(
results, query=query, storage_type=self.schema.index.storage_type
)
return process_results(results, query=query, schema=self.schema)

async def query(self, query: BaseQuery) -> List[Dict[str, Any]]:
"""Asynchronously execute a query on the index.
Expand Down
2 changes: 2 additions & 0 deletions redisvl/query/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from redisvl.query.query import (
BaseQuery,
BaseVectorQuery,
CountQuery,
FilterQuery,
RangeQuery,
Expand All @@ -9,6 +10,7 @@

__all__ = [
"BaseQuery",
"BaseVectorQuery",
"VectorQuery",
"FilterQuery",
"RangeQuery",
Expand Down
34 changes: 34 additions & 0 deletions redisvl/query/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from redisvl.query.filter import FilterExpression
from redisvl.redis.utils import array_to_buffer
from redisvl.utils.utils import denorm_cosine_distance


class BaseQuery(RedisQuery):
Expand Down Expand Up @@ -175,6 +176,8 @@ class BaseVectorQuery:
DISTANCE_ID: str = "vector_distance"
VECTOR_PARAM: str = "vector"

_normalize_vector_distance: bool = False


class HybridPolicy(str, Enum):
"""Enum for valid hybrid policy options in vector queries."""
Expand All @@ -198,6 +201,7 @@ def __init__(
in_order: bool = False,
hybrid_policy: Optional[str] = None,
batch_size: Optional[int] = None,
normalize_vector_distance: bool = False,
):
"""A query for running a vector search along with an optional filter
expression.
Expand Down Expand Up @@ -233,6 +237,12 @@ def __init__(
of vectors to fetch in each batch. Larger values may improve performance
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
Defaults to None, which lets Redis auto-select an appropriate batch size.
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
COSINE distance returns a value between 0 and 2. IP returns a value determined by
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
throw a warning since by definition COSINE similarity is normalized IP.

Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
Expand All @@ -246,6 +256,7 @@ def __init__(
self._num_results = num_results
self._hybrid_policy: Optional[HybridPolicy] = None
self._batch_size: Optional[int] = None
self._normalize_vector_distance = normalize_vector_distance
self.set_filter(filter_expression)
query_string = self._build_query_string()

Expand Down Expand Up @@ -394,6 +405,7 @@ def __init__(
in_order: bool = False,
hybrid_policy: Optional[str] = None,
batch_size: Optional[int] = None,
normalize_vector_distance: bool = False,
):
"""A query for running a filtered vector search based on semantic
distance threshold.
Expand Down Expand Up @@ -437,6 +449,19 @@ def __init__(
of vectors to fetch in each batch. Larger values may improve performance
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
Defaults to None, which lets Redis auto-select an appropriate batch size.
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
COSINE distance returns a value between 0 and 2. IP returns a value determined by
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
throw a warning since by definition COSINE similarity is normalized IP.

Raises:
TypeError: If filter_expression is not of type redisvl.query.FilterExpression

Note:
Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query

"""
self._vector = vector
self._vector_field_name = vector_field_name
Expand All @@ -456,6 +481,7 @@ def __init__(
if batch_size is not None:
self.set_batch_size(batch_size)

self._normalize_vector_distance = normalize_vector_distance
self.set_distance_threshold(distance_threshold)
self.set_filter(filter_expression)
query_string = self._build_query_string()
Expand Down Expand Up @@ -493,6 +519,14 @@ def set_distance_threshold(self, distance_threshold: float):
raise TypeError("distance_threshold must be of type float or int")
if distance_threshold < 0:
raise ValueError("distance_threshold must be non-negative")
if self._normalize_vector_distance:
if distance_threshold > 1:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

raise ValueError(
"distance_threshold must be between 0 and 1 when normalize_vector_distance is set to True"
)

# User sets normalized value 0-1 denormalize for use in DB
distance_threshold = denorm_cosine_distance(distance_threshold)
self._distance_threshold = distance_threshold

# Reset the query string
Expand Down
8 changes: 8 additions & 0 deletions redisvl/schema/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,14 @@
from redis.commands.search.field import TextField as RedisTextField
from redis.commands.search.field import VectorField as RedisVectorField

from redisvl.utils.utils import norm_cosine_distance, norm_l2_distance

VECTOR_NORM_MAP = {
"COSINE": norm_cosine_distance,
"L2": norm_l2_distance,
"IP": None, # normalized inner product is cosine similarity by definition
}


class FieldTypes(str, Enum):
TAG = "tag"
Expand Down
19 changes: 19 additions & 0 deletions redisvl/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,3 +191,22 @@ def wrapper():
return

return wrapper


def norm_cosine_distance(value: float) -> float:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could add additional check logic to this function kept it simple stupid to start

"""
Normalize the cosine distance to a similarity score between 0 and 1.
"""
return max((2 - value) / 2, 0)


def denorm_cosine_distance(value: float) -> float:
"""Denormalize the distance threshold from [0, 1] to [0, 1] for our db."""
return max(2 - 2 * value, 0)


def norm_l2_distance(value: float) -> float:
"""
Normalize the L2 distance.
"""
return 1 / (1 + value)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def sample_data(sample_datetimes):
"last_updated": sample_datetimes["high"].timestamp(),
"credit_score": "medium",
"location": "-110.0839,37.3861",
"user_embedding": [0.9, 0.9, 0.1],
"user_embedding": [-0.1, -0.1, -0.5],
},
]

Expand Down
Loading