Skip to content

Commit

Permalink
add: ChatOCR
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanjie-ai committed Sep 5, 2023
1 parent 47ea4b9 commit 9755c41
Show file tree
Hide file tree
Showing 28 changed files with 3,403 additions and 2,668 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ include LICENSE
include README*

recursive-include tests *
recursive-exclude data *
recursive-exclude * __pycache__
recursive-exclude * *.py[co]
recursive-exclude docs *
Expand Down
30 changes: 24 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,25 @@ for c in completion:

</details>

## ChatOCR

<details markdown="1">
<summary>Click to ChatOCR</summary>

```python
from meutils.pipe import *
from chatllm.llmchain.applications import ChatOCR

llm = ChatOCR()
file_path = "data/invoice.jpg"
llm.display(file_path, 700)
llm.chat('识别编号,公司名称,开票日期,开票人,收款人,复核人,金额', file_path=file_path) | xprint
```

![ocr](data/imgs/chatocr.png)

</details>

## ChatMind

<details markdown="1">
Expand Down Expand Up @@ -109,12 +128,11 @@ for i in qa(query='东北证券主营业务'):

- ChatGLM-6B 模型硬件需求

| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
| -------------- | ------------------------- | --------------------------------- |
| FP16(无量化) | 13 GB | 14 GB |
| INT8 | 8 GB | 9 GB |
| INT4 | 6 GB | 7 GB |

| **量化等级** | **最低 GPU 显存**(推理) | **最低 GPU 显存**(高效参数微调) |
|-----------|-------------------|-----------------------|
| FP16(无量化) | 13 GB | 14 GB |
| INT8 | 8 GB | 9 GB |
| INT4 | 6 GB | 7 GB |

- 从本地加载模型
- [安装指南](docs/INSTALL.md)
Expand Down
2 changes: 2 additions & 0 deletions chatllm/llmchain/TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ https://github.com/amosjyng/langchain-visualizer/blob/main/tests/demo.ipynb

# 研究结构化输出
# 研究 agents

1w字以内整篇理解
4 changes: 3 additions & 1 deletion chatllm/llmchain/applications/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,7 @@
# @Software : PyCharm
# @Description :

from chatllm.llmchain.applications.chatbase import ChatBase
from chatllm.llmchain.applications.chatfile import ChatFile
from chatllm.llmchain.applications.summarizer import Summarizer

from chatllm.llmchain.applications.chatocr import ChatOCR
34 changes: 34 additions & 0 deletions chatllm/llmchain/applications/_chatocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Project : AI. @by PyCharm
# @File : chatocr
# @Time : 2023/8/25 16:45
# @Author : betterme
# @WeChat : meutils
# @Software : PyCharm
# @Description : https://aistudio.baidu.com/modelsdetail?modelId=332

from meutils.pipe import *
from IPython.display import Image
from langchain.chat_models import ChatOpenAI

llm = ChatOpenAI()

from rapidocr_onnxruntime import RapidOCR

rapid_ocr = RapidOCR()

p = "/Users/betterme/PycharmProjects/AI/MeUtils/meutils/ai_cv/invoice.jpg"
ocr_result, _ = rapid_ocr(p)
Image(p)

key = '识别编号,公司名称,开票日期,开票人,收款人,复核人,金额'

prompt = f"""你现在的任务是从OCR文字识别的结果中提取我指定的关键信息。OCR的文字识别结果使用```符号包围,包含所识别出来的文字,
顺序在原始图片中从左至右、从上至下。我指定的关键信息使用[]符号包围。请注意OCR的文字识别结果可能存在长句子换行被切断、不合理的分词、
对应错位等问题,你需要结合上下文语义进行综合判断,以抽取准确的关键信息。
在返回结果时使用json格式,包含一个key-value对,key值为我指定的关键信息,value值为所抽取的结果。
如果认为OCR识别结果中没有关键信息key,则将value赋值为“未找到相关信息”。 请只输出json格式的结果,不要包含其它多余文字!下面正式开始:
OCR文字:```{ocr_result}```
要抽取的关键信息:[{key}]。"""
print(llm.predict(prompt))
138 changes: 14 additions & 124 deletions chatllm/llmchain/applications/chatbase.py
Original file line number Diff line number Diff line change
@@ -1,155 +1,45 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Project : AI. @by PyCharm
# @File : chatdoc
# @Time : 2023/7/15 20:53
# @File : base
# @Time : 2023/8/9 15:04
# @Author : betterme
# @WeChat : meutils
# @Software : PyCharm
# @Description :

from meutils.pipe import *

from chatllm.llmchain.utils import docs2dataframe
from chatllm.llmchain.decorators import llm_stream
from chatllm.llmchain.vectorstores import Milvus, FAISS
from chatllm.llmchain.embeddings import OpenAIEmbeddings, DashScopeEmbeddings
from chatllm.llmchain.document_loaders import FilesLoader
from chatllm.llmchain.prompts.prompt_templates import context_prompt_template
from meutils.pipe import *

