-
Notifications
You must be signed in to change notification settings - Fork 15.9k
/
Copy pathweaviate_hybrid_search.py
162 lines (133 loc) Β· 6.03 KB
/
weaviate_hybrid_search.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import annotations
from typing import Any, Dict, List, Optional, cast
from uuid import uuid4
from langchain_core.callbacks import CallbackManagerForRetrieverRun
from langchain_core.documents import Document
from langchain_core.retrievers import BaseRetriever
from pydantic import ConfigDict, model_validator
class WeaviateHybridSearchRetriever(BaseRetriever):
"""`Weaviate hybrid search` retriever.
See the documentation:
https://weaviate.io/blog/hybrid-search-explained
"""
client: Any = None
"""keyword arguments to pass to the Weaviate client."""
index_name: str
"""The name of the index to use."""
text_key: str
"""The name of the text key to use."""
alpha: float = 0.5
"""The weight of the text key in the hybrid search."""
k: int = 4
"""The number of results to return."""
attributes: List[str]
"""The attributes to return in the results."""
create_schema_if_missing: bool = True
"""Whether to create the schema if it doesn't exist."""
@model_validator(mode="before")
@classmethod
def validate_client(
cls,
values: Dict[str, Any],
) -> Any:
try:
import weaviate
except ImportError:
raise ImportError(
"Could not import weaviate python package. "
"Please install it with `pip install weaviate-client`."
)
if not isinstance(values["client"], weaviate.Client):
client = values["client"]
raise ValueError(
f"client should be an instance of weaviate.Client, got {type(client)}"
)
if values.get("attributes") is None:
values["attributes"] = []
cast(List, values["attributes"]).append(values["text_key"])
if values.get("create_schema_if_missing", True):
class_obj = {
"class": values["index_name"],
"properties": [{"name": values["text_key"], "dataType": ["text"]}],
"vectorizer": "text2vec-openai",
}
if not values["client"].schema.exists(values["index_name"]):
values["client"].schema.create_class(class_obj)
return values
model_config = ConfigDict(
arbitrary_types_allowed=True,
)
# added text_key
def add_documents(self, docs: List[Document], **kwargs: Any) -> List[str]:
"""Upload documents to Weaviate."""
from weaviate.util import get_valid_uuid
with self.client.batch as batch:
ids = []
for i, doc in enumerate(docs):
metadata = doc.metadata or {}
data_properties = {self.text_key: doc.page_content, **metadata}
# If the UUID of one of the objects already exists
# then the existing objectwill be replaced by the new object.
if "uuids" in kwargs:
_id = kwargs["uuids"][i]
else:
_id = get_valid_uuid(uuid4())
batch.add_data_object(data_properties, self.index_name, _id)
ids.append(_id)
return ids
def _get_relevant_documents(
self,
query: str,
*,
run_manager: CallbackManagerForRetrieverRun,
where_filter: Optional[Dict[str, object]] = None,
score: bool = False,
hybrid_search_kwargs: Optional[Dict[str, object]] = None,
) -> List[Document]:
"""Look up similar documents in Weaviate.
query: The query to search for relevant documents
of using weviate hybrid search.
where_filter: A filter to apply to the query.
https://weaviate.io/developers/weaviate/guides/querying/#filtering
score: Whether to include the score, and score explanation
in the returned Documents meta_data.
hybrid_search_kwargs: Used to pass additional arguments
to the .with_hybrid() method.
The primary uses cases for this are:
1) Search specific properties only -
specify which properties to be used during hybrid search portion.
Note: this is not the same as the (self.attributes) to be returned.
Example - hybrid_search_kwargs={"properties": ["question", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#selected-properties-only
2) Weight boosted searched properties -
Boost the weight of certain properties during the hybrid search portion.
Example - hybrid_search_kwargs={"properties": ["question^2", "answer"]}
https://weaviate.io/developers/weaviate/search/hybrid#weight-boost-searched-properties
3) Search with a custom vector - Define a different vector
to be used during the hybrid search portion.
Example - hybrid_search_kwargs={"vector": [0.1, 0.2, 0.3, ...]}
https://weaviate.io/developers/weaviate/search/hybrid#with-a-custom-vector
4) Use Fusion ranking method
Example - from weaviate.gql.get import HybridFusion
hybrid_search_kwargs={"fusion": fusion_type=HybridFusion.RELATIVE_SCORE}
https://weaviate.io/developers/weaviate/search/hybrid#fusion-ranking-method
"""
query_obj = self.client.query.get(self.index_name, self.attributes)
if where_filter:
query_obj = query_obj.with_where(where_filter)
if score:
query_obj = query_obj.with_additional(["score", "explainScore"])
if hybrid_search_kwargs is None:
hybrid_search_kwargs = {}
result = (
query_obj.with_hybrid(query, alpha=self.alpha, **hybrid_search_kwargs)
.with_limit(self.k)
.do()
)
if "errors" in result:
raise ValueError(f"Error during query: {result['errors']}")
docs = []
for res in result["data"]["Get"][self.index_name]:
text = res.pop(self.text_key)
docs.append(Document(page_content=text, metadata=res))
return docs