Skip to content

Commit 4453314

Browse files
committed
update for vector_norm map
1 parent 5ddb998 commit 4453314

File tree

9 files changed

+77
-31
lines changed

9 files changed

+77
-31
lines changed

redisvl/extensions/llmcache/semantic.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SemanticCacheIndexSchema,
2323
)
2424
from redisvl.index import AsyncSearchIndex, SearchIndex
25-
from redisvl.query import RangeQuery
25+
from redisvl.query import VectorRangeQuery
2626
from redisvl.query.filter import FilterExpression
2727
from redisvl.query.query import BaseQuery
2828
from redisvl.redis.connection import RedisConnectionFactory
@@ -390,7 +390,7 @@ def check(
390390
vector = vector or self._vectorize_prompt(prompt)
391391
self._check_vector_dims(vector)
392392

393-
query = RangeQuery(
393+
query = VectorRangeQuery(
394394
vector=vector,
395395
vector_field_name=CACHE_VECTOR_FIELD_NAME,
396396
return_fields=self.return_fields,
@@ -473,14 +473,15 @@ async def acheck(
473473
vector = vector or await self._avectorize_prompt(prompt)
474474
self._check_vector_dims(vector)
475475

476-
query = RangeQuery(
476+
query = VectorRangeQuery(
477477
vector=vector,
478478
vector_field_name=CACHE_VECTOR_FIELD_NAME,
479479
return_fields=self.return_fields,
480480
distance_threshold=distance_threshold,
481481
num_results=num_results,
482482
return_score=True,
483483
filter_expression=filter_expression,
484+
normalize_vector_distance=True,
484485
)
485486

486487
# Search the cache!

redisvl/extensions/router/schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Route(BaseModel):
1818
"""List of reference phrases for the route."""
1919
metadata: Dict[str, Any] = Field(default={})
2020
"""Metadata associated with the route."""
21-
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=1)] = 0.5
21+
distance_threshold: Annotated[float, Field(strict=True, gt=0, le=2)] = 0.5
2222
"""Distance threshold for matching the route."""
2323

2424
@field_validator("name")

redisvl/extensions/router/semantic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
SemanticRouterIndexSchema,
1818
)
1919
from redisvl.index import SearchIndex
20-
from redisvl.query import RangeQuery
20+
from redisvl.query import VectorRangeQuery
2121
from redisvl.redis.utils import convert_bytes, hashify, make_dict
2222
from redisvl.utils.log import get_logger
2323
from redisvl.utils.utils import deprecated_argument, model_to_dict
@@ -237,7 +237,7 @@ def _distance_threshold_filter(self) -> str:
237237

238238
def _build_aggregate_request(
239239
self,
240-
vector_range_query: RangeQuery,
240+
vector_range_query: VectorRangeQuery,
241241
aggregation_method: DistanceAggregationMethod,
242242
max_k: int,
243243
) -> AggregateRequest:
@@ -279,7 +279,7 @@ def _get_route_matches(
279279
# therefore you might take the max_threshold and further refine from there.
280280
distance_threshold = max(route.distance_threshold for route in self.routes)
281281

282-
vector_range_query = RangeQuery(
282+
vector_range_query = VectorRangeQuery(
283283
vector=vector,
284284
vector_field_name=ROUTE_VECTOR_FIELD_NAME,
285285
distance_threshold=float(distance_threshold),

redisvl/index/index.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
)
5050
from redisvl.redis.utils import convert_bytes
5151
from redisvl.schema import IndexSchema, StorageType
52-
from redisvl.schema.fields import VectorDistanceMetric
52+
from redisvl.schema.fields import VECTOR_NORM_MAP, VectorDistanceMetric
5353
from redisvl.utils.log import get_logger
5454

5555
logger = get_logger(__name__)
@@ -92,12 +92,20 @@ def process_results(
9292
and not query._return_fields # type: ignore
9393
)
9494

95-
normalize_cosine_distance = (
96-
(isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery))
97-
and query._normalize_cosine_distance
98-
and schema.fields[query._vector_field_name].attrs.distance_metric # type: ignore
99-
== VectorDistanceMetric.COSINE
100-
)
95+
if (
96+
isinstance(query, VectorQuery) or isinstance(query, VectorRangeQuery)
97+
) and query._normalize_vector_distance:
98+
dist_metric = VectorDistanceMetric(
99+
schema.fields[query._vector_field_name].attrs.distance_metric.upper() # type: ignore
100+
)
101+
if dist_metric == VectorDistanceMetric.IP:
102+
warnings.warn(
103+
"Attempting to normalize inner product distance metric. Use cosine distance instead which is normalized inner product by definition."
104+
)
105+
106+
norm_fn = VECTOR_NORM_MAP[dist_metric.value]
107+
else:
108+
norm_fn = None
101109

102110
# Process records
103111
def _process(doc: "Document") -> Dict[str, Any]:
@@ -112,10 +120,10 @@ def _process(doc: "Document") -> Dict[str, Any]:
112120
return {"id": doc_dict.get("id"), **json_data}
113121
raise ValueError(f"Unable to parse json data from Redis {json_data}")
114122

115-
if normalize_cosine_distance:
123+
if norm_fn:
116124
# convert float back to string to be consistent
117125
doc_dict[query.DISTANCE_ID] = str( # type: ignore
118-
norm_cosine_distance(float(doc_dict[query.DISTANCE_ID])) # type: ignore
126+
norm_fn(float(doc_dict[query.DISTANCE_ID])) # type: ignore
119127
)
120128

121129
# Remove 'payload' if present

redisvl/query/query.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(
198198
in_order: bool = False,
199199
hybrid_policy: Optional[str] = None,
200200
batch_size: Optional[int] = None,
201-
normalize_cosine_distance: bool = False,
201+
normalize_vector_distance: bool = False,
202202
):
203203
"""A query for running a vector search along with an optional filter
204204
expression.
@@ -234,9 +234,12 @@ def __init__(
234234
of vectors to fetch in each batch. Larger values may improve performance
235235
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
236236
Defaults to None, which lets Redis auto-select an appropriate batch size.
237-
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
238-
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
239-
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
237+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
238+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
239+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
240+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
241+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
242+
throw a warning since by definition COSINE similarity is normalized IP.
240243
241244
Raises:
242245
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -250,7 +253,7 @@ def __init__(
250253
self._num_results = num_results
251254
self._hybrid_policy: Optional[HybridPolicy] = None
252255
self._batch_size: Optional[int] = None
253-
self._normalize_cosine_distance = normalize_cosine_distance
256+
self._normalize_vector_distance = normalize_vector_distance
254257
self.set_filter(filter_expression)
255258
query_string = self._build_query_string()
256259

@@ -399,7 +402,7 @@ def __init__(
399402
in_order: bool = False,
400403
hybrid_policy: Optional[str] = None,
401404
batch_size: Optional[int] = None,
402-
normalize_cosine_distance: bool = False,
405+
normalize_vector_distance: bool = False,
403406
):
404407
"""A query for running a filtered vector search based on semantic
405408
distance threshold.
@@ -443,9 +446,12 @@ def __init__(
443446
of vectors to fetch in each batch. Larger values may improve performance
444447
at the cost of memory usage. Only applies when hybrid_policy="BATCHES".
445448
Defaults to None, which lets Redis auto-select an appropriate batch size.
446-
normalize_cosine_distance (bool): by default Redis returns cosine distance as a value
447-
between 0 and 2 where 0 is the best match. If set to True, the cosine distance will be
448-
converted to cosine similarity with a value between 0 and 1 where 1 is the best match.
449+
normalize_vector_distance (bool): Redis supports 3 distance metrics: L2 (euclidean),
450+
IP (inner product), and COSINE. By default, L2 distance returns an unbounded value.
451+
COSINE distance returns a value between 0 and 2. IP returns a value determined by
452+
the magnitude of the vector. Setting this flag to true converts COSINE and L2 distance
453+
to a similarity score between 0 and 1. Note: setting this flag to true for IP will
454+
throw a warning since by definition COSINE similarity is normalized IP.
449455
450456
Raises:
451457
TypeError: If filter_expression is not of type redisvl.query.FilterExpression
@@ -472,7 +478,7 @@ def __init__(
472478
if batch_size is not None:
473479
self.set_batch_size(batch_size)
474480

475-
self._normalize_cosine_distance = normalize_cosine_distance
481+
self._normalize_vector_distance = normalize_vector_distance
476482
self.set_distance_threshold(distance_threshold)
477483
self.set_filter(filter_expression)
478484
query_string = self._build_query_string()

redisvl/schema/fields.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@
1616
from redis.commands.search.field import TextField as RedisTextField
1717
from redis.commands.search.field import VectorField as RedisVectorField
1818

19+
from redisvl.utils.utils import norm_cosine_distance, norm_l2_distance
20+
21+
VECTOR_NORM_MAP = {
22+
"COSINE": norm_cosine_distance,
23+
"L2": norm_l2_distance,
24+
"IP": None, # normalized inner product is cosine similarity by definition
25+
}
26+
1927

2028
class FieldTypes(str, Enum):
2129
TAG = "tag"

redisvl/utils/utils.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,3 +198,10 @@ def norm_cosine_distance(value: float) -> float:
198198
Normalize the cosine distance to a similarity score between 0 and 1.
199199
"""
200200
return (2 - value) / 2
201+
202+
203+
def norm_l2_distance(value: float) -> float:
204+
"""
205+
Normalize the L2 distance.
206+
"""
207+
return 1 / (1 + value)

tests/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def sample_data(sample_datetimes):
141141
"last_updated": sample_datetimes["high"].timestamp(),
142142
"credit_score": "medium",
143143
"location": "-110.0839,37.3861",
144-
"user_embedding": [0.9, 0.9, 0.1],
144+
"user_embedding": [-0.1, -0.1, -0.5],
145145
},
146146
]
147147

tests/integration/test_query.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def vector_query():
2525
return VectorQuery(
2626
vector=[0.1, 0.1, 0.5],
2727
vector_field_name="user_embedding",
28+
return_score=True,
2829
return_fields=[
2930
"user",
3031
"credit_score",
@@ -58,7 +59,7 @@ def normalized_vector_query():
5859
return VectorQuery(
5960
vector=[0.1, 0.1, 0.5],
6061
vector_field_name="user_embedding",
61-
normalize_cosine_distance=True,
62+
normalize_vector_distance=True,
6263
return_score=True,
6364
return_fields=[
6465
"user",
@@ -107,7 +108,7 @@ def normalized_range_query():
107108
return RangeQuery(
108109
vector=[0.1, 0.1, 0.5],
109110
vector_field_name="user_embedding",
110-
normalize_cosine_distance=True,
111+
normalize_vector_distance=True,
111112
return_score=True,
112113
return_fields=["user", "credit_score", "age", "job", "location"],
113114
distance_threshold=0.2,
@@ -749,13 +750,28 @@ def test_query_normalize_cosine_distance(index, normalized_vector_query):
749750
assert 0 <= float(r["vector_distance"]) <= 1
750751

751752

752-
def test_query_normalize_cosine_distance_lp_distance(L2_index, normalized_vector_query):
753+
def test_query_cosine_distance_un_normalized(index, vector_query):
753754

754-
res = L2_index.query(normalized_vector_query)
755+
res = index.query(vector_query)
756+
757+
assert any(float(r["vector_distance"]) > 1 for r in res)
758+
759+
760+
def test_query_l2_distance_un_normalized(L2_index, vector_query):
761+
762+
res = L2_index.query(vector_query)
755763

756764
assert any(float(r["vector_distance"]) > 1 for r in res)
757765

758766

767+
def test_query_l2_distance_normalized(L2_index, normalized_vector_query):
768+
769+
res = L2_index.query(normalized_vector_query)
770+
771+
for r in res:
772+
assert 0 <= float(r["vector_distance"]) <= 1
773+
774+
759775
def test_range_query_normalize_cosine_distance(index, normalized_range_query):
760776

761777
res = index.query(normalized_range_query)

0 commit comments

Comments
 (0)