-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathrtx_api_3_5.py
67 lines (58 loc) · 2.19 KB
/
rtx_api_3_5.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import requests
import random
import string
import psutil
import json
port = None
def find_chat_with_rtx_port():
global port
connections = psutil.net_connections(kind='inet')
for host in connections:
try:
if host.pid:
process = psutil.Process(host.pid)
if "ChatWithRTX" in process.exe():
test_port = host.laddr.port
url = f"http://127.0.0.1:{test_port}/queue/join"
response = requests.post(url, data="", timeout=0.05)
if response.status_code == 422:
port = test_port
return
except:
pass
def join_queue(session_hash, fn_index, port, chatdata):
#fn_indexes are some gradio generated indexes from rag/trt/ui/user_interface.py
python_object = {
"data": chatdata,
"event_data": None,
"fn_index": fn_index,
"session_hash": session_hash
}
json_string = json.dumps(python_object)
url = f"http://127.0.0.1:{port}/queue/join"
response = requests.post(url, data=json_string)
# print("Join Queue Response:", response.json())
def listen_for_updates(session_hash, port):
url = f"http://127.0.0.1:{port}/queue/data?session_hash={session_hash}"
response = requests.get(url, stream=True)
for line in response.iter_lines():
if line:
try:
data = json.loads(line[5:])
# if data['msg'] == 'process_generating':
# print(data['output']['data'][0][0][1])
if data['msg'] == 'process_completed':
return data['output']['data'][0][0][1]
except Exception as e:
pass
return ""
def send_message(message):
if not port:
find_chat_with_rtx_port()
if not port:
raise Exception("Failed to find a server port for 'Chat with RTX'. Ensure the server is running.")
session_hash = ''.join(random.choices(string.ascii_letters + string.digits, k=10))
#add chat history here -v
chatdata = [[[message, None]], None]
join_queue(session_hash, 34, port, chatdata)
return listen_for_updates(session_hash, port)