Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Code] update rag #122

Merged
merged 5 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion rag/src/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,5 +33,5 @@
{system_prompt}
根据下面检索回来的信息,回答问题。
{content}
问题:{question}
问题:{query}
"""
55 changes: 45 additions & 10 deletions rag/src/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from langchain_community.document_loaders import DirectoryLoader, TextLoader, JSONLoader
from langchain_text_splitters import CharacterTextSplitter, RecursiveCharacterTextSplitter, RecursiveJsonSplitter
from BCEmbedding import EmbeddingModel, RerankerModel
from util.pipeline import EmoLLMRAG
# from util.pipeline import EmoLLMRAG
from transformers import AutoTokenizer, AutoModelForCausalLM
from langchain.document_loaders.pdf import PyPDFDirectoryLoader
from langchain.document_loaders import UnstructuredFileLoader,DirectoryLoader
Expand Down Expand Up @@ -91,13 +91,13 @@ def extract_text_from_json(self, obj, content=None):
if isinstance(obj, dict):
for key, value in obj.items():
try:
self.extract_text_from_json(value, content)
content = self.extract_text_from_json(value, content)
except Exception as e:
print(f"Error processing value: {e}")
elif isinstance(obj, list):
for index, item in enumerate(obj):
try:
self.extract_text_from_json(item, content)
content = self.extract_text_from_json(item, content)
except Exception as e:
print(f"Error processing item: {e}")
elif isinstance(obj, str):
Expand Down Expand Up @@ -157,14 +157,15 @@ def split_conversation(self, path):
logger.info(f'splitting file {file_path}')
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(data)
# print(data)
for conversation in data:
# for dialog in conversation['conversation']:
##按qa对切分,将每一轮qa转换为langchain_core.documents.base.Document
# content = self.extract_text_from_json(dialog,'')
# split_qa.append(Document(page_content = content))
#按conversation块切分
content = self.extract_text_from_json(conversation['conversation'], '')
logger.info(f'content====={content}')
split_qa.append(Document(page_content = content))
# logger.info(f'split_qa size====={len(split_qa)}')
return split_qa
Expand Down Expand Up @@ -229,9 +230,8 @@ def retrieve(self, query, vector_db, k=5):
# compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
# compressed_docs = compression_retriever.get_relevant_documents(query)
# return compressed_docs


def rerank(self, query, docs):
def rerank(self, query, docs):
reranker = self.load_rerank_model()
passages = []
for doc in docs:
Expand All @@ -240,9 +240,41 @@ def rerank(self, query, docs):
sorted_pairs = sorted(zip(passages, scores), key=lambda x: x[1], reverse=True)
sorted_passages, sorted_scores = zip(*sorted_pairs)
return sorted_passages, sorted_scores




# def create_prompt(question, context):
# from langchain.prompts import PromptTemplate
# prompt_template = f"""请基于以下内容回答问题:

# {context}

# 问题: {question}
# 回答:"""
# prompt = PromptTemplate(
# template=prompt_template, input_variables=["context", "question"]
# )
# logger.info(f'Prompt: {prompt}')
# return prompt

def create_prompt(question, context):
prompt = f"""请基于以下内容: {context} 给出问题答案。问题如下: {question}。回答:"""
logger.info(f'Prompt: {prompt}')
return prompt

def test_zhipu(prompt):
from zhipuai import ZhipuAI
api_key = "" # 填写您自己的APIKey
if api_key == "":
raise ValueError("请填写api_key")
client = ZhipuAI(api_key=api_key)
response = client.chat.completions.create(
model="glm-4", # 填写需要调用的模型名称
messages=[
{"role": "user", "content": prompt[:100]}
],
)
print(response.choices[0].message)

if __name__ == "__main__":
logger.info(data_dir)
if not os.path.exists(data_dir):
Expand All @@ -254,7 +286,8 @@ def rerank(self, query, docs):
# query = "儿童心理学说明-内容提要-目录 《儿童心理学》1993年修订版说明 《儿童心理学》是1961年初全国高等学校文科教材会议指定朱智贤教授编 写的。1962年初版,1979年再版。"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想,我现在感到非常孤独、累和迷茫。您能给我提供一些建议吗?"
# query = "这在一定程度上限制了其思维能力,特别是辩证 逻辑思维能力的发展。随着年龄的增长,初中三年级学生逐步克服了依赖性"
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
# query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
query = "我现在心情非常差,有什么解决办法吗?"
docs, retriever = dp.retrieve(query, vector_db, k=10)
logger.info(f'Query: {query}')
logger.info("Retrieve results:")
Expand All @@ -267,4 +300,6 @@ def rerank(self, query, docs):
logger.info("After reranking...")
for i in range(len(scores)):
logger.info(str(scores[i]) + '\n')
logger.info(passages[i])
logger.info(passages[i])
prompt = create_prompt(query, passages[0])
test_zhipu(prompt) ## 如果显示'Server disconnected without sending a response.'可能是由于上下文窗口限制
62 changes: 44 additions & 18 deletions rag/src/main.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
import os
import json
import pickle
import numpy as np
from typing import Tuple
from sentence_transformers import SentenceTransformer

