Skip to content

Commit

Permalink
Enable hybrid search (#794)
Browse files Browse the repository at this point in the history
Fixes ml6team/fondant-usecase-RAG#70

* No need to specify the vectorizer and it's module (needed for
embedding the hybrid search query), it will automatically use the one
specified in the indexing component based on the initial schema.
* Tested the functionality indirectly,
* Small fixes to default args of the indexing component
  • Loading branch information
PhilippeMoussalli authored Jan 19, 2024
1 parent 71a8f72 commit b402b32
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 5 deletions.
4 changes: 2 additions & 2 deletions components/index_weaviate/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ args:
description: |
Additional configuration to pass to the weaviate client.
type: dict
default: None
default: {}
additional_headers:
description: |
Additional headers to pass to the weaviate client.
Expand All @@ -133,7 +133,7 @@ args:
.io/developers/weaviate/modules/retriever-vectorizer-modules
Set this to None if you want to insert your own embeddings.
type: str
default: {}
default: None
module_config:
description: |
The configuration of the vectorizer module.
Expand Down
8 changes: 8 additions & 0 deletions components/retrieve_from_weaviate/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ The component takes the following arguments to alter its behavior:
| weaviate_url | str | The URL of the weaviate instance. | http://localhost:8080 |
| class_name | str | The name of the weaviate class that will be queried | / |
| top_k | int | Number of chunks to retrieve | / |
| additional_config | dict | Additional configuration to pass to the weaviate client. | / |
| additional_headers | dict | Additional headers to pass to the weaviate client. | / |
| hybrid_query | str | The hybrid query to be used for retrieval. Optional parameter. | / |
| hybrid_alpha | float | Argument to change how much each search affects the results. An alpha of 1 is a pure vector search. An alpha of 0 is a pure keyword search. | / |

<a id="retrieve_from_weaviate#usage"></a>
## Usage
Expand All @@ -55,6 +59,10 @@ dataset = dataset.apply(
# "weaviate_url": "http://localhost:8080",
# "class_name": ,
# "top_k": 0,
# "additional_config": {},
# "additional_headers": {},
# "hybrid_query": ,
# "hybrid_alpha": 0.0,
},
)
```
Expand Down
20 changes: 20 additions & 0 deletions components/retrieve_from_weaviate/fondant_component.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,23 @@ args:
top_k:
description: Number of chunks to retrieve
type: int
additional_config:
description: |
Additional configuration to pass to the weaviate client.
type: dict
default: {}
additional_headers:
description: |
Additional headers to pass to the weaviate client.
type: dict
default: {}
hybrid_query:
description: |
The hybrid query to be used for retrieval. Optional parameter.
type: str
default: None
hybrid_alpha:
description: |
Argument to change how much each search affects the results. An alpha of 1 is a pure vector search. An alpha of 0 is a pure keyword search.
type: float
default: None
44 changes: 41 additions & 3 deletions components/retrieve_from_weaviate/src/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import typing as t

import pandas as pd
import weaviate
from fondant.component import PandasTransformComponent
Expand All @@ -10,6 +12,10 @@ def __init__(
weaviate_url: str,
class_name: str,
top_k: int,
additional_config: t.Optional[dict],
additional_headers: t.Optional[dict],
hybrid_query: t.Optional[str],
hybrid_alpha: t.Optional[float],
**kwargs,
) -> None:
"""
Expand All @@ -18,24 +24,56 @@ def __init__(
class_name: Name of class to query
top_k: Amount of context to return.
kwargs: Unhandled keyword arguments passed in by Fondant.
additional_config: Additional configuration passed to the weaviate client.
additional_headers: Additional headers passed to the weaviate client.
hybrid_query: The hybrid query to be used for retrieval. Optional parameter.
hybrid_alpha: Argument to change how much each search affects the results. An alpha
of 1 is a pure vector search. An alpha of 0 is a pure keyword search.
"""
# Initialize your component here based on the arguments
self.client = weaviate.Client(weaviate_url)
self.client = weaviate.Client(
url=weaviate_url,
additional_config=additional_config if additional_config else None,
additional_headers=additional_headers if additional_headers else None,
)
self.class_name = class_name
self.k = top_k
self.hybrid_query, self.hybrid_alpha = self.validate_hybrid_query(
hybrid_query,
hybrid_alpha,
)

@staticmethod
def validate_hybrid_query(
hybrid_query: t.Optional[str],
hybrid_alpha: t.Optional[float],
):
if hybrid_query is not None and hybrid_alpha is None:
msg = (
"If hybrid_query is specified, hybrid_alpha must be specified as well."
)
raise ValueError(
msg,
)

return hybrid_query, hybrid_alpha

def teardown(self) -> None:
del self.client

def retrieve_chunks(self, vector_query: list):
"""Get results from weaviate database."""
result = (
query = (
self.client.query.get(self.class_name, ["passage"])
.with_near_vector({"vector": vector_query})
.with_limit(self.k)
.with_additional(["distance"])
.do()
)
if self.hybrid_query is not None:
query = query.with_hybrid(query=self.hybrid_query, alpha=self.hybrid_alpha)

result = query.do()

result_dict = result["data"]["Get"][self.class_name]
return [retrieved_chunk["passage"] for retrieved_chunk in result_dict]

Expand Down
4 changes: 4 additions & 0 deletions components/retrieve_from_weaviate/tests/test_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ def test_component():
weaviate_url=url,
class_name="Test",
top_k=2,
additional_config={},
additional_headers={},
hybrid_query=None,
hybrid_alpha=None,
)

output_dataframe = component.transform(input_dataframe)
Expand Down

0 comments on commit b402b32

Please sign in to comment.