Skip to content

Commit

Permalink
Adding the updates to the backend-api and frontend
Browse files Browse the repository at this point in the history
  • Loading branch information
IshmeetMehta committed Nov 11, 2024
1 parent bfc4da4 commit d181e9a
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 346 deletions.
179 changes: 0 additions & 179 deletions use-cases/rag-on-gke/frontend/src/interface.1.py

This file was deleted.

151 changes: 94 additions & 57 deletions use-cases/rag-on-gke/frontend/src/interface.py
Original file line number Diff line number Diff line change
@@ -1,73 +1,110 @@
import gradio as gr
import re
import requests

# URLs of your custom LLMs (replace with actual URLs)
MODEL_URLS = {
"modelA": "http://104.198.185.85:8000/v1/chat/completions",
# "modelB": "YOUR_MODEL_B_API_URL",
# Add more models as needed
}

TEXT_EMBEDDING_ENDPOINT = "http://0.0.0.0:8000/embeddings/text"
MULTIMODAL_EMBEDDING_ENDPOINT = "http://0.0.0.0:8000/embeddings/multimodal"

def chatbot(message, history):
history = history or []
history.append([message, None])

# Get the selected model's URL
# model_url = MODEL_URLS.get(model)
# if not model_url:
# raise ValueError(f"Invalid model selected: {model}")
# Function to validate GCS URI
def validate_gcs_uri(uri):
"""
Validates if the provided URI is a valid GCS URI.
# Construct the prompt for the selected model
# (Adjust based on your API's requirements)
prompt = ""
for user_msg, bot_msg in history:
prompt += f"Customer: {user_msg}\nRetail Bot: {bot_msg}\n"
prompt += f"Customer: {message}\nRetail Bot: "
Args:
uri: The URI string to validate.
# model_url = "http://34.171.174.67:8000/v1/chat/completions"
model_url = MODEL_URLS.get("modelA")
model_name = "/data/models/model-gemma2-a100/experiment-a2aa2c3it1"
Returns:
True if the URI is valid, False otherwise.
"""
pattern = (
r"^gs://[a-z0-9][a-z0-9._-]{1,253}/[a-zA-Z0-9_.!@#$%^&*()/-]+[a-zA-Z0-9]+$"
)
return bool(re.match(pattern, uri))


# Function to process text input
def process_text(text):
"""
Processes the text input (prompt) to generate embeddings.
Args:
text: The input prompt.
Returns:
Embeddings as a list of floats.
"""
# Replace this with your actual embedding generation logic
# This is a placeholder example
# embeddings = [0.1, 0.2, 0.3, 0.4, 0.5]
embeddings = []

# Send request to the selected model API
response = requests.post(
model_url,
headers={"content-type": "application/json"},
timeout=100,
json={
"model": model_name,
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.70,
"top_p": 1.0,
"top_k": 1.0,
"max_tokens": 256,
},
stream=False,
TEXT_EMBEDDING_ENDPOINT,
json={"text": text},
headers={"Content-Type": "application/json"},
timeout=1000,
)
response.raise_for_status()
embeddings = response.json()["text_embeds"]

return embeddings


# Function to process text and image URI input
def process_text_image(text, image_uri):
"""
Processes the text prompt and image URI to generate embeddings.
# Extract the generated response
# (Adjust based on your API's response format)
try:
bot_message = response.json()["choices"][0]["message"]["content"]
print(bot_message)
except KeyError:
raise ValueError("Invalid response format from the model API")
Args:
text: The input prompt.
image_uri: The GCS URI of the image.
history[-1][1] = bot_message
return history, history
Returns:
Embeddings as a list of floats.
"""
if not validate_gcs_uri(image_uri): # Validate the URI
return "Invalid GCS URI provided." # Return an error message

# Replace this with your actual embedding generation logic
# This is a placeholder example
# embeddings = [0.6, 0.7, 0.8, 0.9, 1.0]
embeddings = []

response = requests.post(
MULTIMODAL_EMBEDDING_ENDPOINT,
json={"text": text, "image_uri": image_uri},
headers={"Content-Type": "application/json"},
timeout=1000,
)
response.raise_for_status()
embeddings = response.json()["multimodal_embeds"]

return embeddings


# Create the Gradio interface
iface = gr.Interface(
fn=chatbot,
inputs=["text"],
outputs=[
gr.Chatbot(label="Retail Chatbot"),
gr.Chatbot(label="Conversation History"),
],
title="Retail Customer Chatbot",
description="Ask your retail questions here!",
theme=gr.themes.Soft(),
)

iface.launch(debug=True)
with gr.Blocks() as demo:
gr.Markdown("## Retail Chatbot with Embeddings")

with gr.Tab("Text Prompt"):
text_input = gr.Textbox(lines=5, label="Enter your prompt")
text_output = gr.Textbox(label="Embeddings")
text_button = gr.Button("Generate Embeddings")
text_button.click(fn=process_text, inputs=text_input, outputs=text_output)

with gr.Tab("Text Prompt + Image"):
text_input_2 = gr.Textbox(lines=5, label="Enter your prompt")
image_uri_input = gr.Textbox(label="Enter GCS image URI")
image_uri_output = gr.Textbox(label="Embeddings")
image_uri_button = gr.Button("Generate Embeddings")
image_uri_button.click(
fn=process_text_image,
inputs=[text_input_2, image_uri_input],
outputs=image_uri_output,
)

# Launch the demo
if __name__ == "__main__":
demo.launch()
Loading

0 comments on commit d181e9a

Please sign in to comment.