-
Notifications
You must be signed in to change notification settings - Fork 316
/
llama_index_knowledge.py
587 lines (539 loc) · 20.9 KB
/
llama_index_knowledge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
# -*- coding: utf-8 -*-
"""
This module is an integration of the Llama index RAG
into AgentScope package
"""
import os.path
from typing import Any, Optional, List, Union
from loguru import logger
try:
import llama_index
from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.base.embeddings.base import (
BaseEmbedding,
Embedding,
)
from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core import (
VectorStoreIndex,
StorageContext,
load_index_from_storage,
)
from llama_index.core.schema import (
Document,
TransformComponent,
)
except ImportError:
llama_index = None
BaseRetriever = None
BaseEmbedding = None
Embedding = None
IngestionPipeline = None
SentenceSplitter = None
VectorStoreIndex = None
StorageContext = None
load_index_from_storage = None
PrivateAttr = None
Document = None
TransformComponent = None
from agentscope.manager import FileManager
from agentscope.models import ModelWrapperBase
from agentscope.constants import (
DEFAULT_TOP_K,
DEFAULT_CHUNK_SIZE,
DEFAULT_CHUNK_OVERLAP,
)
from agentscope.rag.knowledge import Knowledge
try:
class _EmbeddingModel(BaseEmbedding):
"""
wrapper for ModelWrapperBase to an embedding model can be used
in Llama Index pipeline.
"""
_emb_model_wrapper: ModelWrapperBase = PrivateAttr()
def __init__(
self,
emb_model: ModelWrapperBase,
embed_batch_size: int = 1,
) -> None:
"""
Dummy wrapper to convert a ModelWrapperBase to llama Index
embedding model
Args:
emb_model (ModelWrapperBase):
embedding model in ModelWrapperBase
embed_batch_size (int):
batch size, defaults to 1
"""
super().__init__(
model_name="Temporary_embedding_wrapper",
embed_batch_size=embed_batch_size,
)
self._emb_model_wrapper = emb_model
def _get_query_embedding(self, query: str) -> List[float]:
"""
get embedding for query
Args:
query (str): query to be embedded
"""
# Note: AgentScope embedding model wrapper returns list
# of embedding
return list(self._emb_model_wrapper(query).embedding[0])
def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]:
"""
get embedding for list of strings
Args:
texts ( List[str]): texts to be embedded
"""
results = [
list(self._emb_model_wrapper(t).embedding[0]) for t in texts
]
return results
def _get_text_embedding(self, text: str) -> Embedding:
"""
get embedding for a single string
Args:
text (str): texts to be embedded
"""
return list(self._emb_model_wrapper(text).embedding[0])
# TODO: use proper async methods, but depends on model wrapper
async def _aget_query_embedding(self, query: str) -> List[float]:
"""The asynchronous version of _get_query_embedding."""
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> List[float]:
"""Asynchronously get text embedding."""
return self._get_text_embedding(text)
async def _aget_text_embeddings(
self,
texts: List[str],
) -> List[List[float]]:
"""Asynchronously get text embeddings."""
return self._get_text_embeddings(texts)
except Exception:
class _EmbeddingModel: # type: ignore[no-redef]
"""
A dummy embedding model for passing tests when
llama-index is not install
"""
def __init__(self, emb_model: ModelWrapperBase):
self._emb_model_wrapper = emb_model
class LlamaIndexKnowledge(Knowledge):
"""
This class is a wrapper with the llama index RAG.
"""
def __init__(
self,
knowledge_id: str,
emb_model: Union[ModelWrapperBase, BaseEmbedding, None] = None,
knowledge_config: Optional[dict] = None,
model: Optional[ModelWrapperBase] = None,
persist_root: Optional[str] = None,
overwrite_index: Optional[bool] = False,
showprogress: Optional[bool] = True,
**kwargs: Any,
) -> None:
"""
initialize the knowledge component based on the
llama-index framework: https://github.com/run-llama/llama_index
Notes:
In LlamaIndex, one of the most important concepts is index,
which is a data structure composed of Document objects, designed to
enable querying by an LLM. The core workflow of initializing RAG is
to convert data to index, and retrieve information from index.
For example:
1) preprocessing documents with data loaders
2) generate embedding by configuring pipline with embedding models
3) store the embedding-content to vector database
the default dir is "./rag_storage/knowledge_id"
Args:
knowledge_id (str):
The id of the RAG knowledge unit.
emb_model (ModelWrapperBase):
The embedding model used for generate embeddings
knowledge_config (dict):
The configuration for llama-index to
generate or load the index.
model (ModelWrapperBase):
The language model used for final synthesis
persist_root (str):
The root directory for index persisting
overwrite_index (Optional[bool]):
Whether to overwrite the index while refreshing
showprogress (Optional[bool]):
Whether to show the indexing progress
"""
super().__init__(
knowledge_id=knowledge_id,
emb_model=emb_model,
knowledge_config=knowledge_config,
model=model,
**kwargs,
)
if llama_index is None:
raise ImportError(
"LlamaIndexKnowledge require llama-index installed. "
"Try a stable llama-index version, such as "
"`pip install llama-index==0.10.30`",
)
if persist_root is None:
persist_root = FileManager.get_instance().cache_dir or "./"
self.persist_dir = os.path.join(persist_root, knowledge_id)
logger.info(f"** persist_dir: {self.persist_dir}")
self.emb_model = emb_model
self.overwrite_index = overwrite_index
self.showprogress = showprogress
self.index = None
# ensure the emb_model is compatible with LlamaIndex
if isinstance(emb_model, ModelWrapperBase):
self.emb_model = _EmbeddingModel(emb_model)
elif isinstance(self.emb_model, BaseEmbedding):
pass
else:
raise TypeError(
f"Embedding model does not support {type(self.emb_model)}.",
)
# then we can initialize the RAG
self._init_rag()
def _init_rag(self, **kwargs: Any) -> None:
"""
Initialize the RAG. This includes:
* if the persist_dir exists, load the persisted index
* if not, convert the data to index
* if needed, update the index
* set the retriever to retrieve information from index
Notes:
* the index is persisted in the self.persist_dir
* the refresh_index method is placed here for testing, it can be
called externally. For example, updated the index periodically
by calling rag.refresh_index() during the execution of the
agent.
"""
if os.path.exists(self.persist_dir):
self._load_index()
# self.refresh_index()
else:
self._data_to_index()
self._get_retriever()
logger.info(
f"RAG with knowledge ids: {self.knowledge_id} "
f"initialization completed!\n",
)
def _load_index(self) -> None:
"""
Load the persisted index from persist_dir.
"""
# load the storage_context
storage_context = StorageContext.from_defaults(
persist_dir=self.persist_dir,
)
# construct index from
self.index = load_index_from_storage(
storage_context=storage_context,
embed_model=self.emb_model,
)
logger.info(f"index loaded from {self.persist_dir}")
def _data_to_index(self) -> None:
"""
Convert the data to index by configs. This includes:
* load the data to documents by using information from configs
* set the transformations associated with documents
* convert the documents to nodes
* convert the nodes to index
Notes:
As each selected file type may need to use a different loader
and transformations, knowledge_config is a list of configs.
"""
nodes = []
# load data to documents and set transformations
# using information in knowledge_config
for config in self.knowledge_config.get("data_processing"):
documents = self._data_to_docs(config=config)
transformations = self._set_transformations(config=config).get(
"transformations",
)
nodes_docs = self._docs_to_nodes(
documents=documents,
transformations=transformations,
)
nodes = nodes + nodes_docs
# convert nodes to index
self.index = VectorStoreIndex(
nodes=nodes,
embed_model=self.emb_model,
)
logger.info("index calculation completed.")
# persist the calculated index
self.index.storage_context.persist(persist_dir=self.persist_dir)
logger.info("index persisted.")
def _data_to_docs(
self,
query: Optional[str] = None,
config: dict = None,
) -> Any:
"""
This method set the loader as needed, or just use the default setting.
Then use the loader to load data from dir to documents.
Notes:
We can use simple directory loader (SimpleDirectoryReader)
to load general documents, including Markdown, PDFs,
Word documents, PowerPoint decks, images, audio and video.
Or use SQL loader (DatabaseReader) to load database.
Args:
query (Optional[str]):
optional, used when the data is in a database.
config (dict):
optional, used when the loader config is in a config file.
Returns:
Any: loaded documents
"""
loader = self._set_loader(config=config).get("loader")
# let the doc_id be the filename for each document
loader.filename_as_id = True
if query is None:
documents = loader.load_data()
else:
# this is for querying a database,
# does not work for loading a document directory
documents = loader.load_data(query)
logger.info(f"loaded {len(documents)} documents")
return documents
def _docs_to_nodes(
self,
documents: List[Document],
transformations: Optional[list[Optional[TransformComponent]]] = None,
) -> Any:
"""
Convert the loaded documents to nodes using transformations.
Args:
documents (List[Document]):
documents to be processed, usually expected to be in
llama index Documents.
transformations (Optional[list[TransformComponent]]):
optional, specifies the transformations (operators) to
process documents (e.g., split the documents into smaller
chunks)
Return:
Any: return the index of the processed document
"""
# nodes, or called chunks, is a presentation of the documents
# we build nodes by using the IngestionPipeline
# for each document with corresponding transformations
pipeline = IngestionPipeline(
transformations=transformations,
)
# stack up the nodes from the pipline
nodes = pipeline.run(
documents=documents,
show_progress=self.showprogress,
)
logger.info("nodes generated.")
return nodes
def _set_loader(self, config: dict) -> Any:
"""
Set the loader as needed, or just use the default setting.
Args:
config (dict): a dictionary containing configurations
"""
if "load_data" in config:
# we prepare the loader from the configs
loader = self._prepare_args_from_config(
config=config.get("load_data", {}),
)
else:
# we prepare the loader by default
try:
from llama_index.core import SimpleDirectoryReader
except ImportError as exc_inner:
raise ImportError(
" LlamaIndexAgent requires llama-index to be install."
"Please run `pip install llama-index`",
) from exc_inner
loader = {
"loader": SimpleDirectoryReader(
input_dir="set_default_data_path",
),
}
logger.info("loaders are ready.")
return loader
def _set_transformations(self, config: dict) -> Any:
"""
Set the transformations as needed, or just use the default setting.
Args:
config (dict): a dictionary containing configurations.
"""
if "store_and_index" in config:
temp = self._prepare_args_from_config(
config=config.get("store_and_index", {}),
)
transformations = temp.get("transformations")
else:
transformations = [
SentenceSplitter(
chunk_size=self.knowledge_config.get(
"chunk_size",
DEFAULT_CHUNK_SIZE,
),
chunk_overlap=self.knowledge_config.get(
"chunk_overlap",
DEFAULT_CHUNK_OVERLAP,
),
),
]
# adding embedding model as the last step of transformation
# https://docs.llamaindex.ai/en/stable/module_guides/loading/ingestion_pipeline/root.html
transformations.append(self.emb_model)
logger.info("transformations are ready.")
# as the last step, we need to repackage the transformations in dict
transformations = {"transformations": transformations}
return transformations
def _get_retriever(
self,
similarity_top_k: int = None,
**kwargs: Any,
) -> BaseRetriever:
"""
Set the retriever as needed, or just use the default setting.
Args:
retriever (Optional[BaseRetriever]): passing a retriever in
LlamaIndexKnowledge
rag_config (dict): rag configuration, including similarity top k
index.
"""
# set the retriever
logger.info(
f"similarity_top_k" f"={similarity_top_k or DEFAULT_TOP_K}",
)
retriever = self.index.as_retriever(
embed_model=self.emb_model,
similarity_top_k=similarity_top_k or DEFAULT_TOP_K,
**kwargs,
)
logger.info("retriever is ready.")
return retriever
def retrieve(
self,
query: str,
similarity_top_k: int = None,
to_list_strs: bool = False,
retriever: Optional[BaseRetriever] = None,
**kwargs: Any,
) -> list[Any]:
"""
This is a basic retrieve function for knowledge.
It will build a retriever on the fly and return the
result of the query.
Args:
query (str):
query is expected to be a question in string
similarity_top_k (int):
the number of most similar data returned by the
retriever.
to_list_strs (bool):
whether returns the list of strings;
if False, return NodeWithScore
retriever (BaseRetriever):
for advanced usage, user can pass their own retriever.
Return:
list[Any]: list of str or NodeWithScore
More advanced query processing can refer to
https://docs.llamaindex.ai/en/stable/examples/query_transformations/query_transform_cookbook.html
"""
if retriever is None:
retriever = self._get_retriever(similarity_top_k)
retrieved = retriever.retrieve(str(query))
if to_list_strs:
results = []
for node in retrieved:
results.append(node.get_text())
return results
return retrieved
def refresh_index(self) -> None:
"""
Refresh the index when needed.
"""
for config in self.knowledge_config.get("data_processing"):
documents = self._data_to_docs(config=config)
# store and indexing for each file type
transformations = self._set_transformations(config=config).get(
"transformations",
)
self._insert_docs_to_index(
documents=documents,
transformations=transformations,
)
def _insert_docs_to_index(
self,
documents: List[Document],
transformations: TransformComponent,
) -> None:
"""
Add documents to the index. Given a list of documents, we first test if
the doc_id is already in the index. If not, we add the doc to the
list. If yes, and the over-write flag is enabled,
we delete the old doc and add the new doc to the list.
Lastly, we generate nodes for all documents on the list, and insert
the nodes to the index.
Args:
documents (List[Document]): list of documents to be added.
transformations (TransformComponent): transformations that
convert the documents into nodes.
"""
# this is the pipline that generate the nodes
pipeline = IngestionPipeline(
transformations=transformations,
)
# we need to generate nodes from this list of documents
insert_docs_list = []
for doc in documents:
if doc.doc_id not in self.index.ref_doc_info.keys():
# if the doc_id is not in the index, we add it to the list
insert_docs_list.append(doc)
logger.info(
f"add new documents to index, " f"doc_id={doc.doc_id}",
)
else:
if self.overwrite_index:
# if we enable overwrite index, we delete the old doc
self.index.delete_ref_doc(
ref_doc_id=doc.doc_id,
delete_from_docstore=True,
)
# then add the same doc to the list
insert_docs_list.append(doc)
logger.info(
f"replace document in index, " f"doc_id={doc.doc_id}",
)
logger.info("documents scan completed.")
# we generate nodes for documents on the list
nodes = pipeline.run(
documents=insert_docs_list,
show_progress=True,
)
logger.info("nodes generated.")
# insert the new nodes to index
self.index.insert_nodes(nodes=nodes)
logger.info("nodes inserted to index.")
# persist the updated index
self.index.storage_context.persist(persist_dir=self.persist_dir)
def _delete_docs_from_index(
self,
documents: List[Document],
) -> None:
"""
Delete the nodes that are associated with a list of documents.
Args:
documents (List[Document]): list of documents to be deleted.
"""
doc_id_list = [doc.doc_id for doc in documents]
for key in self.index.ref_doc_info.keys():
if key in doc_id_list:
self.index.delete_ref_doc(
ref_doc_id=key,
delete_from_docstore=True,
)
logger.info(f"docs deleted from index, doc_id={key}")
# persist the updated index
self.index.storage_context.persist(persist_dir=self.persist_dir)
logger.info("nodes delete completed.")