Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for google gemini #25

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions webui/webui/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,20 @@
import json
import openai
import reflex as rx
import google.generativeai as genai

openai.api_key = os.getenv("OPENAI_API_KEY")
openai.api_base = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")

BAIDU_API_KEY = os.getenv("BAIDU_API_KEY")
BAIDU_SECRET_KEY = os.getenv("BAIDU_SECRET_KEY")

GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY')
genai.configure(api_key=GOOGLE_API_KEY)

if not openai.api_key and not BAIDU_API_KEY:
raise Exception("Please set OPENAI_API_KEY or BAIDU_API_KEY")

if not openai.api_key and not BAIDU_API_KEY and not GOOGLE_API_KEY:
raise Exception("Please set OPENAI_API_KEY or BAIDU_API_KEY or GOOGLE_API_KEY")


def get_access_token():
Expand Down Expand Up @@ -64,7 +68,10 @@ class State(rx.State):
# Whether the modal is open.
modal_open: bool = False

api_type: str = "baidu" if BAIDU_API_KEY else "openai"
api_type: str = (
"baidu" if BAIDU_API_KEY else "openai" if openai.api_key else "genai"
)


def create_chat(self):
"""Create a new chat."""
Expand Down Expand Up @@ -119,8 +126,10 @@ async def process_question(self, form_data: dict[str, str]):

if self.api_type == "openai":
model = self.openai_process_question
else:
elif self.api_type == "baidu":
model = self.baidu_process_question
else:
model = self.gemini_process_question

async for value in model(question):
yield value
Expand Down Expand Up @@ -208,3 +217,46 @@ async def baidu_process_question(self, question: str):
yield
# Toggle the processing flag.
self.processing = False



async def gemini_process_question(self, question: str):
"""Get the response from the Gemini API."""

# Add the question to the list of questions.
qa = QA(question=question, answer="")
self.chats[self.current_chat].append(qa)

# Clear the input and start the processing.
self.processing = True
yield

# Build the messages.
messages = [
{"role": "system", "content": "You are a friendly chatbot named Reflex."}
]
for qa in self.chats[self.current_chat]:
messages.append({"role": "user", "content": qa.question})
messages.append({"role": "assistant", "content": qa.answer})

# Remove the last mock answer.
messages = messages[:-1]

# Configure Gemini API
genai.configure(api_key=os.getenv("GOOGLE_API_KEY"))
model = genai.GenerativeModel("gemini-pro")

# Get the response from Gemini
response = model.generate_content(
contents="\n".join([message["content"] for message in messages])
)
answer_text = response.text

# Update the chat with the answer
self.chats[self.current_chat][-1].answer += answer_text
self.chats = self.chats # Trigger reactivity
yield

# Toggle the processing flag.
self.processing = False