-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathui_local_rag.py
214 lines (181 loc) · 6.79 KB
/
ui_local_rag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
import os
import re
import time
import shutil
import concurrent.futures as cf
import gradio as gr
from repolya.rag.vdb_faiss import (
get_faiss_HuggingFace,
merge_faiss_HuggingFace,
)
from repolya.rag.qa_chain import (
qa_vdb_multi_query_textgen,
qa_with_context_as_mio_textgen,
create_rag_subtask_list_textgen,
)
from repolya.autogen.workflow import (
search_faiss_textgen,
)
from repolya.rag.digest_dir import (
calculate_md5,
dir_to_faiss_HuggingFace,
)
from repolya.rag.doc_loader import clean_txt
from repolya.toolset.tool_bshr import bshr_vdb
from repolya._const import LOG_ROOT, WORKSPACE_RAG
from repolya._log import logger_rag
_upload_dir = WORKSPACE_RAG / 'lj_rag_upload'
_db_name = str(WORKSPACE_RAG / 'lj_rag_hf')
_clean_txt_dir = str(WORKSPACE_RAG / 'lj_rag_clean_txt')
if not os.path.exists(_upload_dir):
os.makedirs(_upload_dir)
if not os.path.exists(_db_name):
os.makedirs(_db_name)
if not os.path.exists(_clean_txt_dir):
os.makedirs(_clean_txt_dir)
_log_ans_fast = LOG_ROOT / '_ans_fast.txt'
_log_ref_fast = LOG_ROOT / '_ref_fast.txt'
_log_ans_autogen = LOG_ROOT / '_ans_autogen.txt'
_log_ref_autogen = LOG_ROOT / '_ref_autogen.txt'
def write_log_ans(_log_ans, _txt, _status=None):
with open(_log_ans, 'w', encoding='utf-8') as wf:
if _status == "continue":
_txt += "\n\n计算中,请稍候..."
# elif _status == "done":
# _txt += "\n\n[完成]"
wf.write(_txt)
def write_log_ref(_log_ref, _txt):
with open(_log_ref, 'w', encoding='utf-8') as wf:
wf.write(_txt)
def rag_read_logs():
with open(_log_ans_fast, "r") as f:
_ans_fast = f.read()
with open(_log_ref_fast, "r") as f:
_ref_fast = f.read()
with open(_log_ans_autogen, "r") as f:
_ans_autogen = f.read()
with open(_log_ref_autogen, "r") as f:
_ref_autogen = f.read()
return [_ans_fast, _ref_fast, _ans_autogen, _ref_autogen]
def rag_clean_logs():
write_log_ans(_log_ans_fast,'')
write_log_ref(_log_ref_fast,'')
write_log_ans(_log_ans_autogen,'')
write_log_ref(_log_ref_autogen,'')
rag_clean_logs()
def rag_clean_all():
rag_clean_logs()
print('rag_clean_logs()')
return [gr.Textbox(value=""), gr.Button(variant="secondary")]
def is_dir_empty(dir_path):
"""
检查指定目录是否为空
"""
if not os.path.exists(dir_path):
return True
return len(os.listdir(dir_path)) == 0
def move_dir_content(src_dir, dst_dir):
"""
将 src_dir 下的所有内容移动到 dst_dir 下
"""
for item in os.listdir(src_dir):
src_path = os.path.join(src_dir, item)
dst_path = os.path.join(dst_dir, item)
if os.path.isfile(src_path):
shutil.move(src_path, dst_path)
elif os.path.isdir(src_path):
if not os.path.exists(dst_path):
os.makedirs(dst_path)
move_dir_content(src_path, dst_path)
##### RAG
def rag_handle_upload(_tmp_path):
_tmp_files = []
_out = []
for i in _tmp_path:
i_fp = i.name
_tmp_files.append(i_fp)
i_fn = os.path.basename(i_fp)
i_dir = os.path.dirname(i_fp)
# print(i_dir)
i_md5 = calculate_md5(i_fp)
# print(i_md5)
i_fn_new = f"{i_md5}" + os.path.splitext(os.path.basename(i_fp))[1]
i_fp_new = os.path.join(_upload_dir, i_fn_new)
i_db_name = os.path.join(_upload_dir, f"{i_md5}_hf")
# print(i_fp_new)
if not os.path.exists(i_fp_new):
logger_rag.info(f"upload {i_fn} to {i_fn_new}")
dir_to_faiss_HuggingFace(i_dir, i_db_name, _clean_txt_dir)
shutil.move(i_fp, i_fp_new)
if is_dir_empty(_db_name):
print(i_db_name, _db_name)
move_dir_content(i_db_name, _db_name)
else:
merge_faiss_HuggingFace(_db_name, i_db_name)
shutil.rmtree(i_db_name)
logger_rag.info(f"done upload process")
_out.append(f"upload {i_fn} to {i_fn_new}")
else:
logger_rag.info(f"{i_fn} ({i_fn_new}) exists")
_out.append(f"{i_fn} ({i_fn_new}) exists")
return "\n".join(_out)
def qa_faiss_HuggingFace(_query, _vdb):
start_time = time.time()
_textgen_url = "http://127.0.0.1:5552"
_ans, _step, _token_cost = qa_vdb_multi_query_textgen(_query, _vdb, 'stuff', _textgen_url)
end_time = time.time()
execution_time = end_time - start_time
_time = f"Time: {execution_time:.1f} seconds"
logger_rag.info(f"{_time}")
return [_ans, _step, _token_cost, _time]
def rag_helper_fast(_query):
_vdb = get_faiss_HuggingFace(_db_name)
_ans, _ref = "", ""
write_log_ans(_log_ans_fast,'')
write_log_ref(_log_ref_fast,'')
with cf.ProcessPoolExecutor() as executor:
write_log_ans(_log_ans_fast, '', 'continue')
_ans, _step, _token_cost, _time = qa_faiss_HuggingFace(_query, _vdb)
_ref = f"{_token_cost}\n{_time}\n\n{_step}"
write_log_ans(_log_ans_fast, clean_txt(_ans), 'done')
write_log_ref(_log_ref_fast, _ref)
return
def rag_helper_autogen_textgen(_query):
_textgen_url = "http://127.0.0.1:5552"
_vdb = get_faiss_HuggingFace(_db_name)
_ans, _ref = "", ""
write_log_ans(_log_ans_autogen,'')
write_log_ref(_log_ref_autogen,'')
start_time = time.time()
write_log_ans(_log_ans_autogen, '', 'continue')
# ChatCompletion.start_logging(reset_counter=True, compact=False)
### task list
_subtask, _token_cost = create_rag_subtask_list_textgen(_query, _textgen_url)
write_log_ans(_log_ans_autogen, f"生成的子问题列表:\n\n{_subtask}", 'continue')
end_time = time.time()
execution_time = end_time - start_time
_time = f"Time: {execution_time:.1f} seconds"
write_log_ref(_log_ref_autogen, f"\n\n{_time}")
# print(f"cost_usage: {cost_usage(ChatCompletion.logged_history)}")
### context
_context, _token_cost = search_faiss_textgen(_subtask, _vdb, _textgen_url)
write_log_ans(_log_ans_autogen, f"生成的 QA 上下文:\n\n{_context}", 'continue')
end_time = time.time()
execution_time = end_time - start_time
_time = f"Time: {execution_time:.1f} seconds"
write_log_ref(_log_ref_autogen, f"\n\n{_time}")
### qa
_qa, _tc = qa_with_context_as_mio_textgen(_query, _context, _textgen_url)
write_log_ans(_log_ans_autogen, f"生成的最终答案:\n\n{_qa}", 'done')
end_time = time.time()
execution_time = end_time - start_time
_time = f"Time: {execution_time:.1f} seconds"
write_log_ref(_log_ref_autogen, f"\n\n{_time}\n\n{'='*40}\n\n{_context}")
return
def rag_helper(_query, _radio):
if _radio == "快速":
logger_rag.info("[快速]")
rag_helper_fast(_query)
if _radio == "多智":
logger_rag.info("[多智]")
rag_helper_autogen_textgen(_query)