-
Notifications
You must be signed in to change notification settings - Fork 17
/
llamaindex_adaptive_rag.py
396 lines (346 loc) · 12.7 KB
/
llamaindex_adaptive_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
import asyncio
import os
from dataclasses import dataclass
from typing import Any, List, Literal, Dict
import nest_asyncio
from pypdf import mult
import streamlit as st
from llama_index.core import (
Settings,
SimpleDirectoryReader,
VectorStoreIndex,
get_response_synthesizer,
)
from llama_index.core.agent import AgentRunner, FunctionCallingAgentWorker
from llama_index.core.indices.document_summary.base import DocumentSummaryIndex
from llama_index.core.indices.query.query_transform.base import (
StepDecomposeQueryTransform,
)
from llama_index.core.llms.llm import LLM
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.postprocessor.metadata_replacement import (
MetadataReplacementPostProcessor,
)
from llama_index.core.query_engine import (
BaseQueryEngine,
CustomQueryEngine,
RetrieverQueryEngine,
)
from llama_index.core.query_engine.router_query_engine import RouterQueryEngine
from llama_index.core.response_synthesizers.type import ResponseMode
from llama_index.core.retrievers import RecursiveRetriever
from llama_index.core.schema import Document
from llama_index.core.selectors.llm_selectors import LLMSingleSelector
from llama_index.core.tools import QueryEngineTool, ToolMetadata
from llama_index.embeddings.cohere import CohereEmbedding
from llama_index.embeddings.nomic import NomicEmbedding
from llama_index.legacy.postprocessor import CohereRerank, SentenceTransformerRerank
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.cohere import Cohere
from llama_index.llms.gemini import Gemini
from llama_index.llms.groq import Groq
from llama_index.llms.openai import OpenAI
from llama_index.core.query_engine import MultiStepQueryEngine
from llama_index.legacy.embeddings.langchain import LangchainEmbedding
from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings
from loguru import logger
from rich.pretty import pprint
from tqdm.asyncio import tqdm
nest_asyncio.apply()
VERBOSE = True
WIN_SZ = 3
SIM_TOP_K = 5
RERANK_TOP_K = 3
N_MULTI_STEPS = 5
def pretty_print(title: str = None, content: Any = None):
if not VERBOSE:
return
if title is None:
print(content)
return
print(title)
pprint(content)
llm_map: Dict[str, LLM] = {
"anthropic": Anthropic(temperature=0, model="claude-3-haiku-20240307"),
"openai": OpenAI(temperature=0, model="gpt-4-turbo"),
"cohere": Cohere(temperature=0, max_tokens=2048),
"groq": Groq(model="mixtral-8x7b-32768", temperature=0, timeout=60),
}
summary_llm = llm_map[
st.sidebar.selectbox(
"Summary LLM", list(llm_map.keys()), index=2, key="summary_llm"
)
]
multi_step_query_engine_llm = llm_map[
st.sidebar.selectbox(
"Multi-step Query Engine LLM",
list(llm_map.keys()),
index=3,
key="multi_step_query_engine_llm",
)
]
standalone_query_engine_llm = llm_map[
st.sidebar.selectbox(
"Standalone Query Engine LLM",
list(llm_map.keys()),
index=3,
key="standalone_query_engine_llm",
)
]
agent_llm = llm_map[
st.sidebar.selectbox("Agent LLM", list(llm_map.keys()), index=0, key="agent_llm")
]
chain_llm = llm_map[
st.sidebar.selectbox("Chain LLM", list(llm_map.keys()), index=0, key="chain_llm")
]
general_llm = llm_map[
st.sidebar.selectbox(
"General LLM", list(llm_map.keys()), index=0, key="general_llm"
)
]
Settings.llm = llm_map[
st.sidebar.selectbox(
"Settings LLM", list(llm_map.keys()), index=1, key="settings_llm"
)
]
Settings.embed_model = LangchainEmbedding(NVIDIAEmbeddings(model="nvolveqa_40k"))
Settings.node_parser = SentenceWindowNodeParser.from_defaults(
window_size=WIN_SZ,
window_metadata_key="window",
original_text_metadata_key="original_text",
)
@dataclass
class DataSource:
name: str
description: str
query_engine: BaseQueryEngine
multi_step_query_engine: BaseQueryEngine
def __hash__(self):
return hash((self.name, self.description))
def __eq__(self, other):
if isinstance(other, DataSource):
return self.name == other.name and self.description == other.description
return False
async def load_docs(file_paths: List[str]) -> List[DataSource]:
all_doc_src = []
tasks = [
index_and_chunks(
os.path.basename(file_path).split(".")[0],
SimpleDirectoryReader(input_files=[file_path]).load_data(),
)
for file_path in file_paths
]
doc_src_tasks_run = await tqdm.gather(*tasks)
all_doc_src.extend(doc_src_tasks_run)
return all_doc_src
async def index_and_chunks(file_name: str, raw_docs: List[Document]) -> DataSource:
pretty_print("Raw docs", file_name)
name = file_name
# check if the name is based on String should match pattern '^[a-zA-Z0-9_-]{1,64}$'
# required by the LlamaIndex.
# if not, then replace with a valid name
if not name.isalnum():
# replace with a valid name
name = "file_" + str(hash(name))
postproc = MetadataReplacementPostProcessor(target_metadata_key="window")
rerank = SentenceTransformerRerank(
top_n=RERANK_TOP_K, model="BAAI/bge-reranker-base"
)
# vector indexing
retriever = RecursiveRetriever(
"vector",
retriever_dict={
"vector": VectorStoreIndex.from_documents(
raw_docs, show_progress=True
).as_retriever(similarity_top_k=SIM_TOP_K)
},
verbose=VERBOSE,
)
# summary
summary = await RetrieverQueryEngine.from_args(
DocumentSummaryIndex.from_documents(
raw_docs,
show_progress=True,
).as_retriever(),
llm=summary_llm,
response_synthesizer=get_response_synthesizer(
response_mode=ResponseMode.SIMPLE_SUMMARIZE
),
node_postprocessors=[postproc, rerank],
verbose=VERBOSE,
).aquery("Provide the shortest description of the content.")
query_engine = RetrieverQueryEngine.from_args(
retriever,
llm=standalone_query_engine_llm,
node_postprocessors=[postproc, rerank],
verbose=VERBOSE,
)
return DataSource(
name=name,
description=summary.response,
query_engine=query_engine,
multi_step_query_engine=MultiStepQueryEngine(
query_engine=query_engine,
query_transform=StepDecomposeQueryTransform(
llm=multi_step_query_engine_llm, verbose=VERBOSE
),
num_steps=N_MULTI_STEPS,
),
)
def build_mulit_step_query_engine_tools(
ds_list: List[DataSource],
) -> List[QueryEngineTool]:
desc_fmt = "Useful for complex queries on the content with multi-step that covers the following dedicated topic:\n{topic}\n"
return [
QueryEngineTool(
query_engine=ds.multi_step_query_engine,
metadata=ToolMetadata(
name=ds.name, description=desc_fmt.format(topic=ds.description)
),
)
for ds in ds_list
]
def build_standalone_query_engine_tools(
ds_list: List[DataSource],
) -> List[QueryEngineTool]:
desc_fmt = (
"Useful for simple queries on the content that covers the following dedicated topic:\n{topic}\n"
)
return [
QueryEngineTool(
query_engine=ds.query_engine,
metadata=ToolMetadata(
name=ds.name, description=desc_fmt.format(topic=ds.description)
),
)
for ds in ds_list
]
def build_query_engine_tools_agent_tool(
query_engine_tools: List[QueryEngineTool],
base_description: str,
) -> QueryEngineTool:
agent_worker = FunctionCallingAgentWorker.from_tools(
query_engine_tools,
llm=agent_llm,
verbose=VERBOSE,
allow_parallel_tool_calls=True,
)
agent_runner = AgentRunner(
agent_worker,
llm=agent_llm,
verbose=VERBOSE,
)
description_list = [base_description]
for tools in query_engine_tools:
meta = tools.metadata
description_list.append(f"Description of {meta.name}:\n{meta.description}\n")
description = "\n\n".join(description_list)
return QueryEngineTool(
query_engine=agent_runner,
metadata=ToolMetadata(description=description),
)
class LLMQueryEngine(CustomQueryEngine):
"""RAG String Query Engine."""
llm: LLM
def custom_query(self, query_str: str):
return str(self.llm.complete(query_str))
def build_fallback_query_engine_tool() -> QueryEngineTool:
return QueryEngineTool(
query_engine=LLMQueryEngine(llm=general_llm),
metadata=ToolMetadata(
name="General queries as fallback",
description=(
"Useful for information about general queries other than specific data sources, as fallback action if no other tool is selected."
),
),
)
def build_adaptive_rag_chain(ds_list: List[DataSource]) -> RouterQueryEngine:
standalone_query_engine_tools = build_standalone_query_engine_tools(ds_list)
standalone_query_engine_tools_agent_tool = build_query_engine_tools_agent_tool(
build_standalone_query_engine_tools(ds_list),
"Useful for queries that span multiple and cross-docs, the docs should cover different topics:\n",
)
multi_step_query_engine_tools = build_mulit_step_query_engine_tools(ds_list)
multi_step_query_engine_tools_agent_tool = build_query_engine_tools_agent_tool(
build_mulit_step_query_engine_tools(ds_list),
"Useful for complex queries that span multiple and cross-docs with the help of multi-step, the docs should cover different topics:\n",
)
fallback_query_engine_tool = build_fallback_query_engine_tool()
query_engine_tools = (
multi_step_query_engine_tools
+ [multi_step_query_engine_tools_agent_tool]
+ standalone_query_engine_tools
+ [standalone_query_engine_tools_agent_tool]
+ [fallback_query_engine_tool]
)
return RouterQueryEngine.from_defaults(
llm=chain_llm,
selector=LLMSingleSelector.from_defaults(llm=chain_llm),
query_engine_tools=query_engine_tools,
verbose=VERBOSE,
)
async def doc_uploader() -> BaseQueryEngine:
with st.sidebar:
uploaded_docs = st.file_uploader(
"# Upload files",
key="doc_uploader",
accept_multiple_files=True,
)
if not uploaded_docs:
st.session_state["file_names"] = None
st.session_state["query_engine"] = None
logger.debug("No file uploaded")
return None
if uploaded_docs:
pretty_print("Uploaded files", uploaded_docs)
tmp_dir = "tmp/"
if not os.path.exists(tmp_dir):
os.makedirs(tmp_dir)
file_names = []
for uploaded_doc in uploaded_docs:
temp_file_path = os.path.join(tmp_dir, f"{uploaded_doc.name}")
with open(temp_file_path, "wb") as file:
file.write(uploaded_doc.getvalue())
file_name = uploaded_doc.name
logger.debug(f"Uploaded {file_name}")
uploaded_doc.flush()
uploaded_doc.close()
file_names.append(temp_file_path)
all_same_files = (
all(
[
file_name == st.session_state["file_names"][idx]
for idx, file_name in enumerate(file_names)
]
)
if st.session_state.get("file_names")
else False
)
if all_same_files:
logger.debug("Same files uploaded")
return st.session_state["query_engine"]
logger.debug("New files, new queries, indexing needed")
st.session_state["file_names"] = file_names
pretty_print("File names", st.session_state["file_names"])
with st.spinner("Indexing, it take while depending on the system..."):
ds_list = await load_docs(st.session_state["file_names"])
pretty_print("Data sources", ds_list)
st.session_state["query_engine"] = build_adaptive_rag_chain(ds_list)
return st.session_state["query_engine"]
return None
async def main():
st.sidebar.title("Upload file")
query_engine = await doc_uploader()
if query_engine is None:
pretty_print("Has query_engine", "No query_engine")
return
query_text = st.text_input(
"Query",
key="query_text",
placeholder="Enter your query here",
).strip()
if query_text is not None and query_text != "":
final_res = await query_engine.aquery(query_text)
st.write(str(final_res))
if __name__ == "__main__":
asyncio.run(main())