From 227619102e51d9e750e5b7c6b37937a1c5c8d12a Mon Sep 17 00:00:00 2001
From: "P. Taylor Goetz" <tgoetz@monetate.com>
Date: Wed, 10 Apr 2024 15:21:41 -0400
Subject: [PATCH 1/2] Consolidate LLMType into AgentType

---
 backend/app/agent.py | 61 ++++++++++++++++++--------------------------
 1 file changed, 25 insertions(+), 36 deletions(-)

diff --git a/backend/app/agent.py b/backend/app/agent.py
index a633c071a..105213721 100644
--- a/backend/app/agent.py
+++ b/backend/app/agent.py
@@ -64,9 +64,9 @@ class AgentType(str, Enum):
     CLAUDE2 = "Claude 2"
     BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"
     GEMINI = "GEMINI"
+    MIXTRAL = "Mixtral"
     OLLAMA = "Ollama"
 
-
 DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
 
 CHECKPOINTER = PostgresCheckpoint(at=CheckpointAt.END_OF_STEP)
@@ -175,36 +175,25 @@ def __init__(
         )
 
 
-class LLMType(str, Enum):
-    GPT_35_TURBO = "GPT 3.5 Turbo"
-    GPT_4 = "GPT 4 Turbo"
-    AZURE_OPENAI = "GPT 4 (Azure OpenAI)"
-    CLAUDE2 = "Claude 2"
-    BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"
-    GEMINI = "GEMINI"
-    MIXTRAL = "Mixtral"
-    OLLAMA = "Ollama"
-
-
 def get_chatbot(
-    llm_type: LLMType,
+    llm_type: AgentType,
     system_message: str,
 ):
-    if llm_type == LLMType.GPT_35_TURBO:
+    if llm_type == AgentType.GPT_35_TURBO:
         llm = get_openai_llm()
-    elif llm_type == LLMType.GPT_4:
+    elif llm_type == AgentType.GPT_4:
         llm = get_openai_llm(gpt_4=True)
-    elif llm_type == LLMType.AZURE_OPENAI:
+    elif llm_type == AgentType.AZURE_OPENAI:
         llm = get_openai_llm(azure=True)
-    elif llm_type == LLMType.CLAUDE2:
+    elif llm_type == AgentType.CLAUDE2:
         llm = get_anthropic_llm()
-    elif llm_type == LLMType.BEDROCK_CLAUDE2:
+    elif llm_type == AgentType.BEDROCK_CLAUDE2:
         llm = get_anthropic_llm(bedrock=True)
-    elif llm_type == LLMType.GEMINI:
+    elif llm_type == AgentType.GEMINI:
         llm = get_google_llm()
-    elif llm_type == LLMType.MIXTRAL:
+    elif llm_type == AgentType.MIXTRAL:
         llm = get_mixtral_fireworks()
-    elif llm_type == LLMType.OLLAMA:
+    elif llm_type == AgentType.OLLAMA:
         llm = get_ollama_llm()
     else:
         raise ValueError("Unexpected llm type")
