-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
47ea4b9
commit 9755c41
Showing
28 changed files
with
3,403 additions
and
2,668 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Project : AI. @by PyCharm | ||
# @File : chatocr | ||
# @Time : 2023/8/25 16:45 | ||
# @Author : betterme | ||
# @WeChat : meutils | ||
# @Software : PyCharm | ||
# @Description : https://aistudio.baidu.com/modelsdetail?modelId=332 | ||
|
||
from meutils.pipe import * | ||
from IPython.display import Image | ||
from langchain.chat_models import ChatOpenAI | ||
|
||
llm = ChatOpenAI() | ||
|
||
from rapidocr_onnxruntime import RapidOCR | ||
|
||
rapid_ocr = RapidOCR() | ||
|
||
p = "/Users/betterme/PycharmProjects/AI/MeUtils/meutils/ai_cv/invoice.jpg" | ||
ocr_result, _ = rapid_ocr(p) | ||
Image(p) | ||
|
||
key = '识别编号,公司名称,开票日期,开票人,收款人,复核人,金额' | ||
|
||
prompt = f"""你现在的任务是从OCR文字识别的结果中提取我指定的关键信息。OCR的文字识别结果使用```符号包围,包含所识别出来的文字, | ||
顺序在原始图片中从左至右、从上至下。我指定的关键信息使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、 | ||
对应错位等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。 | ||
在返回结果时使用json格式,包含一个key-value对,key值为我指定的关键信息,value值为所抽取的结果。 | ||
如果认为OCR识别结果中没有关键信息key,则将value赋值为“未找到相关信息”。 请只输出json格式的结果,不要包含其它多余文字!下面正式开始: | ||
OCR文字:```{ocr_result}``` | ||
要抽取的关键信息:[{key}]。""" | ||
print(llm.predict(prompt)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,155 +1,45 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# @Project : AI. @by PyCharm | ||
# @File : chatdoc | ||
# @Time : 2023/7/15 20:53 | ||
# @File : base | ||
# @Time : 2023/8/9 15:04 | ||
# @Author : betterme | ||
# @WeChat : meutils | ||
# @Software : PyCharm | ||
# @Description : | ||
|
||
from meutils.pipe import * | ||
|
||
from chatllm.llmchain.utils import docs2dataframe | ||
from chatllm.llmchain.decorators import llm_stream | ||
from chatllm.llmchain.vectorstores import Milvus, FAISS | ||
from chatllm.llmchain.embeddings import OpenAIEmbeddings, DashScopeEmbeddings | ||
from chatllm.llmchain.document_loaders import FilesLoader | ||
from chatllm.llmchain.prompts.prompt_templates import context_prompt_template | ||
from meutils.pipe import * | ||
|
||
from langchain.text_splitter import * | ||
from langchain.chat_models import ChatOpenAI | ||
from langchain.chains.question_answering import load_qa_chain | ||
from langchain.embeddings.base import Embeddings | ||
from langchain.prompts import ChatPromptTemplate | ||
from langchain.schema.language_model import BaseLanguageModel | ||
from chatllm.llmchain.vectorstores import Milvus | ||
from langchain.vectorstores.base import VectorStore | ||
from chatllm.llmchain.decorators import llm_stream, llm_astream | ||
|
||
|
||
class ChatBase(object): | ||
|
||
def __init__( | ||
self, | ||
# 初始化 openai_api_key | ||
llm: Optional[BaseLanguageModel] = None, | ||
embeddings: Optional[Embeddings] = None, | ||
vectorstore_cls=None, | ||
|
||
get_api_key: Optional[Callable[[int], List[str]]] = None, | ||
prompt_template=context_prompt_template, | ||
|
||
get_api_key: Optional[Callable[[int], List[str]]] = None, # 队列 | ||
**kwargs | ||
): | ||
self.llm = llm or ChatOpenAI(model="gpt-3.5-turbo-16k-0613", temperature=0, streaming=True) | ||
self.llm.streaming = True | ||
self.embeddings = embeddings or OpenAIEmbeddings(chunk_size=5) | ||
self.vectorstore = None | ||
self.vectorstore_cls: VectorStore = vectorstore_cls or FAISS | ||
|
||
self.prompt_template = prompt_template | ||
|
||
if get_api_key: | ||
self.llm.openai_api_key = get_api_key(1)[0] | ||
self.embeddings.get_api_key = get_api_key | ||
self.embeddings.openai_api_key = get_api_key(1)[0] | ||
|
||
self.chain = load_qa_chain( | ||
self.llm, | ||
chain_type="stuff", | ||
prompt=ChatPromptTemplate.from_template(prompt_template) # todo: 增加上下文信息 | ||
) | ||
|
||
def pipeline(self): | ||
pass | ||
|
||
def create_index(self, docs: List[Document], **kwargs): | ||
"""初始化 drop_old=True""" | ||
self.vectorstore = self.vectorstore_cls.from_documents(docs, **{**self.vdb_kwargs, **kwargs}) | ||
|
||
def llm_qa(self, query: str, k: int = 5, threshold: float = 0.5, **kwargs: Any): | ||
assert self.vectorstore is not None, "Please create index." | ||
def chat(self, prompt, **kwargs): | ||
yield from llm_stream(self.llm.predict)(prompt) | ||
|
||
docs = self.vectorstore.similarity_search(query, k=k, threshold=threshold, **kwargs) | ||
docs = docs | xUnique_plus(lambda doc: doc.page_content.strip()) # 按内容去重,todo: 按语义相似度去重 | ||
docs = docs[:k] | ||
if docs: | ||
return llm_stream(self.chain.run)({"input_documents": docs, "question": query}) # todo: 空文档报错吗? | ||
else: | ||
logger.warning("Retrieval is empty, Please check the vector database !!!") | ||
# yield from "无相关文档" | ||
|
||
@staticmethod | ||
def load_file( | ||
file_paths, | ||
max_workers=3, | ||
chunk_size=2000, | ||
chunk_overlap=200, | ||
separators: Optional[List[str]] = None | ||
) -> List[Document]: | ||
"""支持多文件""" | ||
loader = FilesLoader(file_paths, max_workers=max_workers) | ||
separators = separators or ['\n\n', '\r', '\n', '\r\n', '。', '!', '!', '\\?', '?', '……', '…'] | ||
textsplitter = RecursiveCharacterTextSplitter( | ||
chunk_size=chunk_size, | ||
chunk_overlap=chunk_overlap, | ||
add_start_index=True, | ||
separators=separators | ||
) | ||
docs = loader.load_and_split(textsplitter) | ||
return docs | ||
|
||
@property | ||
def vdb_kwargs(self): | ||
""" | ||
# 向量数据库 | ||
self.collection_name = collection_name | ||
# _vdb_kwargs = self.vdb_kwargs.copy() | ||
# _vdb_kwargs['embedding_function'] = _vdb_kwargs.pop('embedding') # 参数一致性 | ||
# # _vdb_kwargs['drop_old'] = True # 重新创建 | ||
# self.vectorstore = vectorstore or Milvus(**_vdb_kwargs) # 耗时吗 | ||
:return: | ||
""" | ||
connection_args = { | ||
'uri': os.getenv('ZILLIZ_ENDPOINT'), | ||
'token': os.getenv('ZILLIZ_TOKEN') | ||
} | ||
address = os.getenv('MILVUS_ADDRESS') # 该参数优先 | ||
if address: | ||
connection_args.pop('uri') | ||
connection_args['address'] = address | ||
|
||
index_params = { | ||
"metric_type": "IP", | ||
"index_type": "IVF_FLAT", | ||
"params": {"nlist": 1024} | ||
} | ||
|
||
embedding_function = self.embeddings | ||
|
||
vdb_kwargs = dict( | ||
embedding=embedding_function, | ||
connection_args=connection_args, | ||
index_params=index_params, | ||
search_params=None, | ||
collection_name=None, | ||
drop_old=False, | ||
) | ||
|
||
return vdb_kwargs | ||
def achat(self, prompt, **kwargs): | ||
close_event_loop() | ||
yield from async2sync_generator(llm_astream(self.llm.apredict)(prompt)) | ||
|
||
|
||
if __name__ == '__main__': | ||
from chatllm.llmchain.applications import ChatBase | ||
from chatllm.llmchain.embeddings import HuggingFaceEmbeddings | ||
from chatllm.llmchain.vectorstores import FAISS, Milvus | ||
|
||
model_name = '/Users/betterme/PycharmProjects/AI/m3e-small' | ||
encode_kwargs = {'normalize_embeddings': True, "show_progress_bar": True} | ||
embeddings = HuggingFaceEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs) | ||
|
||
docs = [Document(page_content='1')] * 10 | ||
faiss = FAISS.from_documents(docs, embeddings) | ||
# ChatBase().chat('1+1') | xprint(end='\n') | ||
ChatBase().achat('周杰伦是谁') | xprint(end='\n') | ||
|
||
cb = ChatBase(vectorstore=FAISS) | ||
cb.create_index(docs) | ||
# for i in ChatBase().achat('周杰伦是谁'): | ||
# print(i, end='') |
Oops, something went wrong.