Skip to content

Commit

Permalink
Merge pull request #337 from ptgoetz/gpt-4o-support
Browse files Browse the repository at this point in the history
Add GPT 4o as a model
  • Loading branch information
mkorpela authored May 16, 2024
2 parents 40a76db + 04de67a commit 2cec393
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
13 changes: 11 additions & 2 deletions backend/app/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
class AgentType(str, Enum):
GPT_35_TURBO = "GPT 3.5 Turbo"
GPT_4 = "GPT 4 Turbo"
GPT_4O = "GPT 4o"
AZURE_OPENAI = "GPT 4 (Azure OpenAI)"
CLAUDE2 = "Claude 2"
BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"
Expand All @@ -88,7 +89,12 @@ def get_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.GPT_4:
llm = get_openai_llm(gpt_4=True)
llm = get_openai_llm(model="gpt-4-turbo")
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
elif agent == AgentType.GPT_4O:
llm = get_openai_llm(model="gpt-4o")
return get_tools_agent_executor(
tools, llm, system_message, interrupt_before_action, CHECKPOINTER
)
Expand Down Expand Up @@ -182,6 +188,7 @@ def __init__(
class LLMType(str, Enum):
GPT_35_TURBO = "GPT 3.5 Turbo"
GPT_4 = "GPT 4 Turbo"
GPT_4O = "GPT 4o"
AZURE_OPENAI = "GPT 4 (Azure OpenAI)"
CLAUDE2 = "Claude 2"
BEDROCK_CLAUDE2 = "Claude 2 (Amazon Bedrock)"
Expand Down Expand Up @@ -277,7 +284,9 @@ def __init__(
if llm_type == LLMType.GPT_35_TURBO:
llm = get_openai_llm()
elif llm_type == LLMType.GPT_4:
llm = get_openai_llm(gpt_4=True)
llm = get_openai_llm(model="gpt-4-turbo")
elif llm_type == LLMType.GPT_4O:
llm = get_openai_llm(model="gpt-4o")
elif llm_type == LLMType.AZURE_OPENAI:
llm = get_openai_llm(azure=True)
elif llm_type == LLMType.CLAUDE2:
Expand Down
4 changes: 2 additions & 2 deletions backend/app/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@


@lru_cache(maxsize=4)
def get_openai_llm(gpt_4: bool = False, azure: bool = False):
def get_openai_llm(model: str = "gpt-3.5-turbo", azure: bool = False):
proxy_url = os.getenv("PROXY_URL")
http_client = None
if proxy_url:
Expand All @@ -27,7 +27,7 @@ def get_openai_llm(gpt_4: bool = False, azure: bool = False):

if not azure:
try:
openai_model = "gpt-4-turbo-preview" if gpt_4 else "gpt-3.5-turbo"
openai_model = model
llm = ChatOpenAI(
http_client=http_client,
model=openai_model,
Expand Down

0 comments on commit 2cec393

Please sign in to comment.