-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
100 lines (74 loc) · 3.15 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from pypdf import PdfReader
import onnxruntime_genai as og
import os
import pre_processing
from pre_processing import embedding_model
base_path = os.getcwd()
model_path = os.path.join(base_path, 'cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4')
model = og.Model(model_path)
tokenizer = og.Tokenizer(model)
tokenizer_stream = tokenizer.create_stream()
# params = og.GeneratorParams(model)
# params.try_graph_capture_with_max_batch_size(1)
def doc_processing(uploaded_pdf,var):
first_section = "abstract"
ignore_after = "references"
reader = PdfReader(uploaded_pdf)
context_list = pre_processing.parese_doc(reader,first_section,ignore_after)
index = pre_processing.create_embedding(context_list)
return {input_box: gr.Textbox(value="Ask a question", visible=True),
state_var:[context_list,index]}
def response_generator(text,var1):
context_list,index = var1
chat_template = '<|user|>\nYou are an Research Assistant. You will provide short and precise answer.<|end|>\n<|assistant|>\nYes I will keep the answer short and precise.<|end|>\n<|user|>\n{input} <|end|>\n<|assistant|>'
search_options ={}
search_options['temperature'] = 1
search_options['max_length'] = 2000
query_embedding = embedding_model.encode(text).reshape(1, -1)
top_k = 1
_scores, binary_ids = index.search(query_embedding, top_k)
binary_ids = binary_ids[0]
_scores = _scores[0]
temp_list = []
for idx in binary_ids:
temp_list.append(context_list[idx])
context = '. '.join(temp_list)
text += " with respect to context: "+context
prompt = f'{chat_template.format(input=text)}'
input_tokens = tokenizer.encode(prompt)
params = og.GeneratorParams(model)
params.try_graph_capture_with_max_batch_size(1)
params.set_search_options(**search_options)
params.input_ids = input_tokens
generator = og.Generator(model, params)
output = ""
while not generator.is_done():
generator.compute_logits()
generator.generate_next_token()
new_token = generator.get_next_tokens()[0]
p_word = tokenizer_stream.decode(new_token)
output+=p_word
yield {output_box:output}
del generator
def submit():
return {input_box: gr.Textbox(visible=True)}
with gr.Blocks() as demo:
gr.Markdown(
"""
# Phi3 3.8B
## RAG - Topic based pdf Q/A
- ***LLM:*** Phi3 Mini
- ***Embedding:*** nomic-embed-text-v1
""")
state_var = gr.State([])
with gr.Row():
upload_button = gr.UploadButton("📁 Upload PDF", file_types=[".pdf"])
error_box = gr.Textbox(label="Error", visible=False)
input_box = gr.Textbox(autoscroll=True,visible=False,label='User')
output_box = gr.Textbox(autoscroll=True,max_lines=30,value="Output",label='Assistant')
gr.Interface(fn=response_generator, inputs=[input_box,state_var], outputs=[output_box,state_var],delete_cache=(20,10))
upload_button.upload(doc_processing,inputs=[upload_button,state_var],outputs=[input_box,state_var],queue=False,show_progress=True,trigger_mode="once")
upload_button.upload(submit,None,input_box)
demo.queue()
demo.launch()