@@ -212,14 +201,14 @@ def get_chatbot(
 
 
 class ConfigurableChatBot(RunnableBinding):
-    llm: LLMType
+    llm: AgentType
     system_message: str = DEFAULT_SYSTEM_MESSAGE
     user_id: Optional[str] = None
 
     def __init__(
         self,
         *,
-        llm: LLMType = LLMType.GPT_35_TURBO,
+        llm: AgentType = AgentType.GPT_35_TURBO,
         system_message: str = DEFAULT_SYSTEM_MESSAGE,
         kwargs: Optional[Mapping[str, Any]] = None,
         config: Optional[Mapping[str, Any]] = None,
@@ -238,7 +227,7 @@ def __init__(
 
 
 chatbot = (
-    ConfigurableChatBot(llm=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER)
+    ConfigurableChatBot(llm=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER)
     .configurable_fields(
         llm=ConfigurableField(id="llm_type", name="LLM Type"),
         system_message=ConfigurableField(id="system_message", name="Instructions"),
@@ -248,7 +237,7 @@ def __init__(
 
 
 class ConfigurableRetrieval(RunnableBinding):
-    llm_type: LLMType
+    llm_type: AgentType
     system_message: str = DEFAULT_SYSTEM_MESSAGE
     assistant_id: Optional[str] = None
     thread_id: Optional[str] = None
@@ -257,7 +246,7 @@ class ConfigurableRetrieval(RunnableBinding):
     def __init__(
         self,
         *,
-        llm_type: LLMType = LLMType.GPT_35_TURBO,
+        llm_type: AgentType = AgentType.GPT_35_TURBO,
         system_message: str = DEFAULT_SYSTEM_MESSAGE,
         assistant_id: Optional[str] = None,
         thread_id: Optional[str] = None,
@@ -267,21 +256,21 @@ def __init__(
     ) -> None:
         others.pop("bound", None)
         retriever = get_retriever(assistant_id, thread_id)
-        if llm_type == LLMType.GPT_35_TURBO:
+        if llm_type == AgentType.GPT_35_TURBO:
             llm = get_openai_llm()
-        elif llm_type == LLMType.GPT_4:
+        elif llm_type == AgentType.GPT_4:
             llm = get_openai_llm(gpt_4=True)
-        elif llm_type == LLMType.AZURE_OPENAI:
+        elif llm_type == AgentType.AZURE_OPENAI:
             llm = get_openai_llm(azure=True)
-        elif llm_type == LLMType.CLAUDE2:
+        elif llm_type == AgentType.CLAUDE2:
             llm = get_anthropic_llm()
-        elif llm_type == LLMType.BEDROCK_CLAUDE2:
+        elif llm_type == AgentType.BEDROCK_CLAUDE2:
             llm = get_anthropic_llm(bedrock=True)
-        elif llm_type == LLMType.GEMINI:
+        elif llm_type == AgentType.GEMINI:
             llm = get_google_llm()
-        elif llm_type == LLMType.MIXTRAL:
+        elif llm_type == AgentType.MIXTRAL:
             llm = get_mixtral_fireworks()
-        elif llm_type == LLMType.OLLAMA:
+        elif llm_type == AgentType.OLLAMA:
             llm = get_ollama_llm()
         else:
             raise ValueError("Unexpected llm type")
@@ -296,7 +285,7 @@ def __init__(
 
 
 chat_retrieval = (
-    ConfigurableRetrieval(llm_type=LLMType.GPT_35_TURBO, checkpoint=CHECKPOINTER)
+    ConfigurableRetrieval(llm_type=AgentType.GPT_35_TURBO, checkpoint=CHECKPOINTER)
     .configurable_fields(
         llm_type=ConfigurableField(id="llm_type", name="LLM Type"),
         system_message=ConfigurableField(id="system_message", name="Instructions"),
@@ -319,7 +308,7 @@ def __init__(
         thread_id=None,
     )
     .configurable_fields(
-        agent=ConfigurableField(id="agent_type", name="Agent Type"),
+        agent=ConfigurableField(id="agent_type", name="LLM Type"),
         system_message=ConfigurableField(id="system_message", name="Instructions"),
         interrupt_before_action=ConfigurableField(
             id="interrupt_before_action",

From dd1245a3fcfc6e5f439a49e52695e07254589156 Mon Sep 17 00:00:00 2001
From: "P. Taylor Goetz" <tgoetz@monetate.com>
Date: Wed, 10 Apr 2024 15:29:08 -0400
Subject: [PATCH 2/2] linting

---
 backend/app/agent.py | 1 +
 1 file changed, 1 insertion(+)

diff --git a/backend/app/agent.py b/backend/app/agent.py
index 105213721..c5acffec3 100644
--- a/backend/app/agent.py
+++ b/backend/app/agent.py
@@ -67,6 +67,7 @@ class AgentType(str, Enum):
     MIXTRAL = "Mixtral"
     OLLAMA = "Ollama"
 
+
 DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant."
 
 CHECKPOINTER = PostgresCheckpoint(at=CheckpointAt.END_OF_STEP)