forked from patil-suraj/question_generation
-
Notifications
You must be signed in to change notification settings - Fork 4
/
server.py
120 lines (86 loc) · 3.21 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from flask import Flask, request, send_file, render_template
from werkzeug.utils import secure_filename
from queue import Queue, Empty
import time
import threading
from pipelines import pipeline
import pandas as pd
app = Flask(__name__, template_folder='templates')
app.config['MAX_CONTENT_LENGTH'] = 1024 * 1024
requests_queue = Queue()
BATCH_SIZE = 1
CHECK_INTERVAL = 0.1
#preload model
nlp = pipeline("multitask-qa-qg")
qg = pipeline("e2e-qg")
def handle_requests_by_batch():
while True:
requests_batch = []
while not (len(requests_batch) >= BATCH_SIZE):
try:
requests_batch.append(requests_queue.get(timeout=CHECK_INTERVAL))
except Empty:
continue
batch_outputs = []
for request in requests_batch:
batch_outputs.append(run(request['input'][0]))
for request, output in zip(requests_batch, batch_outputs):
request['output'] = output
threading.Thread(target=handle_requests_by_batch).start()
def run(input_text):
try:
generated_text = nlp(input_text)
generated_q = qg(input_text)
df = pd.DataFrame(generated_text)
except ValueError:
result = 'error'
return result
return [df, generated_q]
# Web server
@app.route('/', methods=['GET', 'POST'])
@app.route('/index', methods=['GET', 'POST'])
def upload_file():
if request.method == 'POST':
input_text = str(request.form['input'])
if len(input_text) == 0:
return render_template('index.html', error = 'No Input'), 400
if requests_queue.qsize() >= BATCH_SIZE:
return render_template('index.html', error = 'TooMany requests. please try again'), 429
req = {
'input': [input_text]
}
requests_queue.put(req)
while 'output' not in req:
time.sleep(CHECK_INTERVAL)
if req['output'] == 'error':
return render_template('index.html', error = 'Invalid text. please try again.'), 400
[df, generated_q] = req['output']
return render_template('index.html', result=[df.to_html(classes='data')], titles=df.columns.values, question=generated_q, input_text=input_text)
return render_template('index.html')
# API server
@app.route('/generate', methods=['POST'])
def generate_q():
if request.method == 'POST':
input_text = str(request.form['input'])
if len(input_text) == 0:
return 'No input', 400
if requests_queue.qsize() >= BATCH_SIZE:
return {'error': 'TooMany requests. please try again'}, 429
req = {
'input': [input_text]
}
requests_queue.put(req)
while 'output' not in req:
time.sleep(CHECK_INTERVAL)
if req['output'] == 'error':
return render_template('index.html', error = 'Invalid text. please try again.'), 400
[df, generated_q] = req['output']
df = df.to_dict()
return df
return None
@app.route('/healthz', methods=['GET'])
def checkHealth():
return "Alive",200
if __name__ == '__main__':
app.run(debug=False, port=8080, host='0.0.0.0')