Skip to content

Commit

Permalink
Feat/rag use llm (#483)
Browse files Browse the repository at this point in the history
Co-authored-by: skyline2006 <skyline2006@163.com>
Co-authored-by: Zhikaiiii <1658973216@qq.com>
  • Loading branch information
3 people authored Jun 13, 2024
1 parent 2f2660c commit fe1bc43
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 26 deletions.
10 changes: 4 additions & 6 deletions apps/agentfabric/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,16 +603,14 @@ def preview_send_message(chatbot, input, _state, uuid_str):
# get chat history from memory
history = user_memory.get_history()

# get knowledge from memory, currently get one file
uploaded_file = None
if len(append_files) > 0:
uploaded_file = append_files[0]
use_llm = True if len(user_agent.function_list) else False
ref_doc = user_memory.run(
query=input.text,
url=uploaded_file,
url=append_files,
max_token=4000,
top_k=2,
checked=True)
checked=True,
use_llm=use_llm)

response = ''
try:
Expand Down
7 changes: 2 additions & 5 deletions apps/agentfabric/appBot.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,9 @@ def send_message(chatbot, input, _state):
# get short term memory history
history = user_memory.get_history()

# get long term memory knowledge, currently get one file
uploaded_file = None
if len(append_files) > 0:
uploaded_file = append_files[0]
use_llm = True if len(user_agent.function_list) else False
ref_doc = user_memory.run(
query=input.text, url=uploaded_file, checked=True)
query=input.text, url=append_files, checked=True, use_llm=use_llm)

response = ''
try:
Expand Down
12 changes: 6 additions & 6 deletions apps/agentfabric/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,15 +402,15 @@ def generate():
f'load history method: time consumed {time.time() - start_time}'
)

# get knowledge from memory, currently get one file
uploaded_file = None
if len(file_paths) > 0:
uploaded_file = file_paths[0]
use_llm = True if len(user_agent.function_list) else False
ref_doc = user_memory.run(
query=input_content, url=uploaded_file, checked=True)
query=input_content,
url=file_paths,
checked=True,
use_llm=use_llm)
logger.info(
f'load knowledge method: time consumed {time.time() - start_time}, '
f'the uploaded_file name is {uploaded_file}') # noqa
f'the uploaded_file name is {file_paths}') # noqa

response = ''

Expand Down
16 changes: 12 additions & 4 deletions modelscope_agent/memory/memory_with_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,21 @@ def _run(self,
query: str = None,
url: str = None,
max_token: int = 4000,
top_k: int = 3,
**kwargs) -> Union[str, Iterator[str]]:
if isinstance(url, str):
url = [url]
if url and len(url):
self.store_knowledge.add(files=url)
if query:
summary_result = self.store_knowledge.run(query, files=url)
# limit length
return summary_result[0:max_token - 1]
summary_result = self.store_knowledge.run(
query, files=url, **kwargs)
# limit length
if isinstance(summary_result, list):
single_max_token = int(max_token / len(summary_result))
concatenated_records = '\n'.join([
record[0:single_max_token - 1] for record in summary_result
])

return concatenated_records
else:
return summary_result[0:max_token - 1]
22 changes: 17 additions & 5 deletions modelscope_agent/rag/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.query_engine import BaseQueryEngine, RetrieverQueryEngine
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document, QueryBundle, TransformComponent
from llama_index.core.schema import (Document, MetadataMode, QueryBundle,
TransformComponent)
from llama_index.core.settings import Settings
from llama_index.core.vector_stores.types import (MetadataFilter,
MetadataFilters)
Expand Down Expand Up @@ -299,15 +300,26 @@ def set_filter(self, files: List[str]):
]
retriever._filters = MetadataFilters(filters=filters)

def run(self, query: str, files: List[str] = [], **kwargs) -> str:
def run(self,
query: str,
files: List[str] = [],
use_llm: bool = True,
**kwargs) -> Union[str, List[str]]:
query_bundle = FileQueryBundle(query)
if isinstance(files, str):
files = [files]

if files and len(files) > 0:
self.set_filter(files)

return str(self.query_engine.query(query_bundle, **kwargs))
if use_llm:
return str(self.query_engine.query(query_bundle))
else:
nodes = self.query_engine.retrieve(query_bundle)
msg = [
n.node.get_content(metadata_mode=MetadataMode.LLM)
for n in nodes
]
return msg

def add(self, files: List[str]):
if isinstance(files, str):
Expand All @@ -329,4 +341,4 @@ def add(self, files: List[str]):
knowledge = BaseKnowledge('./data2', use_cache=False, llm=llm)

knowledge.add(['./data/常见QA.pdf'])
print(knowledge.run('高德天气API申请', files=['常见QA.pdf']))
print(knowledge.run('高德天气API申请', files=['常见QA.pdf'], use_llm=False))
12 changes: 12 additions & 0 deletions tests/test_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,15 @@ def test_memory_with_rag_multi_modal():
summary_str = memory.run('我想看rag的流程图')
print(summary_str)
assert 'rag.png' in summary_str


def test_memory_with_rag_no_use_llm():
memory = MemoryWithRag(use_knowledge_cache=False)

summary_str = memory.run(
query='模型大文件上传失败怎么办',
url=['tests/samples/modelscope_qa_2.txt'],
use_llm=False)
print(summary_str)
assert 'file_path' in summary_str
assert 'git-lfs' in summary_str

0 comments on commit fe1bc43

Please sign in to comment.