From 8d63bc41553a1234d2b930f9203048e654bc76b7 Mon Sep 17 00:00:00 2001 From: Brian Sam-Bodden Date: Mon, 24 Mar 2025 15:49:34 -0700 Subject: [PATCH 1/2] Adds support for HYBRID_POLICY on KNN queries (VectorQuery and VectorRangeQuery) with filters --- redisvl/query/query.py | 280 +++++++++++++++++++++++++++++--- tests/integration/test_query.py | 128 +++++++++++++++ tests/unit/test_query_types.py | 275 ++++++++++++++++++++++++++++++- 3 files changed, 657 insertions(+), 26 deletions(-) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index d99a080e..8182952e 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -188,6 +188,8 @@ def __init__( dialect: int = 2, sort_by: Optional[str] = None, in_order: bool = False, + hybrid_policy: Optional[str] = None, + batch_size: Optional[int] = None, ): """A query for running a vector search along with an optional filter expression. @@ -213,6 +215,16 @@ def __init__( in_order (bool): Requires the terms in the field to have the same order as the terms in the query filter, regardless of the offsets between them. Defaults to False. + hybrid_policy (Optional[str]): Controls how filters are applied during vector search. + Options are "BATCHES" (paginates through small batches of nearest neighbors) or + "ADHOC_BF" (computes scores for all vectors passing the filter). + "BATCHES" mode is typically faster for queries with selective filters. + "ADHOC_BF" mode is better when filters match a large portion of the dataset. + Defaults to None, which lets Redis auto-select the optimal policy. + batch_size (Optional[int]): When hybrid_policy is "BATCHES", controls the number + of vectors to fetch in each batch. Larger values may improve performance + at the cost of memory usage. Only applies when hybrid_policy="BATCHES". + Defaults to None, which lets Redis auto-select an appropriate batch size. Raises: TypeError: If filter_expression is not of type redisvl.query.FilterExpression @@ -224,6 +236,8 @@ def __init__( self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results + self._hybrid_policy: Optional[str] = None + self._batch_size: Optional[int] = None self.set_filter(filter_expression) query_string = self._build_query_string() @@ -246,12 +260,89 @@ def __init__( if in_order: self.in_order() + if hybrid_policy is not None: + self.set_hybrid_policy(hybrid_policy) + + if batch_size is not None: + self.set_batch_size(batch_size) + def _build_query_string(self) -> str: """Build the full query string for vector search with optional filtering.""" filter_expression = self._filter_expression if isinstance(filter_expression, FilterExpression): filter_expression = str(filter_expression) - return f"{filter_expression}=>[KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM} AS {self.DISTANCE_ID}]" + + # Base KNN query + knn_query = ( + f"KNN {self._num_results} @{self._vector_field_name} ${self.VECTOR_PARAM}" + ) + + # Add hybrid policy parameters if specified + if self._hybrid_policy: + knn_query += f" HYBRID_POLICY {self._hybrid_policy}" + + # Add batch size if specified and using BATCHES policy + if self._hybrid_policy == "BATCHES" and self._batch_size: + knn_query += f" BATCH_SIZE {self._batch_size}" + + # Add distance field alias + knn_query += f" AS {self.DISTANCE_ID}" + + return f"{filter_expression}=>[{knn_query}]" + + def set_hybrid_policy(self, hybrid_policy: str): + """Set the hybrid policy for the query. + + Args: + hybrid_policy (str): The hybrid policy to use. Options are "BATCHES" + or "ADHOC_BF". + + Raises: + ValueError: If hybrid_policy is not one of the valid options + """ + if hybrid_policy not in {"BATCHES", "ADHOC_BF"}: + raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}") + self._hybrid_policy = hybrid_policy + + # Reset the query string + self._query_string = self._build_query_string() + + def set_batch_size(self, batch_size: int): + """Set the batch size for the query. + + Args: + batch_size (int): The batch size to use when hybrid_policy is "BATCHES". + + Raises: + TypeError: If batch_size is not an integer + ValueError: If batch_size is not positive + """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size <= 0: + raise ValueError("batch_size must be positive") + self._batch_size = batch_size + + # Reset the query string + self._query_string = self._build_query_string() + + @property + def hybrid_policy(self) -> Optional[str]: + """Return the hybrid policy for the query. + + Returns: + Optional[str]: The hybrid policy for the query. + """ + return self._hybrid_policy + + @property + def batch_size(self) -> Optional[int]: + """Return the batch size for the query. + + Returns: + Optional[int]: The batch size for the query. + """ + return self._batch_size @property def params(self) -> Dict[str, Any]: @@ -265,11 +356,16 @@ def params(self) -> Dict[str, Any]: else: vector = array_to_buffer(self._vector, dtype=self._dtype) - return {self.VECTOR_PARAM: vector} + params = {self.VECTOR_PARAM: vector} + + return params class VectorRangeQuery(BaseVectorQuery, BaseQuery): DISTANCE_THRESHOLD_PARAM: str = "distance_threshold" + EPSILON_PARAM: str = "EPSILON" # Parameter name for epsilon + HYBRID_POLICY_PARAM: str = "HYBRID_POLICY" # Parameter name for hybrid policy + BATCH_SIZE_PARAM: str = "BATCH_SIZE" # Parameter name for batch size def __init__( self, @@ -279,11 +375,14 @@ def __init__( filter_expression: Optional[Union[str, FilterExpression]] = None, dtype: str = "float32", distance_threshold: float = 0.2, + epsilon: Optional[float] = None, num_results: int = 10, return_score: bool = True, dialect: int = 2, sort_by: Optional[str] = None, in_order: bool = False, + hybrid_policy: Optional[str] = None, + batch_size: Optional[int] = None, ): """A query for running a filtered vector search based on semantic distance threshold. @@ -298,9 +397,14 @@ def __init__( along with the range query. Defaults to None. dtype (str, optional): The dtype of the vector. Defaults to "float32". - distance_threshold (str, float): The threshold for vector distance. + distance_threshold (float): The threshold for vector distance. A smaller threshold indicates a stricter semantic search. Defaults to 0.2. + epsilon (Optional[float]): The relative factor for vector range queries, + setting boundaries for candidates within radius * (1 + epsilon). + This controls how extensive the search is beyond the specified radius. + Higher values increase recall at the expense of performance. + Defaults to None, which uses the index-defined epsilon (typically 0.01). num_results (int): The MAX number of results to return. Defaults to 10. return_score (bool, optional): Whether to return the vector @@ -312,18 +416,35 @@ def __init__( in_order (bool): Requires the terms in the field to have the same order as the terms in the query filter, regardless of the offsets between them. Defaults to False. - - Raises: - TypeError: If filter_expression is not of type redisvl.query.FilterExpression - - Note: - Learn more about vector range queries: https://redis.io/docs/interact/search-and-query/search/vectors/#range-query - + hybrid_policy (Optional[str]): Controls how filters are applied during vector search. + Options are "BATCHES" (paginates through small batches of nearest neighbors) or + "ADHOC_BF" (computes scores for all vectors passing the filter). + "BATCHES" mode is typically faster for queries with selective filters. + "ADHOC_BF" mode is better when filters match a large portion of the dataset. + Defaults to None, which lets Redis auto-select the optimal policy. + batch_size (Optional[int]): When hybrid_policy is "BATCHES", controls the number + of vectors to fetch in each batch. Larger values may improve performance + at the cost of memory usage. Only applies when hybrid_policy="BATCHES". + Defaults to None, which lets Redis auto-select an appropriate batch size. """ self._vector = vector self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results + self._distance_threshold: float = 0.2 # Initialize with default + self._epsilon: Optional[float] = None + self._hybrid_policy: Optional[str] = None + self._batch_size: Optional[int] = None + + if epsilon is not None: + self.set_epsilon(epsilon) + + if hybrid_policy is not None: + self.set_hybrid_policy(hybrid_policy) + + if batch_size is not None: + self.set_batch_size(batch_size) + self.set_distance_threshold(distance_threshold) self.set_filter(filter_expression) query_string = self._build_query_string() @@ -347,27 +468,104 @@ def __init__( if in_order: self.in_order() + def set_distance_threshold(self, distance_threshold: float): + """Set the distance threshold for the query. + + Args: + distance_threshold (float): Vector distance threshold. + + Raises: + TypeError: If distance_threshold is not a float or int + ValueError: If distance_threshold is negative + """ + if not isinstance(distance_threshold, (float, int)): + raise TypeError("distance_threshold must be of type float or int") + if distance_threshold < 0: + raise ValueError("distance_threshold must be non-negative") + self._distance_threshold = distance_threshold + + # Reset the query string + self._query_string = self._build_query_string() + + def set_epsilon(self, epsilon: float): + """Set the epsilon parameter for the range query. + + Args: + epsilon (float): The relative factor for vector range queries, + setting boundaries for candidates within radius * (1 + epsilon). + + Raises: + TypeError: If epsilon is not a float or int + ValueError: If epsilon is negative + """ + if not isinstance(epsilon, (float, int)): + raise TypeError("epsilon must be of type float or int") + if epsilon < 0: + raise ValueError("epsilon must be non-negative") + self._epsilon = epsilon + + # Reset the query string + self._query_string = self._build_query_string() + + def set_hybrid_policy(self, hybrid_policy: str): + """Set the hybrid policy for the query. + + Args: + hybrid_policy (str): The hybrid policy to use. Options are "BATCHES" + or "ADHOC_BF". + + Raises: + ValueError: If hybrid_policy is not one of the valid options + """ + if hybrid_policy not in {"BATCHES", "ADHOC_BF"}: + raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}") + self._hybrid_policy = hybrid_policy + + # Reset the query string + self._query_string = self._build_query_string() + + def set_batch_size(self, batch_size: int): + """Set the batch size for the query. + + Args: + batch_size (int): The batch size to use when hybrid_policy is "BATCHES". + + Raises: + TypeError: If batch_size is not an integer + ValueError: If batch_size is not positive + """ + if not isinstance(batch_size, int): + raise TypeError("batch_size must be an integer") + if batch_size <= 0: + raise ValueError("batch_size must be positive") + self._batch_size = batch_size + + # Reset the query string + self._query_string = self._build_query_string() + def _build_query_string(self) -> str: """Build the full query string for vector range queries with optional filtering""" + # Build base query with vector range only base_query = f"@{self._vector_field_name}:[VECTOR_RANGE ${self.DISTANCE_THRESHOLD_PARAM} ${self.VECTOR_PARAM}]" + # Build query attributes section + attr_parts = [] + attr_parts.append(f"$YIELD_DISTANCE_AS: {self.DISTANCE_ID}") + + if self._epsilon is not None: + attr_parts.append(f"$EPSILON: {self._epsilon}") + + # Add query attributes section + attr_section = f"=>{{{'; '.join(attr_parts)}}}" + + # Add filter expression if present filter_expression = self._filter_expression if isinstance(filter_expression, FilterExpression): filter_expression = str(filter_expression) if filter_expression == "*": - return f"{base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}}" - return f"({base_query}=>{{$yield_distance_as: {self.DISTANCE_ID}}} {filter_expression})" - - def set_distance_threshold(self, distance_threshold: float): - """Set the distance threshold for the query. - - Args: - distance_threshold (float): vector distance - """ - if not isinstance(distance_threshold, (float, int)): - raise TypeError("distance_threshold must be of type int or float") - self._distance_threshold = distance_threshold + return f"{base_query}{attr_section}" + return f"({base_query}{attr_section} {filter_expression})" @property def distance_threshold(self) -> float: @@ -378,6 +576,33 @@ def distance_threshold(self) -> float: """ return self._distance_threshold + @property + def epsilon(self) -> Optional[float]: + """Return the epsilon for the query. + + Returns: + Optional[float]: The epsilon for the query, or None if not set. + """ + return self._epsilon + + @property + def hybrid_policy(self) -> Optional[str]: + """Return the hybrid policy for the query. + + Returns: + Optional[str]: The hybrid policy for the query. + """ + return self._hybrid_policy + + @property + def batch_size(self) -> Optional[int]: + """Return the batch size for the query. + + Returns: + Optional[int]: The batch size for the query. + """ + return self._batch_size + @property def params(self) -> Dict[str, Any]: """Return the parameters for the query. @@ -390,11 +615,20 @@ def params(self) -> Dict[str, Any]: else: vector_param = array_to_buffer(self._vector, dtype=self._dtype) - return { + params = { self.VECTOR_PARAM: vector_param, self.DISTANCE_THRESHOLD_PARAM: self._distance_threshold, } + # Add hybrid policy and batch size as query parameters (not in query string) + if self._hybrid_policy: + params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy + + if self._hybrid_policy == "BATCHES" and self._batch_size: + params[self.BATCH_SIZE_PARAM] = self._batch_size + + return params + class RangeQuery(VectorRangeQuery): # keep for backwards compatibility diff --git a/tests/integration/test_query.py b/tests/integration/test_query.py index deb58cbc..3f30c02e 100644 --- a/tests/integration/test_query.py +++ b/tests/integration/test_query.py @@ -14,6 +14,7 @@ Text, Timestamp, ) +from redisvl.query.query import VectorRangeQuery from redisvl.redis.utils import array_to_buffer # TODO expand to multiple schema types and sync + async @@ -531,3 +532,130 @@ def test_query_with_chunk_number_zero(): assert ( str(filter_conditions) == expected_query_str ), "Query with chunk_number zero is incorrect" + + +def test_hybrid_policy_batches_mode(index, vector_query): + """Test vector query with BATCHES hybrid policy.""" + # Create a filter + t = Tag("credit_score") == "high" + + # Set hybrid policy to BATCHES + vector_query.set_hybrid_policy("BATCHES") + vector_query.set_batch_size(2) + + # Set the filter + vector_query.set_filter(t) + + # Check query string + assert "HYBRID_POLICY BATCHES BATCH_SIZE 2" in str(vector_query) + + # Execute query + results = index.query(vector_query) + + # Check results - should have filtered to "high" credit scores + assert len(results) > 0 + for result in results: + assert result["credit_score"] == "high" + + +def test_hybrid_policy_adhoc_bf_mode(index, vector_query): + """Test vector query with ADHOC_BF hybrid policy.""" + # Create a filter + t = Tag("credit_score") == "high" + + # Set hybrid policy to ADHOC_BF + vector_query.set_hybrid_policy("ADHOC_BF") + + # Set the filter + vector_query.set_filter(t) + + # Check query string + assert "HYBRID_POLICY ADHOC_BF" in str(vector_query) + + # Execute query + results = index.query(vector_query) + + # Check results - should have filtered to "high" credit scores + assert len(results) > 0 + for result in results: + assert result["credit_score"] == "high" + + +def test_range_query_with_epsilon(index): + """Integration test: Execute range query with epsilon parameter against Redis.""" + # Create a range query with epsilon + epsilon_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job"], + distance_threshold=0.3, + epsilon=0.5, # Larger than default to get potentially more results + ) + + # Verify query string contains epsilon attribute + query_string = str(epsilon_query) + assert "$EPSILON: 0.5" in query_string + + # Verify epsilon property is set + assert epsilon_query.epsilon == 0.5 + + # Test setting epsilon + epsilon_query.set_epsilon(0.1) + assert epsilon_query.epsilon == 0.1 + assert "$EPSILON: 0.1" in str(epsilon_query) + + # Execute basic query without epsilon to ensure functionality + basic_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job"], + distance_threshold=0.2, + ) + + results = index.query(basic_query) + + # Check results + for result in results: + assert float(result["vector_distance"]) <= 0.2 + + +def test_range_query_with_filter_and_hybrid_policy(index): + """Integration test: Test construction of a range query with filter and hybrid policy.""" + # Create a filter for high credit score + credit_filter = Tag("credit_score") == "high" + + # Create a range query with filter and hybrid policy + query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job"], + filter_expression=credit_filter, + distance_threshold=0.5, + hybrid_policy="BATCHES", + batch_size=2, + ) + + # Check query string and parameters + query_string = str(query) + assert "@credit_score:{high}" in query_string + assert "HYBRID_POLICY" not in query_string + assert query.hybrid_policy == "BATCHES" + assert query.batch_size == 2 + assert query.params["HYBRID_POLICY"] == "BATCHES" + assert query.params["BATCH_SIZE"] == 2 + + # Execute basic query with filter but without hybrid policy + basic_filter_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score", "age", "job"], + filter_expression=credit_filter, + distance_threshold=0.5, + ) + + results = index.query(basic_filter_query) + + # Check results + for result in results: + assert result["credit_score"] == "high" + assert float(result["vector_distance"]) <= 0.5 diff --git a/tests/unit/test_query_types.py b/tests/unit/test_query_types.py index 1e9fdb08..bceaa215 100644 --- a/tests/unit/test_query_types.py +++ b/tests/unit/test_query_types.py @@ -5,6 +5,7 @@ from redisvl.index.index import process_results from redisvl.query import CountQuery, FilterQuery, RangeQuery, VectorQuery from redisvl.query.filter import Tag +from redisvl.query.query import VectorRangeQuery # Sample data for testing sample_vector = [0.1, 0.2, 0.3, 0.4] @@ -48,8 +49,8 @@ def test_filter_query(): assert isinstance(filter_query.params, dict) assert filter_query.params == {} assert filter_query._dialect == 2 - assert filter_query._sortby == None - assert filter_query._in_order == False + assert filter_query._sortby is None + assert filter_query._in_order is False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -92,7 +93,7 @@ def test_vector_query(): assert vector_query.params != {} assert vector_query._dialect == 3 assert vector_query._sortby.args[0] == VectorQuery.DISTANCE_ID - assert vector_query._in_order == False + assert vector_query._in_order is False # Test set_filter functionality new_filter_expression = Tag("category") == "Sportswear" @@ -277,3 +278,271 @@ def test_string_filter_expressions(query): query.set_filter("~(@desciption:(hello | world))") assert query._filter_expression == "~(@desciption:(hello | world))" assert query.query_string().__contains__("~(@desciption:(hello | world))") + + +def test_vector_query_hybrid_policy(): + """Test that VectorQuery correctly handles hybrid policy parameters.""" + # Create a vector query with hybrid policy + vector_query = VectorQuery( + [0.1, 0.2, 0.3, 0.4], "vector_field", hybrid_policy="BATCHES" + ) + + # Check properties + assert vector_query.hybrid_policy == "BATCHES" + assert vector_query.batch_size is None + + # Check query string + query_string = str(vector_query) + assert "HYBRID_POLICY BATCHES" in query_string + + # Test with batch size + vector_query = VectorQuery( + [0.1, 0.2, 0.3, 0.4], "vector_field", hybrid_policy="BATCHES", batch_size=50 + ) + + # Check properties + assert vector_query.hybrid_policy == "BATCHES" + assert vector_query.batch_size == 50 + + # Check query string + query_string = str(vector_query) + assert "HYBRID_POLICY BATCHES BATCH_SIZE 50" in query_string + + # Test with ADHOC_BF policy + vector_query = VectorQuery( + [0.1, 0.2, 0.3, 0.4], "vector_field", hybrid_policy="ADHOC_BF" + ) + + # Check properties + assert vector_query.hybrid_policy == "ADHOC_BF" + + # Check query string + query_string = str(vector_query) + assert "HYBRID_POLICY ADHOC_BF" in query_string + + +def test_vector_query_set_hybrid_policy(): + """Test that VectorQuery setter methods work properly.""" + # Create a vector query + vector_query = VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field") + + # Initially no hybrid policy + assert vector_query.hybrid_policy is None + assert "HYBRID_POLICY" not in str(vector_query) + + # Set hybrid policy + vector_query.set_hybrid_policy("BATCHES") + + # Check properties + assert vector_query.hybrid_policy == "BATCHES" + + # Check query string + query_string = str(vector_query) + assert "HYBRID_POLICY BATCHES" in query_string + + # Set batch size + vector_query.set_batch_size(100) + + # Check properties + assert vector_query.batch_size == 100 + + # Check query string + query_string = str(vector_query) + assert "HYBRID_POLICY BATCHES BATCH_SIZE 100" in query_string + + +def test_vector_query_invalid_hybrid_policy(): + """Test error handling for invalid hybrid policy values.""" + # Test with invalid hybrid policy + with pytest.raises(ValueError, match=r"hybrid_policy must be one of.*"): + VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field", hybrid_policy="INVALID") + + # Create a valid vector query + vector_query = VectorQuery([0.1, 0.2, 0.3, 0.4], "vector_field") + + # Test with invalid hybrid policy + with pytest.raises(ValueError, match=r"hybrid_policy must be one of.*"): + vector_query.set_hybrid_policy("INVALID") + + # Test with invalid batch size types + with pytest.raises(TypeError, match="batch_size must be an integer"): + vector_query.set_batch_size("50") + + # Test with invalid batch size values + with pytest.raises(ValueError, match="batch_size must be positive"): + vector_query.set_batch_size(0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + vector_query.set_batch_size(-10) + + +def test_vector_range_query_epsilon(): + """Test that VectorRangeQuery correctly handles epsilon parameter.""" + # Create a range query with epsilon + range_query = VectorRangeQuery( + [0.1, 0.2, 0.3, 0.4], "vector_field", epsilon=0.05, distance_threshold=0.3 + ) + + # Check properties + assert range_query.epsilon == 0.05 + assert range_query.distance_threshold == 0.3 + + # Check query string + query_string = str(range_query) + assert "$EPSILON: 0.05" in query_string + + # Test setting epsilon + range_query.set_epsilon(0.1) + assert range_query.epsilon == 0.1 + assert "$EPSILON: 0.1" in str(range_query) + + +def test_vector_range_query_invalid_epsilon(): + """Test error handling for invalid epsilon values.""" + # Test with invalid epsilon type + with pytest.raises(TypeError, match="epsilon must be of type float or int"): + VectorRangeQuery([0.1, 0.2, 0.3, 0.4], "vector_field", epsilon="0.05") + + # Test with negative epsilon + with pytest.raises(ValueError, match="epsilon must be non-negative"): + VectorRangeQuery([0.1, 0.2, 0.3, 0.4], "vector_field", epsilon=-0.05) + + # Create a valid range query + range_query = VectorRangeQuery([0.1, 0.2, 0.3, 0.4], "vector_field") + + # Test with invalid epsilon + with pytest.raises(TypeError, match="epsilon must be of type float or int"): + range_query.set_epsilon("0.05") + + with pytest.raises(ValueError, match="epsilon must be non-negative"): + range_query.set_epsilon(-0.05) + + +def test_vector_range_query_construction(): + """Unit test: Test the construction of VectorRangeQuery with various parameters.""" + # Basic range query + basic_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score"], + distance_threshold=0.2, + ) + + query_string = str(basic_query) + assert "VECTOR_RANGE $distance_threshold $vector" in query_string + assert "$YIELD_DISTANCE_AS: vector_distance" in query_string + assert "HYBRID_POLICY" not in query_string + + # Range query with epsilon + epsilon_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score"], + distance_threshold=0.2, + epsilon=0.05, + ) + + query_string = str(epsilon_query) + assert "VECTOR_RANGE $distance_threshold $vector" in query_string + assert "$YIELD_DISTANCE_AS: vector_distance" in query_string + assert "$EPSILON: 0.05" in query_string + assert epsilon_query.epsilon == 0.05 + assert "EPSILON" not in epsilon_query.params + + # Range query with hybrid policy + hybrid_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score"], + distance_threshold=0.2, + hybrid_policy="BATCHES", + ) + + query_string = str(hybrid_query) + # Hybrid policy should not be in the query string + assert "HYBRID_POLICY" not in query_string + assert hybrid_query.hybrid_policy == "BATCHES" + assert hybrid_query.params["HYBRID_POLICY"] == "BATCHES" + + # Range query with hybrid policy and batch size + batch_query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + return_fields=["user", "credit_score"], + distance_threshold=0.2, + hybrid_policy="BATCHES", + batch_size=50, + ) + + query_string = str(batch_query) + # Hybrid policy and batch size should not be in the query string + assert "HYBRID_POLICY" not in query_string + assert "BATCH_SIZE" not in query_string + assert batch_query.hybrid_policy == "BATCHES" + assert batch_query.batch_size == 50 + assert batch_query.params["HYBRID_POLICY"] == "BATCHES" + assert batch_query.params["BATCH_SIZE"] == 50 + + +def test_vector_range_query_setter_methods(): + """Unit test: Test setter methods for VectorRangeQuery parameters.""" + # Create a basic query + query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + distance_threshold=0.2, + ) + + # Verify initial state + assert query.epsilon is None + assert query.hybrid_policy is None + assert query.batch_size is None + assert "$EPSILON" not in str(query) + assert "HYBRID_POLICY" not in query.params + assert "BATCH_SIZE" not in query.params + + # Set epsilon + query.set_epsilon(0.1) + assert query.epsilon == 0.1 + assert "$EPSILON: 0.1" in str(query) + + # Set hybrid policy + query.set_hybrid_policy("BATCHES") + assert query.hybrid_policy == "BATCHES" + assert query.params["HYBRID_POLICY"] == "BATCHES" + + # Set batch size + query.set_batch_size(25) + assert query.batch_size == 25 + assert query.params["BATCH_SIZE"] == 25 + + +def test_vector_range_query_error_handling(): + """Unit test: Test error handling for invalid VectorRangeQuery parameters.""" + # Create a basic query + query = VectorRangeQuery( + vector=[0.1, 0.1, 0.5], + vector_field_name="user_embedding", + distance_threshold=0.2, + ) + + # Test invalid epsilon + with pytest.raises(TypeError, match="epsilon must be of type float or int"): + query.set_epsilon("0.1") + + with pytest.raises(ValueError, match="epsilon must be non-negative"): + query.set_epsilon(-0.1) + + # Test invalid hybrid policy + with pytest.raises(ValueError, match="hybrid_policy must be one of"): + query.set_hybrid_policy("INVALID") + + # Test invalid batch size + with pytest.raises(TypeError, match="batch_size must be an integer"): + query.set_batch_size(10.5) + + with pytest.raises(ValueError, match="batch_size must be positive"): + query.set_batch_size(0) + + with pytest.raises(ValueError, match="batch_size must be positive"): + query.set_batch_size(-10) From cba2998afb96585b8e5c4e9b677b1efdccad9882 Mon Sep 17 00:00:00 2001 From: Tyler Hutcherson Date: Tue, 25 Mar 2025 09:13:12 -0400 Subject: [PATCH 2/2] use enum for hybrid policy --- redisvl/query/query.py | 42 ++++++++++++++++++++++++++++-------------- 1 file changed, 28 insertions(+), 14 deletions(-) diff --git a/redisvl/query/query.py b/redisvl/query/query.py index 8182952e..7dbddc91 100644 --- a/redisvl/query/query.py +++ b/redisvl/query/query.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Any, Dict, List, Optional, Union from redis.commands.search.query import Query as RedisQuery @@ -175,6 +176,13 @@ class BaseVectorQuery: VECTOR_PARAM: str = "vector" +class HybridPolicy(str, Enum): + """Enum for valid hybrid policy options in vector queries.""" + + BATCHES = "BATCHES" + ADHOC_BF = "ADHOC_BF" + + class VectorQuery(BaseVectorQuery, BaseQuery): def __init__( self, @@ -236,7 +244,7 @@ def __init__( self._vector_field_name = vector_field_name self._dtype = dtype self._num_results = num_results - self._hybrid_policy: Optional[str] = None + self._hybrid_policy: Optional[HybridPolicy] = None self._batch_size: Optional[int] = None self.set_filter(filter_expression) query_string = self._build_query_string() @@ -279,10 +287,10 @@ def _build_query_string(self) -> str: # Add hybrid policy parameters if specified if self._hybrid_policy: - knn_query += f" HYBRID_POLICY {self._hybrid_policy}" + knn_query += f" HYBRID_POLICY {self._hybrid_policy.value}" # Add batch size if specified and using BATCHES policy - if self._hybrid_policy == "BATCHES" and self._batch_size: + if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size: knn_query += f" BATCH_SIZE {self._batch_size}" # Add distance field alias @@ -300,9 +308,12 @@ def set_hybrid_policy(self, hybrid_policy: str): Raises: ValueError: If hybrid_policy is not one of the valid options """ - if hybrid_policy not in {"BATCHES", "ADHOC_BF"}: - raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}") - self._hybrid_policy = hybrid_policy + try: + self._hybrid_policy = HybridPolicy(hybrid_policy) + except ValueError: + raise ValueError( + f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}" + ) # Reset the query string self._query_string = self._build_query_string() @@ -333,7 +344,7 @@ def hybrid_policy(self) -> Optional[str]: Returns: Optional[str]: The hybrid policy for the query. """ - return self._hybrid_policy + return self._hybrid_policy.value if self._hybrid_policy else None @property def batch_size(self) -> Optional[int]: @@ -433,7 +444,7 @@ def __init__( self._num_results = num_results self._distance_threshold: float = 0.2 # Initialize with default self._epsilon: Optional[float] = None - self._hybrid_policy: Optional[str] = None + self._hybrid_policy: Optional[HybridPolicy] = None self._batch_size: Optional[int] = None if epsilon is not None: @@ -517,9 +528,12 @@ def set_hybrid_policy(self, hybrid_policy: str): Raises: ValueError: If hybrid_policy is not one of the valid options """ - if hybrid_policy not in {"BATCHES", "ADHOC_BF"}: - raise ValueError("hybrid_policy must be one of {'BATCHES', 'ADHOC_BF'}") - self._hybrid_policy = hybrid_policy + try: + self._hybrid_policy = HybridPolicy(hybrid_policy) + except ValueError: + raise ValueError( + f"hybrid_policy must be one of {', '.join([p.value for p in HybridPolicy])}" + ) # Reset the query string self._query_string = self._build_query_string() @@ -592,7 +606,7 @@ def hybrid_policy(self) -> Optional[str]: Returns: Optional[str]: The hybrid policy for the query. """ - return self._hybrid_policy + return self._hybrid_policy.value if self._hybrid_policy else None @property def batch_size(self) -> Optional[int]: @@ -622,9 +636,9 @@ def params(self) -> Dict[str, Any]: # Add hybrid policy and batch size as query parameters (not in query string) if self._hybrid_policy: - params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy + params[self.HYBRID_POLICY_PARAM] = self._hybrid_policy.value - if self._hybrid_policy == "BATCHES" and self._batch_size: + if self._hybrid_policy == HybridPolicy.BATCHES and self._batch_size: params[self.BATCH_SIZE_PARAM] = self._batch_size return params