from langchain.text_splitter import *
from langchain.chat_models import ChatOpenAI
from langchain.chains.question_answering import load_qa_chain
from langchain.embeddings.base import Embeddings
from langchain.prompts import ChatPromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from chatllm.llmchain.vectorstores import Milvus
from langchain.vectorstores.base import VectorStore
from chatllm.llmchain.decorators import llm_stream, llm_astream


class ChatBase(object):

def __init__(
self,
# 初始化 openai_api_key
llm: Optional[BaseLanguageModel] = None,
embeddings: Optional[Embeddings] = None,
vectorstore_cls=None,

get_api_key: Optional[Callable[[int], List[str]]] = None,
prompt_template=context_prompt_template,

get_api_key: Optional[Callable[[int], List[str]]] = None, # 队列
**kwargs
):
self.llm = llm or ChatOpenAI(model="gpt-3.5-turbo-16k-0613", temperature=0, streaming=True)
self.llm.streaming = True
self.embeddings = embeddings or OpenAIEmbeddings(chunk_size=5)
self.vectorstore = None
self.vectorstore_cls: VectorStore = vectorstore_cls or FAISS

self.prompt_template = prompt_template

if get_api_key:
self.llm.openai_api_key = get_api_key(1)[0]
self.embeddings.get_api_key = get_api_key
self.embeddings.openai_api_key = get_api_key(1)[0]

self.chain = load_qa_chain(
self.llm,
chain_type="stuff",
prompt=ChatPromptTemplate.from_template(prompt_template) # todo: 增加上下文信息
)

def pipeline(self):
pass

def create_index(self, docs: List[Document], **kwargs):
"""初始化 drop_old=True"""
self.vectorstore = self.vectorstore_cls.from_documents(docs, **{**self.vdb_kwargs, **kwargs})

def llm_qa(self, query: str, k: int = 5, threshold: float = 0.5, **kwargs: Any):
assert self.vectorstore is not None, "Please create index."
def chat(self, prompt, **kwargs):
yield from llm_stream(self.llm.predict)(prompt)

docs = self.vectorstore.similarity_search(query, k=k, threshold=threshold, **kwargs)
docs = docs | xUnique_plus(lambda doc: doc.page_content.strip()) # 按内容去重,todo: 按语义相似度去重
docs = docs[:k]
if docs:
return llm_stream(self.chain.run)({"input_documents": docs, "question": query}) # todo: 空文档报错吗?
else:
logger.warning("Retrieval is empty, Please check the vector database !!!")
# yield from "无相关文档"

@staticmethod
def load_file(
file_paths,
max_workers=3,
chunk_size=2000,
chunk_overlap=200,
separators: Optional[List[str]] = None
) -> List[Document]:
"""支持多文件"""
loader = FilesLoader(file_paths, max_workers=max_workers)
separators = separators or ['\n\n', '\r', '\n', '\r\n', '。', '!', '!', '\\?', '?', '……', '…']
textsplitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size,
chunk_overlap=chunk_overlap,
add_start_index=True,
separators=separators
)
docs = loader.load_and_split(textsplitter)
return docs

@property
def vdb_kwargs(self):
"""
# 向量数据库
self.collection_name = collection_name
# _vdb_kwargs = self.vdb_kwargs.copy()
# _vdb_kwargs['embedding_function'] = _vdb_kwargs.pop('embedding') # 参数一致性
# # _vdb_kwargs['drop_old'] = True # 重新创建
# self.vectorstore = vectorstore or Milvus(**_vdb_kwargs) # 耗时吗
:return:
"""
connection_args = {
'uri': os.getenv('ZILLIZ_ENDPOINT'),
'token': os.getenv('ZILLIZ_TOKEN')
}
address = os.getenv('MILVUS_ADDRESS') # 该参数优先
if address:
connection_args.pop('uri')
connection_args['address'] = address

index_params = {
"metric_type": "IP",
"index_type": "IVF_FLAT",
"params": {"nlist": 1024}
}

embedding_function = self.embeddings

vdb_kwargs = dict(
embedding=embedding_function,
connection_args=connection_args,
index_params=index_params,
search_params=None,
collection_name=None,
drop_old=False,
)

return vdb_kwargs
def achat(self, prompt, **kwargs):
close_event_loop()
yield from async2sync_generator(llm_astream(self.llm.apredict)(prompt))


if __name__ == '__main__':
from chatllm.llmchain.applications import ChatBase
from chatllm.llmchain.embeddings import HuggingFaceEmbeddings
from chatllm.llmchain.vectorstores import FAISS, Milvus

model_name = '/Users/betterme/PycharmProjects/AI/m3e-small'
encode_kwargs = {'normalize_embeddings': True, "show_progress_bar": True}
embeddings = HuggingFaceEmbeddings(model_name=model_name, encode_kwargs=encode_kwargs)

docs = [Document(page_content='1')] * 10
faiss = FAISS.from_documents(docs, embeddings)
# ChatBase().chat('1+1') | xprint(end='\n')
ChatBase().achat('周杰伦是谁') | xprint(end='\n')

cb = ChatBase(vectorstore=FAISS)
cb.create_index(docs)
# for i in ChatBase().achat('周杰伦是谁'):
# print(i, end='')
Loading

0 comments on commit 9755c41

Please sign in to comment.