-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
266 lines (229 loc) · 12.1 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
import datetime
from flask import jsonify
from simpleaichat import AIChat
from env_loader import get_env
from utils.openai_helper import initialize_openai, moderate_content, num_tokens_from_string, generate_image
from utils.elevenlabs import get_voices_data, text_to_speech
from utils.misc import generate_unique_card_id, get_docs
# Load environment variables
OPENAI_API_KEY = get_env("OPENAI_API_KEY")
MODEL_NAME = get_env("MODEL_NAME")
SYSTEM_PROMPT = get_env("SYSTEM_PROMPT")
PROMPT_PREFIX = get_env("PROMPT_PREFIX")
PROMPT_SUFFIX = get_env("PROMPT_SUFFIX")
MAX_TURNS = get_env("MAX_TURNS")
TTL = get_env("TTL")
MAX_TOKENS_INPUT = get_env("MAX_TOKENS_INPUT")
MAX_TOKENS_OUTPUT = get_env("MAX_TOKENS_OUTPUT")
TEMPERATURE = get_env("TEMPERATURE")
IMAGE_SIZE = get_env("IMAGE_SIZE")
IMAGE_STYLE = get_env("IMAGE_STYLE")
IMAGE_QUALITY = get_env("IMAGE_QUALITY")
DALLE_MODEL = get_env("DALLE_MODEL")
API_URL = get_env("API_URL")
ELEVENLABS_API_KEY = get_env("ELEVENLABS_API_KEY")
# Initialize OpenAI parameters
params = initialize_openai(OPENAI_API_KEY, TEMPERATURE, MAX_TOKENS_OUTPUT)
enable_moderation = get_env("MODERATION")
# Define globals
user_sessions = {} # A dictionary to track the AIChat instances for each user
turn_counts = {} # A dictionary to track the turn count for each user
last_received_times = {} # A dictionary to track the last received time for each user
def process_event(request):
try:
event = request.get_json()
event_type = event['type']
user_id = event['user']['name']
if event_type == 'ADDED_TO_SPACE':
return jsonify({'text': 'Hello! I am your Chat bot. How can I assist you today?'})
elif event_type == 'MESSAGE':
message = event['message']
user_message = message['text']
# Check if the bot was mentioned in the room, if so, remove the mention
if 'annotations' in message:
for annotation in message['annotations']:
if annotation['type'] == 'USER_MENTION':
if annotation['userMention']['user']['name'] == event['space']['name']:
user_message = user_message.replace(annotation['userMention']['text'], '').strip()
return handle_message(user_id, user_message)
else:
return jsonify({'text': 'Sorry, I can only process messages and being added to a space.'})
except Exception as e:
print(f"Error processing event: {str(e)}")
return jsonify({'text': 'Sorry, I encountered an error while processing your message.'})
def handle_message(user_id, user_message):
try:
# Check the user input for any policy violations
if enable_moderation == "True":
moderation_result = moderate_content(user_message)
if moderation_result["flagged"]:
return jsonify({'text': 'Sorry, your message does not comply with our content policy. Please refrain from inappropriate content.'})
current_time = datetime.datetime.now()
# Get the AIChat instance for the user, or create a new one
ai_chat = user_sessions.get(user_id)
turn_count = turn_counts.get(user_id, 0)
last_received_time = last_received_times.get(user_id)
# Count the tokens in the user message
num_tokens = num_tokens_from_string(user_message + SYSTEM_PROMPT)
# If the user types '/reset', reset the session
if user_message.strip().lower() == '/reset':
if user_id in user_sessions:
del user_sessions[user_id] # Delete the user's chat session if it exists
turn_count = 0
bot_message = "Your session has been reset. How can I assist you now?"
# Check if the user input starts with /image
elif user_message.strip().lower().startswith('/image'):
prompt = user_message.split('/image', 1)[1].strip()
if not prompt:
return jsonify({'text': 'Please provide a prompt for the image generation. Example: `/image sunset over a beach`.'})
model = DALLE_MODEL
style = IMAGE_STYLE
quality = IMAGE_QUALITY
try:
image_resp = generate_image(
prompt=prompt,
n=1,
size=IMAGE_SIZE,
model=model,
style=style,
quality=quality,
user=user_id
)
image_url = image_resp["data"][0]["url"]
return jsonify({
'text': 'Processing your image request...',
'cardsV2': [
{
'cardId': generate_unique_card_id(),
'card': {
'header': {
'title': 'Generated Image',
'subtitle': prompt,
},
'sections': [
{
'widgets': [
{
'image': {
'imageUrl': image_url,
'onClick': {
'openLink': {
'url': image_url
}
}
}
}
]
}
]
}
}
]
})
except Exception as e:
print(f"Error generating image: {str(e)}"); return jsonify({'text': "Sorry, I encountered an internal error generating the image. Please try again later."})
# Check if the user input starts with /voice (assuming you meant /voices)
elif user_message.strip().lower() == '/voices':
if not ELEVENLABS_API_KEY:
return jsonify({'text': 'This function is disabled.'})
voices_data, error = get_voices_data()
if error:
print(f"Error: {error}"); return jsonify({'text': "An internal error has occurred. Please try again later."})
voice_names_list = list(voices_data.keys())
# Join voice names with commas and spaces for readability
voices_string = ', '.join(voice_names_list)
return jsonify({'text': f"Available voices: {voices_string}"})
# Check if the user input starts with /tts
elif user_message.strip().lower().startswith('/tts'):
if not ELEVENLABS_API_KEY:
return jsonify({'text': 'This function is disabled.'})
parts = user_message.split(' ')
if len(parts) < 3: # Checking for /tts, voice, and message
return jsonify({'text': 'Please use the format /tts <voice> <message>.'})
voice = parts[1].lower()
voices_data_dict, error = get_voices_data()
if error:
print(f"Error: {error}"); return jsonify({'text': "An internal error has occurred. Please try again later."})
if voice not in voices_data_dict:
return jsonify({'text': f"Sorry, I couldn't recognize the voice {voice}. Please choose a valid voice."})
prompt = ' '.join(parts[2:])
audio_url, error = text_to_speech(prompt, voice)
if audio_url:
# Return a card with the audio link in a button
return jsonify({
'text': 'Processing your TTS request...',
'cardsV2': [
{
'cardId': generate_unique_card_id(),
'card': {
'header': {
'title': 'Generated Audio',
'subtitle': 'Click to Play Audio'
},
'sections': [
{
'collapsible': False,
'uncollapsibleWidgetsCount': 1,
'widgets': [
{
'buttonList': {
'buttons': [
{
'text': 'Play ▶️',
'onClick': {
'openLink': {
'url': audio_url
}
}
}
]
}
}
]
}
]
}
}
]
})
else:
print(f"Error generating audio: {error}"); return jsonify({'text': "Sorry, I encountered an internal error generating the audio. Please try again later."})
# Check if the user input starts with /help
elif user_message.strip().lower() == '/help':
help_content = get_docs("usage/help")
return jsonify({'text': help_content})
# If the message is too large, return an error message
elif num_tokens > MAX_TOKENS_INPUT:
return jsonify({'text': 'Sorry, your message is too large. Please try a shorter message.'})
# If it's not a slash command, handle it normally
else:
if ai_chat is None or turn_count >= MAX_TURNS or (last_received_time is not None and (current_time - last_received_time).total_seconds() > TTL):
if API_URL:
ai_chat = AIChat(api_key=None, api_url=API_URL, system=SYSTEM_PROMPT, params=params)
else:
ai_chat = AIChat(api_key=OPENAI_API_KEY, system=SYSTEM_PROMPT, model=MODEL_NAME, params=params)
user_sessions[user_id] = ai_chat
turn_count = 0
# Generate the response
if API_URL:
local_user_message = f"{PROMPT_PREFIX}{user_message}{PROMPT_SUFFIX}"
response = ai_chat(local_user_message)
else:
response = ai_chat(user_message)
# Ensure the response is less than 4096 characters
if len(response) > 4096:
response = response[:4070] + "<MESSAGE TRUNCATED>" # truncated to leave space for the appended message
# Check the API output for any policy violations
if enable_moderation == "True":
moderation_result = moderate_content(user_message)
if moderation_result["flagged"]:
return jsonify({'text': 'Sorry, your message does not comply with our content policy. Please refrain from inappropriate content.'})
bot_message = response
# Update the turn count and the last received time
turn_count += 1
turn_counts[user_id] = turn_count
last_received_times[user_id] = current_time
except Exception as e:
print(f"Error calling OpenAI API: {str(e)}")
bot_message = "Sorry, I'm currently unable to generate a response."
return jsonify({'text': bot_message})