Skip to content

Commit 6da10a8

Browse files
committed
Add linear hybrid search ranker
1 parent c944dca commit 6da10a8

File tree

6 files changed

+174
-4
lines changed

6 files changed

+174
-4
lines changed

src/neo4j_graphrag/exceptions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,7 @@ class PdfLoaderError(Neo4jGraphRagError):
124124

125125
class PromptMissingPlaceholderError(Neo4jGraphRagError):
126126
"""Exception raised when a prompt is missing an expected placeholder."""
127+
128+
129+
class InvalidHybridSearchRankerError(Neo4jGraphRagError):
130+
"""Exception raised when an invalid ranker type for Hybrid Search is provided."""

src/neo4j_graphrag/neo4j_queries.py

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,11 @@
1515
from __future__ import annotations
1616

1717
import warnings
18-
from typing import Any, Optional
18+
from typing import Any, Optional, Union
1919

20+
from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError
2021
from neo4j_graphrag.filters import get_metadata_filter
21-
from neo4j_graphrag.types import EntityType, SearchType
22+
from neo4j_graphrag.types import EntityType, SearchType, HybridSearchRanker
2223

2324
NODE_VECTOR_INDEX_QUERY = (
2425
"CALL db.index.vector.queryNodes"
@@ -171,6 +172,50 @@ def _get_hybrid_query(neo4j_version_is_5_23_or_above: bool) -> str:
171172
return call_prefix + query_body
172173

173174

175+
def _get_hybrid_query_linear(neo4j_version_is_5_23_or_above: bool, alpha: float) -> str:
176+
"""
177+
Construct a Cypher query for hybrid search using a linear combination approach with an alpha parameter.
178+
179+
This query retrieves normalized scores from both the vector index and full-text index. It then
180+
computes the final score as a weighted sum:
181+
182+
```
183+
final_score = alpha * (vector normalized score) + (1 - alpha) * (fulltext normalized score)
184+
```
185+
186+
If a node appears in only one index, the missing score is treated as 0.
187+
188+
Args:
189+
neo4j_version_is_5_23_or_above (bool): Whether the Neo4j version is 5.23 or above; determines the call syntax.
190+
alpha (float): Weight for the vector index normalized score. The full-text score is weighted as (1 - alpha).
191+
192+
Returns:
193+
str: The constructed Cypher query string.
194+
"""
195+
call_prefix = "CALL () { " if neo4j_version_is_5_23_or_above else "CALL { "
196+
197+
query_body = (
198+
f"{NODE_VECTOR_INDEX_QUERY} "
199+
"WITH collect({node: node, score: score}) AS nodes, max(score) AS vector_index_max_score "
200+
"UNWIND nodes AS n "
201+
"RETURN n.node AS node, (n.score / vector_index_max_score) AS score, 'vector' AS source "
202+
"UNION "
203+
f"{FULL_TEXT_SEARCH_QUERY} "
204+
"WITH collect({node: node, score: score}) AS nodes, max(score) AS ft_index_max_score "
205+
"UNWIND nodes AS n "
206+
"RETURN n.node AS node, (n.score / ft_index_max_score) AS score, 'ft' AS source } "
207+
"WITH node, "
208+
"sum(CASE WHEN source = 'vector' THEN score * "
209+
+ str(alpha)
210+
+ " ELSE 0 END) + "
211+
+ "sum(CASE WHEN source = 'ft' THEN score * "
212+
+ str(1 - alpha)
213+
+ " ELSE 0 END) AS score "
214+
"ORDER BY score DESC LIMIT $top_k"
215+
)
216+
return call_prefix + query_body
217+
218+
174219
def _get_filtered_vector_query(
175220
filters: dict[str, Any],
176221
node_label: str,
@@ -223,6 +268,8 @@ def get_search_query(
223268
filters: Optional[dict[str, Any]] = None,
224269
neo4j_version_is_5_23_or_above: bool = False,
225270
use_parallel_runtime: bool = False,
271+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
272+
alpha: Optional[float] = None,
226273
) -> tuple[str, dict[str, Any]]:
227274
"""
228275
Constructs a search query for vector or hybrid search, including optional pre-filtering
@@ -243,6 +290,8 @@ def get_search_query(
243290
neo4j_version_is_5_23_or_above (Optional[bool]): Whether the Neo4j version is 5.23 or above.
244291
use_parallel_runtime (bool): Whether or not use the parallel runtime to run the query.
245292
Defaults to False.
293+
ranker (HybridSearchRanker): Type of ranker to order the results from retrieval.
294+
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
246295
247296
Returns:
248297
tuple[str, dict[str, Any]]: A tuple containing the constructed query string and
@@ -262,7 +311,14 @@ def get_search_query(
262311
if search_type == SearchType.HYBRID:
263312
if filters:
264313
raise Exception("Filters are not supported with hybrid search")
265-
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
314+
if ranker == HybridSearchRanker.NAIVE:
315+
query = _get_hybrid_query(neo4j_version_is_5_23_or_above)
316+
elif ranker == HybridSearchRanker.LINEAR and alpha:
317+
query = _get_hybrid_query_linear(
318+
neo4j_version_is_5_23_or_above, alpha=alpha
319+
)
320+
else:
321+
raise InvalidHybridSearchRankerError()
266322
params: dict[str, Any] = {}
267323
elif search_type == SearchType.VECTOR:
268324
if filters:

src/neo4j_graphrag/retrievers/hybrid.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import copy
1818
import logging
19-
from typing import Any, Callable, Optional
19+
from typing import Any, Callable, Optional, Union
2020

2121
import neo4j
2222
from pydantic import ValidationError
@@ -39,6 +39,7 @@
3939
RawSearchResult,
4040
RetrieverResultItem,
4141
SearchType,
42+
HybridSearchRanker,
4243
)
4344

4445
logger = logging.getLogger(__name__)
@@ -142,6 +143,8 @@ def get_search_results(
142143
query_vector: Optional[list[float]] = None,
143144
top_k: int = 5,
144145
effective_search_ratio: int = 1,
146+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
147+
alpha: Optional[float] = None,
145148
) -> RawSearchResult:
146149
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
147150
Both query_vector and query_text can be provided.
@@ -162,6 +165,8 @@ def get_search_results(
162165
top_k (int, optional): The number of neighbors to return. Defaults to 5.
163166
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
164167
accuracy and performance. Defaults to 1.
168+
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
169+
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
165170
166171
Raises:
167172
SearchValidationError: If validation of the input arguments fail.
@@ -176,6 +181,8 @@ def get_search_results(
176181
query_text=query_text,
177182
top_k=top_k,
178183
effective_search_ratio=effective_search_ratio,
184+
ranker=ranker,
185+
alpha=alpha,
179186
)
180187
except ValidationError as e:
181188
raise SearchValidationError(e.errors()) from e
@@ -197,7 +204,15 @@ def get_search_results(
197204
return_properties=self.return_properties,
198205
embedding_node_property=self._embedding_node_property,
199206
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
207+
ranker=validated_data.ranker,
208+
alpha=validated_data.alpha,
200209
)
210+
211+
if "ranker" in parameters:
212+
del parameters["ranker"]
213+
if "alpha" in parameters:
214+
del parameters["alpha"]
215+
201216
sanitized_parameters = copy.deepcopy(parameters)
202217
if "query_vector" in sanitized_parameters:
203218
sanitized_parameters["query_vector"] = "..."
@@ -301,6 +316,8 @@ def get_search_results(
301316
top_k: int = 5,
302317
effective_search_ratio: int = 1,
303318
query_params: Optional[dict[str, Any]] = None,
319+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE,
320+
alpha: Optional[float] = None,
304321
) -> RawSearchResult:
305322
"""Get the top_k nearest neighbor embeddings for either provided query_vector or query_text.
306323
Both query_vector and query_text can be provided.
@@ -320,6 +337,8 @@ def get_search_results(
320337
effective_search_ratio (int): Controls the candidate pool size for the vector index by multiplying top_k to balance query
321338
accuracy and performance. Defaults to 1.
322339
query_params (Optional[dict[str, Any]]): Parameters for the Cypher query. Defaults to None.
340+
ranker (str, HybridSearchRanker): Type of ranker to order the results from retrieval.
341+
alpha (Optional[float]): Weight for the vector score when using the linear ranker. Only used when ranker is 'linear'. Defaults to 0.5 if not provided.
323342
324343
Raises:
325344
SearchValidationError: If validation of the input arguments fail.
@@ -334,6 +353,8 @@ def get_search_results(
334353
query_text=query_text,
335354
top_k=top_k,
336355
effective_search_ratio=effective_search_ratio,
356+
ranker=ranker,
357+
alpha=alpha,
337358
query_params=query_params,
338359
)
339360
except ValidationError as e:
@@ -361,7 +382,15 @@ def get_search_results(
361382
search_type=SearchType.HYBRID,
362383
retrieval_query=self.retrieval_query,
363384
neo4j_version_is_5_23_or_above=self.neo4j_version_is_5_23_or_above,
385+
ranker=validated_data.ranker,
386+
alpha=validated_data.alpha,
364387
)
388+
389+
if "ranker" in parameters:
390+
del parameters["ranker"]
391+
if "alpha" in parameters:
392+
del parameters["alpha"]
393+
365394
sanitized_parameters = copy.deepcopy(parameters)
366395
if "query_vector" in sanitized_parameters:
367396
sanitized_parameters["query_vector"] = "..."

src/neo4j_graphrag/types.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
from __future__ import annotations
1616

17+
import warnings
1718
from enum import Enum
1819
from typing import Any, Callable, Literal, Optional, Union
1920

@@ -137,11 +138,53 @@ class VectorCypherSearchModel(VectorSearchModel):
137138
query_params: Optional[dict[str, Any]] = None
138139

139140

141+
class HybridSearchRanker(Enum):
142+
"""Enumerator of Hybrid search rankers."""
143+
144+
NAIVE = "naive"
145+
LINEAR = "linear"
146+
147+
140148
class HybridSearchModel(BaseModel):
141149
query_text: str
142150
query_vector: Optional[list[float]] = None
143151
top_k: PositiveInt = 5
144152
effective_search_ratio: PositiveInt = 1
153+
ranker: Union[str, HybridSearchRanker] = HybridSearchRanker.NAIVE
154+
alpha: Optional[float] = None
155+
156+
@field_validator("ranker", mode="before")
157+
def validate_ranker(cls, v: Union[str, HybridSearchRanker]) -> HybridSearchRanker:
158+
if isinstance(v, str):
159+
try:
160+
return HybridSearchRanker(v.lower())
161+
except ValueError:
162+
allowed = ", ".join([r.value for r in HybridSearchRanker])
163+
raise ValueError(
164+
f"Invalid ranker value. Allowed values are: {allowed}."
165+
)
166+
elif isinstance(v, HybridSearchRanker):
167+
return v
168+
else:
169+
allowed = ", ".join([r.value for r in HybridSearchRanker])
170+
raise ValueError(f"Invalid ranker type. Allowed values are: {allowed}.")
171+
172+
@model_validator(mode="before")
173+
def validate_alpha(cls, values: dict[str, Any]) -> dict[str, Any]:
174+
ranker, alpha = values.get("ranker"), values.get("alpha")
175+
if ranker == HybridSearchRanker.LINEAR:
176+
if alpha is None:
177+
values["alpha"] = 0.5
178+
if isinstance(alpha, float) and not (0.0 <= alpha <= 1.0):
179+
raise ValueError("alpha must be between 0 and 1")
180+
else:
181+
if alpha is not None:
182+
warnings.warn(
183+
"alpha parameter is only used when ranker is 'linear'. Ignoring alpha.",
184+
UserWarning,
185+
)
186+
values["alpha"] = None
187+
return values
145188

146189

147190
class HybridCypherSearchModel(HybridSearchModel):

tests/e2e/test_hybrid_e2e.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,3 +176,25 @@ def test_hybrid_retriever_return_properties(driver: Driver) -> None:
176176
assert len(results.items) == 5
177177
for result in results.items:
178178
assert isinstance(result, RetrieverResultItem)
179+
180+
181+
@pytest.mark.usefixtures("setup_neo4j_for_retrieval")
182+
def test_hybrid_retriever_search_text_linear_ranker(
183+
driver: Driver, random_embedder: Embedder
184+
) -> None:
185+
retriever = HybridRetriever(
186+
driver, "vector-index-name", "fulltext-index-name", random_embedder
187+
)
188+
189+
top_k = 5
190+
effective_search_ratio = 2
191+
results = retriever.search(
192+
query_text="Find me a book about Fremen",
193+
top_k=top_k,
194+
effective_search_ratio=effective_search_ratio,
195+
)
196+
197+
assert isinstance(results, RetrieverResult)
198+
assert len(results.items) == 5
199+
for result in results.items:
200+
assert isinstance(result, RetrieverResultItem)

tests/unit/test_neo4j_queries.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@
1616
from unittest.mock import patch
1717

1818
import pytest
19+
20+
from neo4j_graphrag.exceptions import InvalidHybridSearchRankerError
1921
from neo4j_graphrag.neo4j_queries import (
2022
get_query_tail,
2123
get_search_query,
24+
_get_hybrid_query_linear,
2225
)
2326
from neo4j_graphrag.types import EntityType, SearchType
2427

@@ -249,3 +252,16 @@ def test_get_query_tail_ordering_no_retrieval_query() -> None:
249252
fallback_return=fallback,
250253
)
251254
assert result.strip() == expected.strip()
255+
256+
257+
def test_get_hybrid_query_linear_with_alpha() -> None:
258+
query = _get_hybrid_query_linear(neo4j_version_is_5_23_or_above=True, alpha=0.7)
259+
vector_substr = "CASE WHEN source = 'vector' THEN score * 0.7"
260+
ft_substr = "CASE WHEN source = 'ft' THEN score * 0.3"
261+
assert vector_substr in query
262+
assert ft_substr in query
263+
264+
265+
def test_invalid_hybrid_search_ranker_error() -> None:
266+
with pytest.raises(InvalidHybridSearchRankerError):
267+
get_search_query(SearchType.HYBRID, ranker="invalid")

0 commit comments

Comments
 (0)