1616
1717import copy
1818import logging
19- from typing import Any , Callable , Optional
19+ from typing import Any , Callable , Optional , Union
2020
2121import neo4j
2222from pydantic import ValidationError
3939 RawSearchResult ,
4040 RetrieverResultItem ,
4141 SearchType ,
42+ HybridSearchRanker ,
4243)
4344
4445logger = logging .getLogger (__name__ )
@@ -142,6 +143,8 @@ def get_search_results(
142143 query_vector : Optional [list [float ]] = None ,
143144 top_k : int = 5 ,
144145 effective_search_ratio : int = 1 ,
146+ ranker : Union [str , HybridSearchRanker ] = HybridSearchRanker .NAIVE ,
147+ alpha : Optional [float ] = None ,
145148 ) -> RawSearchResult :
146149 """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
147150 Both query_vector and query_text can be provided.
@@ -162,6 +165,8 @@ def get_search_results(
162165 top_k (int, optional): The number of neighbors to return. Defaults to 5.
163166 effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
164167 accuracy and performance. Defaults to 1.
168+ ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
169+ alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
165170
166171 Raises:
167172 SearchValidationError: If validation of the input arguments fail.
@@ -176,6 +181,8 @@ def get_search_results(
176181 query_text = query_text ,
177182 top_k = top_k ,
178183 effective_search_ratio = effective_search_ratio ,
184+ ranker = ranker ,
185+ alpha = alpha ,
179186 )
180187 except ValidationError as e :
181188 raise SearchValidationError (e .errors ()) from e
@@ -197,7 +204,15 @@ def get_search_results(
197204 return_properties = self .return_properties ,
198205 embedding_node_property = self ._embedding_node_property ,
199206 neo4j_version_is_5_23_or_above = self .neo4j_version_is_5_23_or_above ,
207+ ranker = validated_data .ranker ,
208+ alpha = validated_data .alpha ,
200209 )
210+
211+ if "ranker" in parameters :
212+ del parameters ["ranker" ]
213+ if "alpha" in parameters :
214+ del parameters ["alpha" ]
215+
201216 sanitized_parameters = copy .deepcopy (parameters )
202217 if "query_vector" in sanitized_parameters :
203218 sanitized_parameters ["query_vector" ] = "..."
@@ -301,6 +316,8 @@ def get_search_results(
301316 top_k : int = 5 ,
302317 effective_search_ratio : int = 1 ,
303318 query_params : Optional [dict [str , Any ]] = None ,
319+ ranker : Union [str , HybridSearchRanker ] = HybridSearchRanker .NAIVE ,
320+ alpha : Optional [float ] = None ,
304321 ) -> RawSearchResult :
305322 """Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
306323 Both query_vector and query_text can be provided.
@@ -320,6 +337,8 @@ def get_search_results(
320337 effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
321338 accuracy and performance. Defaults to 1.
322339 query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
340+ ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
341+ alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
323342
324343 Raises:
325344 SearchValidationError: If validation of the input arguments fail.
@@ -334,6 +353,8 @@ def get_search_results(
334353 query_text = query_text ,
335354 top_k = top_k ,
336355 effective_search_ratio = effective_search_ratio ,
356+ ranker = ranker ,
357+ alpha = alpha ,
337358 query_params = query_params ,
338359 )
339360 except ValidationError as e :
@@ -361,7 +382,15 @@ def get_search_results(
361382 search_type = SearchType .HYBRID ,
362383 retrieval_query = self .retrieval_query ,
363384 neo4j_version_is_5_23_or_above = self .neo4j_version_is_5_23_or_above ,
385+ ranker = validated_data .ranker ,
386+ alpha = validated_data .alpha ,
364387 )
388+
389+ if "ranker" in parameters :
390+ del parameters ["ranker" ]
391+ if "alpha" in parameters :
392+ del parameters ["alpha" ]
393+
365394 sanitized_parameters = copy .deepcopy (parameters )
366395 if "query_vector" in sanitized_parameters :
367396 sanitized_parameters ["query_vector" ] = "..."
0 commit comments