-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding the updates to the backend-api and frontend
- Loading branch information
1 parent
bfc4da4
commit d181e9a
Showing
3 changed files
with
94 additions
and
346 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,73 +1,110 @@ | ||
import gradio as gr | ||
import re | ||
import requests | ||
|
||
# URLs of your custom LLMs (replace with actual URLs) | ||
MODEL_URLS = { | ||
"modelA": "http://104.198.185.85:8000/v1/chat/completions", | ||
# "modelB": "YOUR_MODEL_B_API_URL", | ||
# Add more models as needed | ||
} | ||
|
||
TEXT_EMBEDDING_ENDPOINT = "http://0.0.0.0:8000/embeddings/text" | ||
MULTIMODAL_EMBEDDING_ENDPOINT = "http://0.0.0.0:8000/embeddings/multimodal" | ||
|
||
def chatbot(message, history): | ||
history = history or [] | ||
history.append([message, None]) | ||
|
||
# Get the selected model's URL | ||
# model_url = MODEL_URLS.get(model) | ||
# if not model_url: | ||
# raise ValueError(f"Invalid model selected: {model}") | ||
# Function to validate GCS URI | ||
def validate_gcs_uri(uri): | ||
""" | ||
Validates if the provided URI is a valid GCS URI. | ||
# Construct the prompt for the selected model | ||
# (Adjust based on your API's requirements) | ||
prompt = "" | ||
for user_msg, bot_msg in history: | ||
prompt += f"Customer: {user_msg}\nRetail Bot: {bot_msg}\n" | ||
prompt += f"Customer: {message}\nRetail Bot: " | ||
Args: | ||
uri: The URI string to validate. | ||
# model_url = "http://34.171.174.67:8000/v1/chat/completions" | ||
model_url = MODEL_URLS.get("modelA") | ||
model_name = "/data/models/model-gemma2-a100/experiment-a2aa2c3it1" | ||
Returns: | ||
True if the URI is valid, False otherwise. | ||
""" | ||
pattern = ( | ||
r"^gs://[a-z0-9][a-z0-9._-]{1,253}/[a-zA-Z0-9_.!@#$%^&*()/-]+[a-zA-Z0-9]+$" | ||
) | ||
return bool(re.match(pattern, uri)) | ||
|
||
|
||
# Function to process text input | ||
def process_text(text): | ||
""" | ||
Processes the text input (prompt) to generate embeddings. | ||
Args: | ||
text: The input prompt. | ||
Returns: | ||
Embeddings as a list of floats. | ||
""" | ||
# Replace this with your actual embedding generation logic | ||
# This is a placeholder example | ||
# embeddings = [0.1, 0.2, 0.3, 0.4, 0.5] | ||
embeddings = [] | ||
|
||
# Send request to the selected model API | ||
response = requests.post( | ||
model_url, | ||
headers={"content-type": "application/json"}, | ||
timeout=100, | ||
json={ | ||
"model": model_name, | ||
"messages": [{"role": "user", "content": prompt}], | ||
"temperature": 0.70, | ||
"top_p": 1.0, | ||
"top_k": 1.0, | ||
"max_tokens": 256, | ||
}, | ||
stream=False, | ||
TEXT_EMBEDDING_ENDPOINT, | ||
json={"text": text}, | ||
headers={"Content-Type": "application/json"}, | ||
timeout=1000, | ||
) | ||
response.raise_for_status() | ||
embeddings = response.json()["text_embeds"] | ||
|
||
return embeddings | ||
|
||
|
||
# Function to process text and image URI input | ||
def process_text_image(text, image_uri): | ||
""" | ||
Processes the text prompt and image URI to generate embeddings. | ||
# Extract the generated response | ||
# (Adjust based on your API's response format) | ||
try: | ||
bot_message = response.json()["choices"][0]["message"]["content"] | ||
print(bot_message) | ||
except KeyError: | ||
raise ValueError("Invalid response format from the model API") | ||
Args: | ||
text: The input prompt. | ||
image_uri: The GCS URI of the image. | ||
history[-1][1] = bot_message | ||
return history, history | ||
Returns: | ||
Embeddings as a list of floats. | ||
""" | ||
if not validate_gcs_uri(image_uri): # Validate the URI | ||
return "Invalid GCS URI provided." # Return an error message | ||
|
||
# Replace this with your actual embedding generation logic | ||
# This is a placeholder example | ||
# embeddings = [0.6, 0.7, 0.8, 0.9, 1.0] | ||
embeddings = [] | ||
|
||
response = requests.post( | ||
MULTIMODAL_EMBEDDING_ENDPOINT, | ||
json={"text": text, "image_uri": image_uri}, | ||
headers={"Content-Type": "application/json"}, | ||
timeout=1000, | ||
) | ||
response.raise_for_status() | ||
embeddings = response.json()["multimodal_embeds"] | ||
|
||
return embeddings | ||
|
||
|
||
# Create the Gradio interface | ||
iface = gr.Interface( | ||
fn=chatbot, | ||
inputs=["text"], | ||
outputs=[ | ||
gr.Chatbot(label="Retail Chatbot"), | ||
gr.Chatbot(label="Conversation History"), | ||
], | ||
title="Retail Customer Chatbot", | ||
description="Ask your retail questions here!", | ||
theme=gr.themes.Soft(), | ||
) | ||
|
||
iface.launch(debug=True) | ||
with gr.Blocks() as demo: | ||
gr.Markdown("## Retail Chatbot with Embeddings") | ||
|
||
with gr.Tab("Text Prompt"): | ||
text_input = gr.Textbox(lines=5, label="Enter your prompt") | ||
text_output = gr.Textbox(label="Embeddings") | ||
text_button = gr.Button("Generate Embeddings") | ||
text_button.click(fn=process_text, inputs=text_input, outputs=text_output) | ||
|
||
with gr.Tab("Text Prompt + Image"): | ||
text_input_2 = gr.Textbox(lines=5, label="Enter your prompt") | ||
image_uri_input = gr.Textbox(label="Enter GCS image URI") | ||
image_uri_output = gr.Textbox(label="Embeddings") | ||
image_uri_button = gr.Button("Generate Embeddings") | ||
image_uri_button.click( | ||
fn=process_text_image, | ||
inputs=[text_input_2, image_uri_input], | ||
outputs=image_uri_output, | ||
) | ||
|
||
# Launch the demo | ||
if __name__ == "__main__": | ||
demo.launch() |
Oops, something went wrong.