diff --git a/docs/api/index.md b/docs/api/index.md index 5b7b6261..f7c1c661 100644 --- a/docs/api/index.md +++ b/docs/api/index.md @@ -15,6 +15,7 @@ Reference documentation for the RedisVL API. schema searchindex +vector query filter vectorizer diff --git a/docs/api/query.rst b/docs/api/query.rst index fa92230e..c2ba04f9 100644 --- a/docs/api/query.rst +++ b/docs/api/query.rst @@ -88,3 +88,17 @@ CountQuery :inherited-members: :show-inheritance: :exclude-members: add_filter,get_args,highlight,return_field,summarize + + + +MultiVectorQuery +========== + +.. currentmodule:: redisvl.query + + +.. autoclass:: MultiVectorQuery + :members: + :inherited-members: + :show-inheritance: + :exclude-members: add_filter,get_args,highlight,return_field,summarize diff --git a/docs/api/vector.rst b/docs/api/vector.rst new file mode 100644 index 00000000..9d28d9cc --- /dev/null +++ b/docs/api/vector.rst @@ -0,0 +1,17 @@ + +***** +Vector +***** + +The Vector class in RedisVL is a container that encapsulates a numerical vector, it's datatype, corresponding index field name, and optional importance weight. It is used when constructing multi-vector queries using the MultiVectorQuery class. + + +Vector +=========== + +.. currentmodule:: redisvl.query + + +.. autoclass:: Vector + :members: + :exclude-members: diff --git a/redisvl/query/__init__.py b/redisvl/query/__init__.py index 30d35562..8cae93b2 100644 --- a/redisvl/query/__init__.py +++ b/redisvl/query/__init__.py @@ -1,4 +1,9 @@ -from redisvl.query.aggregate import AggregationQuery, HybridQuery +from redisvl.query.aggregate import ( + AggregationQuery, + HybridQuery, + MultiVectorQuery, + Vector, +) from redisvl.query.query import ( BaseQuery, BaseVectorQuery, @@ -21,4 +26,6 @@ "TextQuery", "AggregationQuery", "HybridQuery", + "MultiVectorQuery", + "Vector", ] diff --git a/redisvl/query/aggregate.py b/redisvl/query/aggregate.py index fd066bce..a3a31e05 100644 --- a/redisvl/query/aggregate.py +++ b/redisvl/query/aggregate.py @@ -1,9 +1,11 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union +from pydantic import BaseModel, field_validator from redis.commands.search.aggregation import AggregateRequest, Desc from redisvl.query.filter import FilterExpression from redisvl.redis.utils import array_to_buffer +from redisvl.schema.fields import VectorDataType from redisvl.utils.token_escaper import TokenEscaper from redisvl.utils.utils import lazy_import @@ -11,6 +13,29 @@ nltk_stopwords = lazy_import("nltk.corpus.stopwords") +class Vector(BaseModel): + """ + Simple object containing the necessary arguments to perform a multi vector query. + """ + + vector: Union[List[float], bytes] + field_name: str + dtype: str = "float32" + weight: float = 1.0 + + @field_validator("dtype") + @classmethod + def validate_dtype(cls, dtype: str) -> str: + try: + VectorDataType(dtype.upper()) + except ValueError: + raise ValueError( + f"Invalid data type: {dtype}. Supported types are: {[t.lower() for t in VectorDataType]}" + ) + + return dtype + + class AggregationQuery(AggregateRequest): """ Base class for aggregation queries used to create aggregation queries for Redis. @@ -227,3 +252,149 @@ def _build_query_string(self) -> str: def __str__(self) -> str: """Return the string representation of the query.""" return " ".join([str(x) for x in self.build_args()]) + + +class MultiVectorQuery(AggregationQuery): + """ + MultiVectorQuery allows for search over multiple vector fields in a document simulateously. + The final score will be a weighted combination of the individual vector similarity scores + following the formula: + + score = (w_1 * score_1 + w_2 * score_2 + w_3 * score_3 + ... ) + + Vectors may be of different size and datatype, but must be indexed using the 'cosine' distance_metric. + + .. code-block:: python + + from redisvl.query import MultiVectorQuery, Vector + from redisvl.index import SearchIndex + + index = SearchIndex.from_yaml("path/to/index.yaml") + + vector_1 = Vector( + vector=[0.1, 0.2, 0.3], + field_name="text_vector", + dtype="float32", + weight=0.7, + ) + vector_2 = Vector( + vector=[0.5, 0.5], + field_name="image_vector", + dtype="bfloat16", + weight=0.2, + ) + vector_3 = Vector( + vector=[0.1, 0.2, 0.3], + field_name="text_vector", + dtype="float64", + weight=0.5, + ) + + query = MultiVectorQuery( + vectors=[vector_1, vector_2, vector_3], + filter_expression=None, + num_results=10, + return_fields=["field1", "field2"], + dialect=2, + ) + + results = index.query(query) + """ + + _vectors: List[Vector] + + def __init__( + self, + vectors: Union[Vector, List[Vector]], + return_fields: Optional[List[str]] = None, + filter_expression: Optional[Union[str, FilterExpression]] = None, + num_results: int = 10, + dialect: int = 2, + ): + """ + Instantiates a MultiVectorQuery object. + + Args: + vectors (Union[Vector, List[Vector]]): The Vectors to perform vector similarity search. + return_fields (Optional[List[str]], optional): The fields to return. Defaults to None. + filter_expression (Optional[Union[str, FilterExpression]]): The filter expression to use. + Defaults to None. + num_results (int, optional): The number of results to return. Defaults to 10. + dialect (int, optional): The Redis dialect version. Defaults to 2. + """ + + self._filter_expression = filter_expression + self._num_results = num_results + + if isinstance(vectors, Vector): + self._vectors = [vectors] + else: + self._vectors = vectors # type: ignore + + if not all([isinstance(v, Vector) for v in self._vectors]): + raise TypeError( + "vector argument must be a Vector object or list of Vector objects." + ) + + query_string = self._build_query_string() + super().__init__(query_string) + + # calculate the respective vector similarities + for i in range(len(self._vectors)): + self.apply(**{f"score_{i}": f"(2 - @distance_{i})/2"}) + + # construct the scoring string based on the vector similarity scores and weights + combined_scores = [] + for i, w in enumerate([v.weight for v in self._vectors]): + combined_scores.append(f"@score_{i} * {w}") + combined_score_string = " + ".join(combined_scores) + + self.apply(combined_score=combined_score_string) + + self.sort_by(Desc("@combined_score"), max=num_results) # type: ignore + self.dialect(dialect) + if return_fields: + self.load(*return_fields) # type: ignore[arg-type] + + @property + def params(self) -> Dict[str, Any]: + """Return the parameters for the aggregation. + + Returns: + Dict[str, Any]: The parameters for the aggregation. + """ + params = {} + for i, (vector, dtype) in enumerate( + [(v.vector, v.dtype) for v in self._vectors] + ): + if isinstance(vector, list): + vector = array_to_buffer(vector, dtype=dtype) # type: ignore + params[f"vector_{i}"] = vector + return params + + def _build_query_string(self) -> str: + """Build the full query string for text search with optional filtering.""" + + # base KNN query + range_queries = [] + for i, (vector, field) in enumerate( + [(v.vector, v.field_name) for v in self._vectors] + ): + range_queries.append( + f"@{field}:[VECTOR_RANGE 2.0 $vector_{i}]=>{{$YIELD_DISTANCE_AS: distance_{i}}}" + ) + + range_query = " | ".join(range_queries) + + filter_expression = self._filter_expression + if isinstance(self._filter_expression, FilterExpression): + filter_expression = str(self._filter_expression) + + if filter_expression: + return f"({range_query}) AND ({filter_expression})" + else: + return f"{range_query}" + + def __str__(self) -> str: + """Return the string representation of the query.""" + return " ".join([str(x) for x in self.build_args()]) diff --git a/tests/conftest.py b/tests/conftest.py index 8a3bdae6..692ce77d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -308,6 +308,96 @@ def sample_data(sample_datetimes): ] +@pytest.fixture +def multi_vector_data(sample_datetimes): + return [ + { + "user": "john", + "age": 18, + "job": "engineer", + "description": "engineers conduct trains that ride on train tracks", + "last_updated": sample_datetimes["low"].timestamp(), + "credit_score": "high", + "location": "-122.4194,37.7749", + "user_embedding": [0.1, 0.1, 0.5], + "image_embedding": [0.1, 0.1, 0.1, 0.1, 0.1], + "audio_embedding": [34, 18.5, -6.0, -12, 115, 96.5], + }, + { + "user": "mary", + "age": 14, + "job": "doctor", + "description": "a medical professional who treats diseases and helps people stay healthy", + "last_updated": sample_datetimes["low"].timestamp(), + "credit_score": "low", + "location": "-122.4194,37.7749", + "user_embedding": [0.1, 0.1, 0.5], + "image_embedding": [0.1, 0.2, 0.3, 0.4, 0.5], + "audio_embedding": [0.0, -1.06, 4.55, -1.93, 0.0, 1.53], + }, + { + "user": "nancy", + "age": 94, + "job": "doctor", + "description": "a research scientist specializing in cancers and diseases of the lungs", + "last_updated": sample_datetimes["mid"].timestamp(), + "credit_score": "high", + "location": "-122.4194,37.7749", + "user_embedding": [0.7, 0.1, 0.5], + "image_embedding": [0.1, 0.1, 0.3, 0.3, 0.5], + "audio_embedding": [2.75, -0.33, -3.01, -0.52, 5.59, -2.30], + }, + { + "user": "tyler", + "age": 100, + "job": "engineer", + "description": "a software developer with expertise in mathematics and computer science", + "last_updated": sample_datetimes["mid"].timestamp(), + "credit_score": "high", + "location": "-110.0839,37.3861", + "user_embedding": [0.1, 0.4, 0.5], + "image_embedding": [-0.1, -0.2, -0.3, -0.4, -0.5], + "audio_embedding": [1.11, -6.73, 5.41, 1.04, 3.92, 0.73], + }, + { + "user": "tim", + "age": 12, + "job": "dermatologist", + "description": "a medical professional specializing in diseases of the skin", + "last_updated": sample_datetimes["mid"].timestamp(), + "credit_score": "high", + "location": "-110.0839,37.3861", + "user_embedding": [0.4, 0.4, 0.5], + "image_embedding": [-0.1, 0.0, 0.6, 0.0, -0.9], + "audio_embedding": [0.03, -2.67, -2.08, 4.57, -2.33, 0.0], + }, + { + "user": "taimur", + "age": 15, + "job": "CEO", + "description": "high stress, but financially rewarding position at the head of a company", + "last_updated": sample_datetimes["high"].timestamp(), + "credit_score": "low", + "location": "-110.0839,37.3861", + "user_embedding": [0.6, 0.1, 0.5], + "image_embedding": [1.1, 1.2, -0.3, -4.1, 5.0], + "audio_embedding": [0.68, 0.26, 2.08, 2.96, 0.01, 5.13], + }, + { + "user": "joe", + "age": 35, + "job": "dentist", + "description": "like the tooth fairy because they'll take your teeth, but you have to pay them!", + "last_updated": sample_datetimes["high"].timestamp(), + "credit_score": "medium", + "location": "-110.0839,37.3861", + "user_embedding": [-0.1, -0.1, -0.5], + "image_embedding": [-0.8, 2.0, 3.1, 1.5, -1.6], + "audio_embedding": [0.91, 7.10, -2.14, -0.52, -6.08, -5.53], + }, + ] + + def pytest_addoption(parser: pytest.Parser) -> None: parser.addoption( "--run-api-tests", diff --git a/tests/integration/test_aggregation.py b/tests/integration/test_aggregation.py index 3561b1de..f08815a6 100644 --- a/tests/integration/test_aggregation.py +++ b/tests/integration/test_aggregation.py @@ -1,14 +1,15 @@ import pytest from redisvl.index import SearchIndex -from redisvl.query import HybridQuery +from redisvl.query import HybridQuery, MultiVectorQuery, Vector from redisvl.query.filter import FilterExpression, Geo, GeoRadius, Num, Tag, Text from redisvl.redis.utils import array_to_buffer from tests.conftest import skip_if_redis_version_below @pytest.fixture -def index(sample_data, redis_url, worker_id): +def index(multi_vector_data, redis_url, worker_id): + index = SearchIndex.from_dict( { "index": { @@ -33,6 +34,26 @@ def index(sample_data, redis_url, worker_id): "datatype": "float32", }, }, + { + "name": "image_embedding", + "type": "vector", + "attrs": { + "dims": 5, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float32", + }, + }, + { + "name": "audio_embedding", + "type": "vector", + "attrs": { + "dims": 6, + "distance_metric": "cosine", + "algorithm": "flat", + "datatype": "float64", + }, + }, ], }, redis_url=redis_url, @@ -46,9 +67,11 @@ def hash_preprocess(item: dict) -> dict: return { **item, "user_embedding": array_to_buffer(item["user_embedding"], "float32"), + "image_embedding": array_to_buffer(item["image_embedding"], "float32"), + "audio_embedding": array_to_buffer(item["audio_embedding"], "float64"), } - index.load(sample_data, preprocess=hash_preprocess) + index.load(multi_vector_data, preprocess=hash_preprocess) # run the test yield index @@ -57,7 +80,7 @@ def hash_preprocess(item: dict) -> dict: index.delete(drop=True) -def test_aggregation_query(index): +def test_hybrid_query(index): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -136,7 +159,7 @@ def test_empty_query_string(): ) -def test_aggregation_query_with_filter(index): +def test_hybrid_query_with_filter(index): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -162,7 +185,7 @@ def test_aggregation_query_with_filter(index): assert int(result["age"]) > 30 -def test_aggregation_query_with_geo_filter(index): +def test_hybrid_query_with_geo_filter(index): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -188,7 +211,7 @@ def test_aggregation_query_with_geo_filter(index): @pytest.mark.parametrize("alpha", [0.1, 0.5, 0.9]) -def test_aggregate_query_alpha(index, alpha): +def test_hybrid_query_alpha(index, alpha): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -215,7 +238,7 @@ def test_aggregate_query_alpha(index, alpha): ) # allow for small floating point error -def test_aggregate_query_stopwords(index): +def test_hybrid_query_stopwords(index): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -249,7 +272,7 @@ def test_aggregate_query_stopwords(index): ) # allow for small floating point error -def test_aggregate_query_with_text_filter(index): +def test_hybrid_query_with_text_filter(index): skip_if_redis_version_below(index.client, "7.2.0") text = "a medical professional with expertise in lung cancer" @@ -292,3 +315,330 @@ def test_aggregate_query_with_text_filter(index): for result in results: assert "medical" in result[text_field].lower() assert "research" not in result[text_field].lower() + + +def test_multivector_query(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_vals = [[0.1, 0.1, 0.5], [0.3, 0.4, 0.7, 0.2, -0.3]] + vector_fields = ["user_embedding", "image_embedding"] + vectors = [] + for vector, field in zip(vector_vals, vector_fields): + vectors.append(Vector(vector=vector, field_name=field)) + + return_fields = ["user", "credit_score", "age", "job", "location", "description"] + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert isinstance(results, list) + assert len(results) == 7 + for doc in results: + assert doc["user"] in [ + "john", + "derrick", + "nancy", + "tyler", + "tim", + "taimur", + "joe", + "mary", + ] + assert int(doc["age"]) in [18, 14, 94, 100, 12, 15, 35] + assert doc["job"] in ["engineer", "doctor", "dermatologist", "CEO", "dentist"] + assert doc["credit_score"] in ["high", "low", "medium"] + + multi_query = MultiVectorQuery( + vectors=vectors, + num_results=3, + ) + + results = index.query(multi_query) + assert len(results) == 3 + assert ( + results[0]["combined_score"] + >= results[1]["combined_score"] + >= results[2]["combined_score"] + ) + + +def test_multivector_query_with_filter(index): + skip_if_redis_version_below(index.client, "7.2.0") + + text_field = "description" + vector_vals = [[0.1, 0.1, 0.5], [0.3, 0.4, 0.7, 0.2, -0.3]] + vector_fields = ["user_embedding", "image_embedding"] + filter_expression = Text(text_field) == ("medical") + + vectors = [] + for vector, field in zip(vector_vals, vector_fields): + vectors.append(Vector(vector=vector, field_name=field)) + + # make sure we can still apply filters to the same text field we are querying + multi_query = MultiVectorQuery( + vectors=vectors, + filter_expression=filter_expression, + return_fields=["job", "description"], + ) + + results = index.query(multi_query) + assert len(results) == 2 + for result in results: + assert "medical" in result[text_field].lower() + + filter_expression = (Text(text_field) == ("medical")) & ( + (Text(text_field) != ("research")) + ) + multi_query = MultiVectorQuery( + vectors=vectors, + filter_expression=filter_expression, + return_fields=["description"], + ) + + results = index.query(multi_query) + assert len(results) == 2 + for result in results: + assert "medical" in result[text_field].lower() + assert "research" not in result[text_field].lower() + + filter_expression = (Num("age") > 30) & ((Num("age") < 30)) + multi_query = MultiVectorQuery( + vectors=vectors, + filter_expression=filter_expression, + return_fields=["description"], + ) + + results = index.query(multi_query) + assert len(results) == 0 + + +def test_multivector_query_with_geo_filter(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_vals = [[0.2, 0.4, 0.1], [0.1, 0.8, 0.3, -0.2, 0.3]] + vector_fields = ["user_embedding", "image_embedding"] + return_fields = ["user", "credit_score", "age", "job", "location", "description"] + filter_expression = Geo("location") == GeoRadius(-122.4194, 37.7749, 1000, "m") + + vectors = [] + for vector, field in zip(vector_vals, vector_fields): + vectors.append(Vector(vector=vector, field_name=field)) + + multi_query = MultiVectorQuery( + vectors=vectors, + filter_expression=filter_expression, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert len(results) == 3 + for result in results: + assert result["location"] is not None + + +def test_multivector_query_weights(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_vals = [[0.1, 0.2, 0.5], [0.3, 0.4, 0.7, 0.2, -0.3]] + vector_fields = ["user_embedding", "image_embedding"] + return_fields = [ + "distance_0", + "distance_1", + "score_0", + "score_1", + "user_embedding", + "image_embedding", + ] + + vectors = [] + for vector, field in zip(vector_vals, vector_fields): + vectors.append(Vector(vector=vector, field_name=field)) + + # changing the weights does indeed change the result order + multi_query_1 = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + results_1 = index.query(multi_query_1) + + weights = [0.2, 0.9] + vectors = [] + for vector, field, weight in zip(vector_vals, vector_fields, weights): + vectors.append(Vector(vector=vector, field_name=field, weight=weight)) + + multi_query_2 = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + results_2 = index.query(multi_query_2) + + assert results_1 != results_2 + + for i in range(1, len(results_1)): + assert results_1[i]["combined_score"] <= results_1[i - 1]["combined_score"] + + for i in range(1, len(results_2)): + assert results_2[i]["combined_score"] <= results_2[i - 1]["combined_score"] + + # weights can be negative, 0.0, or greater than 1.0 + weights = [-5.2, 0.0] + vectors = [] + for vector, field, weight in zip(vector_vals, vector_fields, weights): + vectors.append(Vector(vector=vector, field_name=field, weight=weight)) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert results + for r in results: + score = float(r["score_0"]) * weights[0] + assert ( + float(r["combined_score"]) - score <= 0.0001 + ) # allow for small floating point error + + # verify we're doing the combined score math correctly + weights = [-1.322, 0.851] + vectors = [] + for vector, field, weight in zip(vector_vals, vector_fields, weights): + vectors.append(Vector(vector=vector, field_name=field, weight=weight)) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert results + for r in results: + score = float(r["score_0"]) * weights[0] + float(r["score_1"]) * weights[1] + assert ( + float(r["combined_score"]) - score <= 0.0001 + ) # allow for small floating point error + + +def test_multivector_query_datatypes(index): + skip_if_redis_version_below(index.client, "7.2.0") + + vector_vals = [[0.1, 0.2, 0.5], [1.2, 0.3, -0.4, 0.7, 0.2, -0.3]] + vector_fields = ["user_embedding", "audio_embedding"] + dtypes = ["float32", "float64"] + return_fields = [ + "distance_0", + "distance_1", + "score_0", + "score_1", + "user_embedding", + "audio_embedding", + ] + + vectors = [] + for vector, field, dtype in zip(vector_vals, vector_fields, dtypes): + vectors.append(Vector(vector=vector, field_name=field, dtype=dtype)) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + results = index.query(multi_query) + + for i in range(1, len(results)): + assert results[i]["combined_score"] <= results[i - 1]["combined_score"] + + # verify we're doing the combined score math correctly + weights = [-1.322, 0.851] + vectors = [] + for vector, field, weight, dtype in zip( + vector_vals, vector_fields, weights, dtypes + ): + vectors.append( + Vector(vector=vector, field_name=field, weight=weight, dtype=dtype) + ) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert results + for r in results: + score = float(r["score_0"]) * weights[0] + float(r["score_1"]) * weights[1] + assert ( + float(r["combined_score"]) - score <= 0.0001 + ) # allow for small floating point error + + +def test_multivector_query_mixed_index(index): + # test that we can do multi vector queries on indices with both a 'flat' and 'hnsw' index + skip_if_redis_version_below(index.client, "7.2.0") + try: + index.schema.remove_field("audio_embedding") + index.schema.add_field( + { + "name": "audio_embedding", + "type": "vector", + "attrs": { + "dims": 6, + "distance_metric": "cosine", + "algorithm": "hnsw", + "datatype": "float64", + }, + }, + ) + + except: + pytest.skip("Required Redis modules not available or version too low") + + vector_vals = [[0.1, 0.2, 0.5], [1.2, 0.3, -0.4, 0.7, 0.2, -0.3]] + vector_fields = ["user_embedding", "audio_embedding"] + dtypes = ["float32", "float64"] + return_fields = [ + "distance_0", + "distance_1", + "score_0", + "score_1", + "user_embedding", + "audio_embedding", + ] + + vectors = [] + for vector, field, dtype in zip(vector_vals, vector_fields, dtypes): + vectors.append(Vector(vector=vector, field_name=field, dtype=dtype)) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + results = index.query(multi_query) + + for i in range(1, len(results)): + assert results[i]["combined_score"] <= results[i - 1]["combined_score"] + + # verify we're doing the combined score math correctly + weights = [-1.322, 0.851] + vectors = [] + for vector, field, dtype, weight in zip( + vector_vals, vector_fields, dtypes, weights + ): + vectors.append( + Vector(vector=vector, field_name=field, dtype=dtype, weight=weight) + ) + + multi_query = MultiVectorQuery( + vectors=vectors, + return_fields=return_fields, + ) + + results = index.query(multi_query) + assert results + for r in results: + score = float(r["score_0"]) * weights[0] + float(r["score_1"]) * weights[1] + assert ( + float(r["combined_score"]) - score <= 0.0001 + ) # allow for small floating point error diff --git a/tests/unit/test_aggregation_types.py b/tests/unit/test_aggregation_types.py index a13e87f5..f2b6be86 100644 --- a/tests/unit/test_aggregation_types.py +++ b/tests/unit/test_aggregation_types.py @@ -4,13 +4,17 @@ from redis.commands.search.result import Result from redisvl.index.index import process_results -from redisvl.query.aggregate import HybridQuery +from redisvl.query.aggregate import HybridQuery, MultiVectorQuery, Vector from redisvl.query.filter import Tag # Sample data for testing sample_vector = [0.1, 0.2, 0.3, 0.4] sample_text = "the toon squad play basketball against a gang of aliens" +sample_vector_2 = [0.1, 0.2, 0.3, 0.4] +sample_vector_3 = [0.5, 0.5] +sample_vector_4 = [0.1, 0.1, 0.1] + # Test Cases def test_aggregate_hybrid_query(): @@ -87,6 +91,7 @@ def test_aggregate_hybrid_query(): stopwords=["the", "a", "of"], ) assert hybrid_query.stopwords == set(["the", "a", "of"]) + hybrid_query = HybridQuery( sample_text, text_field_name, @@ -137,7 +142,6 @@ def test_hybrid_query_with_string_filter(): ) # Check that filter is stored correctly - print("hybrid_query.filter ===", hybrid_query.filter) assert hybrid_query._filter_expression == string_filter # Check that the generated query string includes both text search and filter @@ -190,3 +194,123 @@ def test_hybrid_query_with_string_filter(): query_string_wildcard = str(hybrid_query_wildcard) assert f"@{text_field_name}:(search | document | 12345)" in query_string_wildcard assert "AND" not in query_string_wildcard + + +def test_multi_vector_query(): + # test we require Vector objects + with pytest.raises(TypeError): + _ = MultiVectorQuery() + + with pytest.raises(TypeError): + _ = MultiVectorQuery(vector=[sample_vector]) + + with pytest.raises(TypeError): + _ = MultiVectorQuery(vectors=[[0.1, 0.1, 0.1], "field_1"]) + + # test we can initialize with a single vector and single field name + multivector_query = MultiVectorQuery( + Vector(vector=sample_vector, field_name="field_1") + ) + + # check default properties + assert multivector_query._vectors == [ + Vector(vector=sample_vector, field_name="field_1") + ] + assert multivector_query._vectors[0].field_name == "field_1" + assert multivector_query._vectors[0].weight == 1.0 + assert multivector_query._vectors[0].dtype == "float32" + assert multivector_query._filter_expression == None + assert multivector_query._num_results == 10 + assert multivector_query._loadfields == [] + assert multivector_query._dialect == 2 + + # test we can initialize with multiple Vectors + vectors = [sample_vector, sample_vector_2, sample_vector_3, sample_vector_4] + vector_field_names = ["field_1", "field_2", "field_3", "field_4"] + weights = [0.2, 0.5, 0.6, 0.1] + dtypes = ["float32", "float32", "float32", "float32"] + + args = [] + for vec, field, weight, dtype in zip(vectors, vector_field_names, weights, dtypes): + args.append(Vector(vector=vec, field_name=field, weight=weight, dtype=dtype)) + + multivector_query = MultiVectorQuery(vectors=args) + + assert len(multivector_query._vectors) == 4 + assert multivector_query._vectors == args + + # test defaults can be overwritten + filter_expression = Tag("user group") == ["group A", "group C"] + + multivector_query = MultiVectorQuery( + vectors=args, + filter_expression=filter_expression, + num_results=5, + return_fields=["field_1", "user name", "address"], + dialect=4, + ) + + assert multivector_query._filter_expression == filter_expression + assert multivector_query._num_results == 5 + assert multivector_query._loadfields == ["field_1", "user name", "address"] + assert multivector_query._dialect == 4 + + +def test_multi_vector_query_string(): + # if a single weight is passed it is applied to all similarity scores + field_1 = "text embedding" + field_2 = "image embedding" + weight_1 = 0.2 + weight_2 = 0.7 + multi_vector_query = MultiVectorQuery( + vectors=[ + Vector(vector=sample_vector_2, field_name=field_1, weight=weight_1), + Vector(vector=sample_vector_3, field_name=field_2, weight=weight_2), + ] + ) + + assert ( + str(multi_vector_query) + == f"@{field_1}:[VECTOR_RANGE 2.0 $vector_0]=>{{$YIELD_DISTANCE_AS: distance_0}} | @{field_2}:[VECTOR_RANGE 2.0 $vector_1]=>{{$YIELD_DISTANCE_AS: distance_1}} SCORER TFIDF DIALECT 2 APPLY (2 - @distance_0)/2 AS score_0 APPLY (2 - @distance_1)/2 AS score_1 APPLY @score_0 * {weight_1} + @score_1 * {weight_2} AS combined_score SORTBY 2 @combined_score DESC MAX 10" + ) + + +def test_vector_object_validation(): + # test an error is raised if none of the field names are present + with pytest.raises(ValueError): + _ = Vector() + + with pytest.raises(ValueError): + _ = Vector( + vector=[], + field_name=[], + ) + + # test an error is raised if the type of vector or fields are incorrect + # no list of list of floats + with pytest.raises(ValueError): + _ = Vector( + vector=[sample_vector, sample_vector_2, sample_vector_3], + field_name="text embedding", + ) + + # no list as field name + with pytest.raises(ValueError): + _ = Vector( + vector=sample_vector, + field_name=["text embedding", "image embedding", "features"], + ) + + # dtype must be one of the supported values + with pytest.raises(ValueError): + _ = Vector(vector=sample_vector, field_name="text embedding", dtype="float") + + with pytest.raises(ValueError): + _ = Vector(vector=sample_vector, field_name="text embedding", dtype="normal") + + with pytest.raises(ValueError): + _ = Vector(vector=sample_vector, field_name="text embedding", dtype="") + + for dtype in ["bfloat16", "float16", "float32", "float64", "int8", "uint8"]: + vec = Vector(vector=sample_vector, field_name="text embedding", dtype=dtype) + assert isinstance(vec, Vector)