Skip to content

Commit

Permalink
Openapi agent (#68)
Browse files Browse the repository at this point in the history
* Add an `OpenAPI` Agent

* Remove `print` statement

* Add support for an Agent that can query API with NL

* Add support for adding a prompt when creating an Agent

* Small tweaks
  • Loading branch information
homanp authored May 30, 2023
1 parent be3c1d8 commit 6fb4308
Show file tree
Hide file tree
Showing 9 changed files with 80 additions and 24 deletions.
2 changes: 2 additions & 0 deletions app/api/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ async def create_agent(body: Agent, token=Depends(JWTBearer())):
"llm": json.dumps(body.llm),
"hasMemory": body.has_memory,
"userId": decoded["userId"],
"documentId": body.documentId,
"promptId": body.promptId,
},
include={"user": True},
)
Expand Down
23 changes: 12 additions & 11 deletions app/api/documents.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from fastapi import APIRouter, Depends, HTTPException, status

from app.lib.auth.prisma import JWTBearer, decodeJWT
Expand All @@ -14,23 +16,22 @@ async def create_document(body: Document, token=Depends(JWTBearer())):

try:
decoded = decodeJWT(token)
document_type = body.type
document_url = body.url
document_name = body.name
document = prisma.document.create(
{
"type": document_type,
"url": document_url,
"type": body.type,
"url": body.url,
"userId": decoded["userId"],
"name": document_name,
"name": body.name,
"authorization": json.dumps(body.authorization),
}
)

upsert_document(
url=document_url,
type=document_type,
document_id=document.id,
)
if body.type == "TXT" or body.type == "PDF":
upsert_document(
url=body.url,
type=body.type,
document_id=document.id,
)

return {"success": True, "data": document}

Expand Down
36 changes: 34 additions & 2 deletions app/lib/agents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Any

from decouple import config
from langchain.agents import AgentType, initialize_agent
from langchain.agents.agent_toolkits import NLAToolkit
from langchain.chains import ConversationalRetrievalChain, LLMChain
from langchain.chains.conversational_retrieval.prompts import (
CONDENSE_QUESTION_PROMPT,
Expand All @@ -12,11 +14,12 @@
from langchain.llms import Cohere, OpenAI
from langchain.memory import ChatMessageHistory, ConversationBufferMemory
from langchain.prompts.prompt import PromptTemplate
from langchain.requests import Requests
from langchain.vectorstores.pinecone import Pinecone

from app.lib.callbacks import StreamingCallbackHandler
from app.lib.prisma import prisma
from app.lib.prompts import default_chat_prompt
from app.lib.prompts import default_chat_prompt, openapi_format_instructions


class Agent:
Expand Down Expand Up @@ -171,7 +174,8 @@ def get_agent(self) -> Any:
llm = self._get_llm()
memory = self._get_memory()
document = self._get_document()
if document:

if self.document and self.document.type != "OPENAPI":
question_generator = LLMChain(
llm=OpenAI(temperature=0), prompt=CONDENSE_QUESTION_PROMPT
)
Expand All @@ -186,6 +190,34 @@ def get_agent(self) -> Any:
get_chat_history=lambda h: h,
)

elif self.document and self.document.type == "OPENAPI":
requests = (
Requests(
headers={
self.document.authorization["key"]: self.document.authorization[
"value"
]
}
)
if self.document.authorization
else Requests()
)
openapi_toolkit = NLAToolkit.from_llm_and_url(
llm, self.document.url, requests=requests, max_text_length=1800
)
tools = openapi_toolkit.get_tools()[:30]
mrkl = initialize_agent(
tools=tools,
llm=llm,
agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
max_iterations=1,
early_stopping_method="generate",
agent_kwargs={"format_instructions": openapi_format_instructions},
)

return mrkl

else:
agent = LLMChain(
llm=llm, memory=memory, verbose=True, prompt=self._get_prompt()
Expand Down
2 changes: 2 additions & 0 deletions app/lib/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ class Agent(BaseModel):
type: str
llm: dict = None
has_memory: bool = False
documentId: str = None
promptId: str = None


class PredictAgent(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions app/lib/models/document.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ class Document(BaseModel):
type: str
url: str
name: str
authorization: dict
12 changes: 12 additions & 0 deletions app/lib/prompts.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
# flake8: noqa
from langchain.prompts.prompt import PromptTemplate

openapi_format_instructions = """Use the following format:
Question: the input question you must answer
Thought: you should always think about what to do
Action: the action to take, should be one of [{tool_names}]
Action Input: what to instruct the AI Action representative.
Observation: The Agent's response
... (this Thought/Action/Action Input/Observation can repeat N times)
Thought: I now know the final answer. User can't see any of my observations, API responses, links, or tools.
Final Answer: the final answer to the original input question in markdown.
When responding with your Final Answer, use a conversational answer without referencing the API.
"""

default_chat_template = """Assistant is designed to be able to assist with a wide range of tasks, from answering
simple questions to providing in-depth explanations and discussions on a wide range of
topics.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterEnum
ALTER TYPE "DocumentType" ADD VALUE 'OPENAPI';
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
-- AlterTable
ALTER TABLE "Document" ADD COLUMN "authorization" JSONB;
24 changes: 13 additions & 11 deletions prisma/schema.prisma
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ generator client {

datasource db {
provider = "postgresql"
url = env("DATABASE_URL")
url = env("DATABASE_MIGRATION_URL")
shadowDatabaseUrl = env("DATABASE_SHADOW_URL")
}

Expand All @@ -23,6 +23,7 @@ enum DocumentType {
TXT
PDF
YOUTUBE
OPENAPI
}

model User {
Expand All @@ -48,16 +49,17 @@ model Profile {
}

model Document {
id String @id @default(cuid()) @db.VarChar(255)
userId String @db.VarChar(255)
user User @relation(fields: [userId], references: [id])
type DocumentType @default(TXT)
url String @db.Text()
name String
createdAt DateTime? @default(now())
updatedAt DateTime? @default(now())
index Json?
Agent Agent[]
id String @id @default(cuid()) @db.VarChar(255)
userId String @db.VarChar(255)
user User @relation(fields: [userId], references: [id])
type DocumentType @default(TXT)
url String @db.Text()
name String
createdAt DateTime? @default(now())
updatedAt DateTime? @default(now())
index Json?
authorization Json?
Agent Agent[]
}

model Agent {
Expand Down

0 comments on commit 6fb4308

Please sign in to comment.