Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image features. Add repeat detection to save tokens #88

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 57 additions & 18 deletions chat_paper.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,10 @@ def __init__(self, key_word, query, filter_keys,
self.chat_api_list = [api.strip() for api in self.chat_api_list if len(api) > 5]
self.cur_api = 0
self.file_format = args.file_format
if args.save_image:
self.gitee_key = self.config.get('Gitee', 'api')
else:
self.gitee_key = ''
if args.save_image == 'gitee':
self.image_path = self.config.get('Gitee', 'api')
elif args.save_image == 'local':
self.image_path = root_path + 'export/' + 'attachments/'
self.max_token_num = 4096
self.encoding = tiktoken.get_encoding("gpt2")

Expand All @@ -64,16 +64,26 @@ def filter_arxiv(self, max_results=30):
filter_keys = self.filter_keys

print("filter_keys:", self.filter_keys)
# 确保每个关键词都能在摘要中找到,才算是目标论文
for index, result in enumerate(search.results()):
abs_text = result.summary.replace('-\n', '-').replace('\n', ' ')
meet_num = 0
for f_key in filter_keys.split(" "):
if f_key.lower() in abs_text.lower():
meet_num += 1
if meet_num == len(filter_keys.split(" ")):
filter_results.append(result)
# Exact match: 确保每个关键词都能在摘要中找到,才算是目标论文
if args.coarse == False:
print("Exact match")
for index, result in enumerate(search.results()):
abs_text = result.summary.replace('-\n', '-').replace('\n', ' ')
meet_num = 0
for f_key in filter_keys.split(" "):
if f_key.lower() in abs_text.lower():
meet_num += 1
if meet_num == len(filter_keys.split(" ")):
filter_results.append(result)
# break
else:
print("Coarse match")
for index, result in enumerate(search.results()):
abs_text = result.summary.replace('-\n', '-').replace('\n', ' ')
for f_key in filter_keys.split(" "):
if f_key.lower() in abs_text.lower():
filter_results.append(result)
break
print("筛选后剩下的论文数量:")
print("filter_results:", len(filter_results))
print("filter_papers:")
Expand Down Expand Up @@ -103,6 +113,13 @@ def download_pdf(self, filter_results):
try:
title_str = self.validateTitle(result.title)
pdf_name = title_str+'.pdf'
# Try to avoid repeating papers
for dir in os.listdir(os.path.join(self.root_path,'pdf_files')):
for file in os.listdir(os.path.join(self.root_path, 'pdf_files', dir)):
if pdf_name == file and args.repeat==False:
raise Exception('\033[91m'+pdf_name+" already exists, no summary will be made to save your tokens. If you insist to summary, pass --repeat True to force summarizing"+'\033[0m')
elif pdf_name == file and args.repeat==True:
print('\033[93m'+pdf_name+" already exists, repeat summarizing anyway"+'\033[0m')
# result.download_pdf(path, filename=pdf_name)
self.try_download_pdf(result, path, pdf_name)
paper_path = os.path.join(path, pdf_name)
Expand Down Expand Up @@ -143,7 +160,7 @@ def upload_gitee(self, image_path, image_name='', ext='png'):
path = image_name+ '-' +date_str

payload = {
"access_token": self.gitee_key,
"access_token": self.config.get('Gitee', 'api'),
"owner": self.config.get('Gitee', 'owner'),
"repo": self.config.get('Gitee', 'repo'),
"path": self.config.get('Gitee', 'path'),
Expand Down Expand Up @@ -244,15 +261,35 @@ def summary_with_chat(self, paper_list):
chat_conclusion_text = self.chat_conclusion(text=text, conclusion_prompt_token=conclusion_prompt_token)
htmls.append(chat_conclusion_text)
htmls.append("\n"*4)


# 第四步补充材料,实验/结果部分前的图片比较有价值
htmls.append("**Supplement Materials:**\n")
img_list, ext = paper.get_image_path(self.image_path)
if img_list is None or args.save_image == '':
pass
elif args.save_image == 'local':
for i_page in range(len(img_list)):
for i_image in range(len(img_list[i_page])):
htmls.append("\n")
htmls.append("![Fig]("+img_list[i_page][i_image].replace(' ', '%20').replace('./export', '.')+")")
htmls.append("\n")
elif args.save_image == 'gitee':
for i_page in range(len(img_list)):
for i_image in range(len(img_list[i_page])):
image_title = self.validateTitle(paper.title)
image_url = self.upload_gitee(image_path=img_list[i_page][i_image], image_name=image_title, ext=ext[i_page][i_image])
htmls.append("\n")
htmls.append("![Fig]("+image_url+")")
htmls.append("\n")

# # 整合成一个文件,打包保存下来。
date_str = str(datetime.datetime.now())[:13].replace(' ', '-')
try:
export_path = os.path.join(self.root_path, 'export')
os.makedirs(export_path)
except:
pass
mode = 'w' if paper_index == 0 else 'a'
mode = 'w' # Don't understand here, we should always overwrite
file_name = os.path.join(export_path, date_str+'-'+self.validateTitle(paper.title[:80])+"."+self.file_format)
self.export_to_markdown("\n".join(htmls), file_name=file_name, mode=mode)

Expand Down Expand Up @@ -464,11 +501,13 @@ def main(args):
parser.add_argument("--pdf_path", type=str, default='', help="if none, the bot will download from arxiv with query")
parser.add_argument("--query", type=str, default='all: ChatGPT robot', help="the query string, ti: xx, au: xx, all: xx,")
parser.add_argument("--key_word", type=str, default='reinforcement learning', help="the key word of user research fields")
parser.add_argument("--filter_keys", type=str, default='ChatGPT robot', help="the filter key words, 摘要中每个单词都得有,才会被筛选为目标论文")
parser.add_argument("--filter_keys", type=str, default='ChatGPT robot', help="the filter key words, 摘要中每个单词都得有,才会被筛选为目标论文, separated by space")
parser.add_argument("--coarse", action='store_true', help="if every key word needs to be matched")
parser.add_argument("--repeat", action='store_true', help="if pdf files already exist, don't summarize again to save tokens")
parser.add_argument("--max_results", type=int, default=1, help="the maximum number of results")
# arxiv.SortCriterion.Relevance
parser.add_argument("--sort", type=str, default="Relevance", help="another is LastUpdatedDate")
parser.add_argument("--save_image", default=False, help="save image? It takes a minute or two to save a picture! But pretty")
parser.add_argument("--save_image", type=str, default='', help="save image? It takes a minute or two to save a picture! But pretty")
parser.add_argument("--file_format", type=str, default='md', help="导出的文件格式,如果存图片的话,最好是md,如果不是的话,txt的不会乱")
parser.add_argument("--language", type=str, default='zh', help="The other output lauguage is English, is en")

Expand Down
77 changes: 48 additions & 29 deletions get_paper_from_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,57 +44,76 @@ def get_paper_info(self):

def get_image_path(self, image_path=''):
"""
将PDF中的第一张图保存到image.png里面,存到本地目录,返回文件名称,供gitee读取
:param filename: 图片所在路径,"C:\\Users\\Administrator\\Desktop\\nwd.pdf"
:param image_path: 图片提取后的保存路径
:return:
将pdf中Experiment/Evaluation前的图片均保存下来, 一般method部分会有理论图解
parameter:
- image_path: path in which the imgs are saved

return:
- img_path: list of list, path of images in each page
- ext: the associated extension
"""
# Create image folders
try:
os.makedirs(image_path)
except:
pass

# open file
max_size = 0
image_list = []
ext = []
stop_index = 0
exp_key = ["Materials and Methods", "Experiment Settings",
'Experiment', "Experimental Results", "Evaluation", "Experiments",
"Results", 'Findings', 'Data Analysis']

for key in self.section_page_dict.keys():
if key in exp_key:
stop_index = self.section_page_dict[key]
break

with fitz.Document(self.path) as my_pdf_file:
# 遍历所有页面
for page_number in range(1, len(my_pdf_file) + 1):
# 遍历实验前的所有页面
for page_number in range(1, stop_index+1):
# 查看独立页面
page = my_pdf_file[page_number - 1]
# 查看当前页所有图片
images = page.get_images()
# 遍历当前页面所有图片
for image_number, image in enumerate(page.get_images(), start=1):
image_in_page = []
ext_in_page = []
for image_number, image in enumerate(images, start=1):
# 访问图片xref
xref_value = image[0]
# 提取图片信息
base_image = my_pdf_file.extract_image(xref_value)
# 访问图片
image_bytes = base_image["image"]
# 获取图片扩展名
ext = base_image["ext"]
ext_in_page.append(base_image["ext"])
# 加载图片
image = Image.open(io.BytesIO(image_bytes))
image_size = image.size[0] * image.size[1]
if image_size > max_size:
max_size = image_size
image_list.append(image)
for image in image_list:
image_size = image.size[0] * image.size[1]
if image_size == max_size:
image_name = f"image.{ext}"
im_path = os.path.join(image_path, image_name)
print("im_path:", im_path)

max_pix = 480
origin_min_pix = min(image.size[0], image.size[1])

if image.size[0] > image.size[1]:
min_pix = int(image.size[1] * (max_pix/image.size[0]))
newsize = (max_pix, min_pix)
else:
min_pix = int(image.size[0] * (max_pix/image.size[1]))
newsize = (min_pix, max_pix)
image = image.resize(newsize)

image.save(open(im_path, "wb"))
return im_path, ext
image_in_page.append(image)
image_list.append(image_in_page)
ext.append(ext_in_page)

img_path = []
for i_page in range(len(image_list)):
im_page_path = []
for i_image in range(len(image_list[i_page])):
image = image_list[i_page][i_image]
image_size = image.size[0] * image.size[1]
image_name = self.title + f"_{i_page}" + f"_{i_image}" + f".{ext[i_page][i_image]}"
path = os.path.join(image_path, image_name)
im_page_path.append(path)
image.save(open(path, "wb"))
img_path.append(im_page_path)

if len(img_path) != 0:
return img_path, ext
return None, None

# 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表
Expand Down