From 694e8992f4a2800664a26544845e2c4e0a075ebe Mon Sep 17 00:00:00 2001 From: Varun Date: Tue, 30 Jan 2024 17:28:06 +0530 Subject: [PATCH] Added support for google gemini --- webui/webui/state.py | 60 +++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 4 deletions(-) diff --git a/webui/webui/state.py b/webui/webui/state.py index 21a8e37..8814494 100644 --- a/webui/webui/state.py +++ b/webui/webui/state.py @@ -3,6 +3,7 @@ import json import openai import reflex as rx +import google.generativeai as genai openai.api_key = os.getenv("OPENAI_API_KEY") openai.api_base = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") @@ -10,9 +11,12 @@ BAIDU_API_KEY = os.getenv("BAIDU_API_KEY") BAIDU_SECRET_KEY = os.getenv("BAIDU_SECRET_KEY") +GOOGLE_API_KEY = os.getenv('GOOGLE_API_KEY') +genai.configure(api_key=GOOGLE_API_KEY) -if not openai.api_key and not BAIDU_API_KEY: - raise Exception("Please set OPENAI_API_KEY or BAIDU_API_KEY") + +if not openai.api_key and not BAIDU_API_KEY and not GOOGLE_API_KEY: + raise Exception("Please set OPENAI_API_KEY or BAIDU_API_KEY or GOOGLE_API_KEY") def get_access_token(): @@ -64,7 +68,10 @@ class State(rx.State): # Whether the modal is open. modal_open: bool = False - api_type: str = "baidu" if BAIDU_API_KEY else "openai" + api_type: str = ( + "baidu" if BAIDU_API_KEY else "openai" if openai.api_key else "genai" +) + def create_chat(self): """Create a new chat.""" @@ -119,8 +126,10 @@ async def process_question(self, form_data: dict[str, str]): if self.api_type == "openai": model = self.openai_process_question - else: + elif self.api_type == "baidu": model = self.baidu_process_question + else: + model = self.gemini_process_question async for value in model(question): yield value @@ -208,3 +217,46 @@ async def baidu_process_question(self, question: str): yield # Toggle the processing flag. self.processing = False + + + + async def gemini_process_question(self, question: str): + """Get the response from the Gemini API.""" + + # Add the question to the list of questions. + qa = QA(question=question, answer="") + self.chats[self.current_chat].append(qa) + + # Clear the input and start the processing. + self.processing = True + yield + + # Build the messages. + messages = [ + {"role": "system", "content": "You are a friendly chatbot named Reflex."} + ] + for qa in self.chats[self.current_chat]: + messages.append({"role": "user", "content": qa.question}) + messages.append({"role": "assistant", "content": qa.answer}) + + # Remove the last mock answer. + messages = messages[:-1] + + # Configure Gemini API + genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) + model = genai.GenerativeModel("gemini-pro") + + # Get the response from Gemini + response = model.generate_content( + contents="\n".join([message["content"] for message in messages]) + ) + answer_text = response.text + + # Update the chat with the answer + self.chats[self.current_chat][-1].answer += answer_text + self.chats = self.chats # Trigger reactivity + yield + + # Toggle the processing flag. + self.processing = False +