-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpage_task.py
191 lines (154 loc) · 6.13 KB
/
page_task.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import pandas as pd
import streamlit as st
from dotenv import load_dotenv
from helper import format_milliseconds, get_mysql_session, task_status_icon
from page_task_edit import task_form
from task_cache import TaskCache
from tables import Tasks
from task_count import task_count
from task_metrics import task_metrics
from task_diff import diff_tasks
from task_loads import current_user, is_admin, load_all_requests, load_all_tasks
from logger import logger
load_dotenv()
def task_page(task_id: int):
task = None
session = get_mysql_session()
if is_admin():
task = session.query(Tasks).filter(Tasks.id == task_id).first()
else:
task = (
session.query(Tasks)
.filter(Tasks.id == task_id, Tasks.user_id == current_user().id)
.first()
)
session.close()
if not task:
st.error("task not found")
return
progress_percentage = f"`{task.progress_percentage}%`"
if task.status < 2:
progress_percentage = ""
st.markdown(
f"## {task_status_icon(task.status)} {task.name} `{task.status_text}` {progress_percentage}"
)
if task.status > 1 and task.progress_percentage > 0:
st.progress(task.progress_percentage)
if task.error_message:
st.error(f"{task.error_message}", icon="🚨")
with st.container(border=True):
task_form(task, True)
if task.status > 1:
requests = load_all_requests(task.id)
render_count(task)
render_metrics(task)
render_charts(requests)
render_requests(task, requests, 0, "❌ Failed Requests")
render_requests(task, requests, 1, "✅ Succeed Requests")
if task.status == 4:
diff_tasks_page(task)
def render_count(task):
counts = task_count(task)
if counts:
st.markdown("## 🪧 Overview")
with st.container(border=True):
for count in counts:
st.write(f"{count}: `{counts[count]}`")
st.write(f"Request Failed: `{task.request_failed}`")
st.write(f"Request Succeed: `{task.request_succeed}`")
def diff_tasks_page(current_task: Tasks):
tasks = load_all_tasks()
tasks = [task for task in tasks if task.id != current_task.id]
if len(tasks) > 1:
st.markdown("## 🔰 Diff Tasks")
options = ["NONE"]
for task in tasks:
options.append(f"{task.id} - {task.name} ({task.model_id})")
col1, col2 = st.columns(2)
with col1:
task_selected = st.selectbox(
f"Select task to compare with `{current_task.name}`",
options,
index=0,
)
with col2:
compare_field = st.selectbox(
"Select field to compare",
["first_token_latency_ms", "request_latency_ms"],
index=0,
)
if task_selected:
if task_selected != "NONE":
task_selected_id = int(task_selected.split(" - ")[0])
with st.spinner("Comparing tasks..."):
diff_tasks(current_task.id, task_selected_id, compare_field)
def render_charts(requests):
requests = [request for request in requests if request.success == 1]
if len(requests) > 0:
first_token_latency_ms_array = []
request_latency_ms_array = []
chunks_count_array = []
for request in requests:
first_token_latency_ms_array.append(request.first_token_latency_ms)
request_latency_ms_array.append(request.request_latency_ms)
chunks_count_array.append(
(request.chunks_count, request.output_token_count)
)
if len(first_token_latency_ms_array) > 0 and len(chunks_count_array) > 0:
st.markdown("## 📉 Charts")
if len(first_token_latency_ms_array) > 0:
with st.container(border=True):
st.markdown("#### First Token Latency")
st.line_chart(
pd.DataFrame(
first_token_latency_ms_array,
columns=["First Token Latency"],
)
)
if len(request_latency_ms_array) > 0:
with st.container(border=True):
st.markdown("#### Request Latency")
st.line_chart(
pd.DataFrame(request_latency_ms_array, columns=["Request Latency"])
)
if len(chunks_count_array) > 0:
with st.container(border=True):
st.markdown("#### Chunks Count / Output Token Count")
st.bar_chart(
pd.DataFrame(
chunks_count_array,
columns=["Chunks Count", "Output Token Count"],
)
)
def render_metrics(task):
"""Display task metrics and queue information."""
with st.spinner(text="Loading Report..."):
try:
data = task_metrics(task)
df = pd.DataFrame.from_dict(data, orient="index")
cache = TaskCache()
queue_len = cache.len()
st.markdown("## 📊 Metrics")
st.markdown(f"Name: `{task.name}`")
if queue_len > 0:
st.markdown(
f"`{queue_len}` chunks in queue, please wait them to finish and refresh report."
)
st.table(df)
except Exception as e:
st.error(e)
def render_requests(task, requests, status, title):
try:
requests = [request for request in requests if request.success == status]
count = len(requests)
if count > 0:
st.markdown(f"## {title} ({count})")
with st.container(border=True, height=450 if len(requests) > 10 else None):
for request in requests:
st.markdown(
f'`{format_milliseconds(request.start_req_time)}` {request.id} | {request.output_token_count} <a href="/?request_id={request.id}&task_id={task.id}" target="_blank">👀 Log</a>',
unsafe_allow_html=True,
)
except Exception as e:
logger.error(e)
st.error(e)