from config.config import knowledge_json_path, knowledge_pkl_path, model_repo, model_dir, base_dir
from util.encode import load_embedding, encode_qa
from util.pipeline import EmoLLMRAG
import time
import jwt

from config.config import base_dir, data_dir
from data_processing import Data_process
from pipeline import EmoLLMRAG

from langchain_openai import ChatOpenAI
from loguru import logger
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import streamlit as st
from openxlab.model import download
from config.config import embedding_path, doc_dir, qa_dir, knowledge_pkl_path, data_dir
from data_processing import Data_process
'''
1)构建完整的 RAG pipeline。输入为用户 query,输出为 answer
2)调用 embedding 提供的接口对 query 向量化
Expand All @@ -24,21 +21,45 @@
6)拼接 prompt 并调用模型返回结果

'''
# download(
# model_repo=model_repo,
# output='model'
# )
def get_glm(temprature):
llm = ChatOpenAI(
model_name="glm-4",
openai_api_base="https://open.bigmodel.cn/api/paas/v4",
openai_api_key=generate_token("api-key"),
streaming=False,
temperature=temprature
)
return llm

def generate_token(apikey: str, exp_seconds: int=100):
try:
id, secret = apikey.split(".")
except Exception as e:
raise Exception("invalid apikey", e)

payload = {
"api_key": id,
"exp": int(round(time.time() * 1000)) + exp_seconds * 1000,
"timestamp": int(round(time.time() * 1000)),
}

return jwt.encode(
payload,
secret,
algorithm="HS256",
headers={"alg": "HS256", "sign_type": "SIGN"},
)

@st.cache_resource
def load_model():
model_dir = os.path.join(base_dir,'../model')
logger.info(f'Loading model from {model_dir}')
model = (
AutoModelForCausalLM.from_pretrained(model_dir, trust_remote_code=True)
AutoModelForCausalLM.from_pretrained('model', trust_remote_code=True)
.to(torch.bfloat16)
.cuda()
)
tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained('model', trust_remote_code=True)
return model, tokenizer

def main(query, system_prompt=''):
Expand All @@ -60,4 +81,9 @@ def main(query, system_prompt=''):

if __name__ == "__main__":
query = "我现在处于高三阶段,感到非常迷茫和害怕。我觉得自己从出生以来就是多余的,没有必要存在于这个世界。无论是在家庭、学校、朋友还是老师面前,我都感到被否定。我非常难过,对高考充满期望但成绩却不理想"
main(query)
main(query)
#model = get_glm(0.7)
#rag_obj = EmoLLMRAG(model, 3)
#res = rag_obj.main(query)
#logger.info(res)

23 changes: 11 additions & 12 deletions rag/src/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,8 @@
from langchain_core.prompts import PromptTemplate
from transformers.utils import logging

from data_processing import DataProcessing
from config.config import retrieval_num, select_num, system_prompt, prompt_template

from data_processing import Data_process
from config.config import system_prompt, prompt_template
logger = logging.get_logger(__name__)


Expand All @@ -28,10 +27,8 @@ def __init__(self, model, retrieval_num, rerank_flag=False, select_num=3) -> Non

"""
self.model = model
self.data_processing_obj = Data_process()
self.vectorstores = self._load_vector_db()
self.system_prompt = self._get_system_prompt()
self.prompt_template = self._get_prompt_template()
self.data_processing_obj = DataProcessing()
self.system_prompt = system_prompt
self.prompt_template = prompt_template
self.retrieval_num = retrieval_num
Expand All @@ -43,8 +40,6 @@ def _load_vector_db(self):
调用 embedding 模块给出接口 load vector DB
"""
vectorstores = self.data_processing_obj.load_vector_db()
if not vectorstores:
vectorstores = self.data_processing_obj.load_index_and_knowledge()

return vectorstores

Expand All @@ -57,13 +52,17 @@ def get_retrieval_content(self, query) -> str:
content = ''
documents = self.vectorstores.similarity_search(query, k=self.retrieval_num)

# 如果需要rerank,调用接口对 documents 进行 rerank
if self.rerank_flag:
documents = self.data_processing_obj.rerank(documents, self.select_num)

for doc in documents:
content += doc.page_content

# 如果需要rerank,调用接口对 documents 进行 rerank
if self.rerank_flag:
documents, _ = self.data_processing_obj.rerank(documents, self.select_num)

content = ''
for doc in documents:
content += doc
logger.info(f'Retrieval data: {content}')
return content

def generate_answer(self, query, content) -> str:
Expand Down