From 8dc308547b1520ab0c324976be489e43cef1e397 Mon Sep 17 00:00:00 2001 From: Andrew Crump Date: Fri, 30 Aug 2024 01:57:19 +0000 Subject: [PATCH] redisvector: fix score threshold option Vector range queries expect a radius rather than a score threshold. --- vectorstores/redisvector/index_search.go | 6 +++--- vectorstores/redisvector/redis_vector_test.go | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/vectorstores/redisvector/index_search.go b/vectorstores/redisvector/index_search.go index 6485dd99c..13dadc35b 100644 --- a/vectorstores/redisvector/index_search.go +++ b/vectorstores/redisvector/index_search.go @@ -98,19 +98,19 @@ func (s IndexVectorSearch) AsCommand() []string { const vectorField = "vector" const vectorFieldAs = defaultDistanceFieldKey - const disThresholdFiled = "distance_threshold" + const disThresholdField = "distance_threshold" const vectorKey = defaultContentVectorFieldKey params := []string{vectorField, VectorString32(s.vector)} if s.scoreThreshold > 0 && s.scoreThreshold < 1 { // Range search // "@content_vector:[VECTOR_RANGE $distance_threshold $vector]=>{$yield_distance_as: distance}" - filter := fmt.Sprintf("@%s:[VECTOR_RANGE $%s $%s]=>{$yield_distance_as: %s}", vectorKey, disThresholdFiled, vectorField, vectorFieldAs) + filter := fmt.Sprintf("@%s:[VECTOR_RANGE $%s $%s]=>{$yield_distance_as: %s}", vectorKey, disThresholdField, vectorField, vectorFieldAs) if len(s.preFilters) > 0 { filter = fmt.Sprintf("(%s) %s", s.preFilters, filter) } cmd = append(cmd, filter) - params = append(params, disThresholdFiled, strconv.FormatFloat(float64(s.scoreThreshold), 'f', -1, 32)) + params = append(params, disThresholdField, strconv.FormatFloat(float64(1.0-s.scoreThreshold), 'f', -1, 32)) } else { // KNN search // "(*)=>[KNN n @content_vector $vector AS distance]" diff --git a/vectorstores/redisvector/redis_vector_test.go b/vectorstores/redisvector/redis_vector_test.go index 0305d045e..f9c7e5d4b 100644 --- a/vectorstores/redisvector/redis_vector_test.go +++ b/vectorstores/redisvector/redis_vector_test.go @@ -303,8 +303,8 @@ func TestSimilaritySearch(t *testing.T) { assert.Len(t, docs[0].Metadata, 3) // search with score threshold - docs, err = store.SimilaritySearch(ctx, "Tokyo", 2, - vectorstores.WithScoreThreshold(0.5), + docs, err = store.SimilaritySearch(ctx, "Tokyo", 10, + vectorstores.WithScoreThreshold(0.8), ) require.NoError(t, err) assert.Len(t, docs, 2)