-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat_ui.py
195 lines (167 loc) · 6.3 KB
/
chat_ui.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
"""Chat UI functions for the Streamlit interface."""
from typing import Literal
import streamlit as st
from streamlit.delta_generator import DeltaGenerator
import chat_data
from llm import respond
from llm_functions import call_function, functions
def show_existing_messages() -> None:
"""Shows existing messages."""
for msg in st.session_state.chat.messages:
if msg["role"] in ["user", "assistant"] and msg["content"]:
show_message(msg["role"], msg["content"])
elif msg["role"] == "function":
show_message(msg["role"], msg["content"], name=msg["name"])
def show_message(
role: str,
content: str,
content_placeholder: DeltaGenerator | None = None,
name: str | None = None,
state: Literal["running", "complete", "error"] = "complete",
) -> None:
"""Shows a message."""
if role == "function":
content_placeholder = show_function_message(
content, content_placeholder, name, state
)
elif role == "assistant":
content_placeholder = show_assistant_message(
content, content_placeholder
)
elif role == "user":
content_placeholder = show_user_message(content, content_placeholder)
else:
raise ValueError(f"Invalid role: {role}")
return content_placeholder
def show_function_message(
content: str | None = None,
content_placeholder: DeltaGenerator | None = None,
name: str | None = "function",
state: Literal["running", "complete", "error"] = "complete",
) -> DeltaGenerator:
"""Shows a function message."""
if not content_placeholder:
content_placeholder = st.status(label=name, state=state)
if state == "complete":
content_placeholder.update(state=state)
with content_placeholder:
# Streamlit markdown ignores \n. Mainly an issue with function
# content containing \n. So just adding this here right now:
content = content.replace("\n", "<br/>") if content else None
st.markdown(content, unsafe_allow_html=True)
return content_placeholder
def show_assistant_message(
content: str | None, content_placeholder: DeltaGenerator | None
) -> DeltaGenerator:
"""Shows an assistant message."""
if not content_placeholder:
with st.chat_message("assistant"):
content_placeholder = st.empty()
content_placeholder.markdown(content)
return content_placeholder
def show_user_message(
content: str | None, content_placeholder: DeltaGenerator | None
) -> DeltaGenerator:
"""Shows a user message."""
if not content_placeholder:
with st.chat_message("user"):
content_placeholder = st.empty()
content_placeholder.markdown(content)
return content_placeholder
def run_response_loop():
"""Runs the response loop."""
assistant_responded = False
while not assistant_responded:
assistant_content = ""
assistant_content_placeholder = None
function_call = {"name": "", "arguments": ""}
function_content_placeholder = None
response = respond(
model="gpt-4",
functions=functions,
messages=st.session_state.chat.messages,
stream=True,
)
for chunk in response:
choice = chunk["choices"][0]
if "delta" in choice and "function_call" in choice.delta:
function_content_placeholder = stream_function_call(
function_call,
choice.delta.get("function_call", {}),
function_content_placeholder,
)
elif "delta" in choice and "content" in choice.delta:
(
assistant_content,
assistant_content_placeholder,
) = stream_assistant_content(
assistant_content,
choice.delta.get("content", ""),
assistant_content_placeholder,
)
if choice["finish_reason"] == "function_call":
function_content_placeholder = handle_complete_function_call(
function_call,
function_content_placeholder,
) # Resets function_placeholder to None
break
elif choice["finish_reason"] == "stop":
assistant_responded = handle_complete_assistant_content(
assistant_content,
assistant_content_placeholder,
) # Sets assistant_responded to True to exit while loop
break
def stream_function_call(
function_call: dict,
function_call_delta: dict,
content_placeholder: DeltaGenerator | None,
) -> None:
"""Builds function call from stream. Returns placeholder."""
function_call["name"] += function_call_delta.get("name", "")
function_call["arguments"] += function_call_delta.get("arguments", "")
if len(function_call["arguments"]) > 0:
content_placeholder = show_message(
"function",
None,
content_placeholder,
function_call["name"],
"running",
)
return content_placeholder
def stream_assistant_content(
content: str,
content_delta: str,
content_placeholder: DeltaGenerator | None,
) -> tuple[DeltaGenerator, str]:
"""Streams content to placeholder. Returns content and placeholder."""
content += content_delta
content_placeholder = show_message(
"assistant", content, content_placeholder
)
return content, content_placeholder
def handle_complete_function_call(
function_call: dict,
content_placeholder: DeltaGenerator,
) -> None:
"""Handles a function call."""
chat_data.add_message("assistant", None, function_call)
function_content = call_function(**function_call)
chat_data.add_message(
"function", function_content, name=function_call["name"]
)
show_message(
"function",
function_content,
content_placeholder,
function_call["name"],
"complete",
)
return None
def handle_complete_assistant_content(
content: str,
content_placeholder: DeltaGenerator,
) -> bool:
"""Handles completed assistant content."""
chat_data.add_message("assistant", content)
show_message("assistant", content, content_placeholder)
return True