Skip to content

Commit

Permalink
Merge pull request #885 from khoangothe/features/chat-with-history-2
Browse files Browse the repository at this point in the history
Chat with History
  • Loading branch information
ElishaKay authored Oct 31, 2024
2 parents e714999 + 845dd2b commit cc9f73a
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 16 deletions.
1 change: 1 addition & 0 deletions backend/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .chat import ChatAgentWithMemory
79 changes: 79 additions & 0 deletions backend/chat/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from fastapi import WebSocket
import uuid

from gpt_researcher.utils.llm import get_llm
from gpt_researcher.memory import Memory
from gpt_researcher.config.config import Config

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver

from langchain_community.vectorstores import InMemoryVectorStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import Tool, tool

class ChatAgentWithMemory:
def __init__(
self,
report: str,
config_path,
headers,
vector_store = None
):
self.report = report
self.headers = headers
self.config = Config(config_path)
self.vector_store = vector_store
self.graph = self.create_agent()

def create_agent(self):
"""Create React Agent Graph"""
#If not vector store, split and talk to the report
llm_provider_name = getattr(self.config, "llm_provider")
fast_llm_model = getattr(self.config, "fast_llm_model")
temperature = getattr(self.config, "temperature")
fast_token_limit = getattr(self.config, "fast_token_limit")

provider = get_llm(llm_provider_name, model=fast_llm_model, temperature=temperature, max_tokens=fast_token_limit, **self.config.llm_kwargs).llm
if not self.vector_store:
documents = self._process_document(self.report)
self.chat_config = {"configurable": {"thread_id": str(uuid.uuid4())}}
self.embedding = Memory(getattr(self.config, 'embedding_provider', None), self.headers).get_embeddings()
self.vector_store = InMemoryVectorStore(self.embedding)
self.vector_store.add_texts(documents)
graph = create_react_agent(provider, tools=[self.vector_store_tool(self.vector_store)], checkpointer=MemorySaver())
return graph

def vector_store_tool(self, vector_store) -> Tool:
"""Create Vector Store Tool"""
@tool
def retrieve_info(query):
"""
Consult the report for relevant contexts whenever you don't know something
"""
retriever = vector_store.as_retriever(k = 4)
return retriever.invoke(query)
return retrieve_info

def _process_document(self, report):
"""Split Report into Chunks"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=20,
length_function=len,
is_separator_regex=False,
)
documents = text_splitter.split_text(report)
return documents

async def chat(self, message, websocket):
"""Chat with React Agent"""
inputs = {"messages": [("user", message)]}
response = await self.graph.ainvoke(inputs, config=self.chat_config)
ai_message = response["messages"][-1].content
if websocket is not None:
await websocket.send_json({"type": "chat", "content": ai_message})

def get_context(self):
"""return the current context of the chat"""
return self.report
5 changes: 5 additions & 0 deletions backend/server/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ async def handle_human_feedback(data: str):
print(f"Received human feedback: {feedback_data}")
# TODO: Add logic to forward the feedback to the appropriate agent or update the research state

async def handle_chat(websocket, data: str, manager):
json_data = json.loads(data[4:])
await manager.chat(json_data.get("message"), websocket)

async def generate_report_files(report: str, filename: str) -> Dict[str, str]:
pdf_path = await write_md_to_pdf(report, filename)
Expand Down Expand Up @@ -117,6 +120,8 @@ async def handle_websocket_communication(websocket, manager):
await handle_start_command(websocket, data, manager)
elif data.startswith("human_feedback"):
await handle_human_feedback(data)
elif data.startswith("chat"):
await handle_chat(websocket, data, manager)
else:
print("Error: Unknown command or not enough parameters provided.")

Expand Down
19 changes: 16 additions & 3 deletions backend/server/websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fastapi import WebSocket

from backend.report_type import BasicReport, DetailedReport
from backend.chat import ChatAgentWithMemory

from gpt_researcher.utils.enum import ReportType, Tone
from multi_agents.main import run_research_task
from gpt_researcher.actions import stream_output # Import stream_output
Expand All @@ -18,6 +20,7 @@ def __init__(self):
self.active_connections: List[WebSocket] = []
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {}
self.message_queues: Dict[WebSocket, asyncio.Queue] = {}
self.chat_agent = None

async def start_sender(self, websocket: WebSocket):
"""Start the sender task."""
Expand Down Expand Up @@ -58,14 +61,24 @@ async def disconnect(self, websocket: WebSocket):
async def start_streaming(self, task, report_type, report_source, source_urls, tone, websocket, headers=None):
"""Start streaming the output."""
tone = Tone[tone]
report = await run_agent(task, report_type, report_source, source_urls, tone, websocket, headers)
# add customized JSON config file path here
config_path = "default"
report = await run_agent(task, report_type, report_source, source_urls, tone, websocket, headers = headers, config_path = config_path)
#Create new Chat Agent whenever a new report is written
self.chat_agent = ChatAgentWithMemory(report, config_path, headers)
return report

async def chat(self, message, websocket):
"""Chat with the agent based message diff"""
if self.chat_agent:
await self.chat_agent.chat(message, websocket)
else:
await websocket.send_json({"type": "chat", "content": "Knowledge empty, please run the research first to obtain knowledge"})

async def run_agent(task, report_type, report_source, source_urls, tone: Tone, websocket, headers=None):
async def run_agent(task, report_type, report_source, source_urls, tone: Tone, websocket, headers=None, config_path=""):
"""Run the agent."""
start_time = datetime.datetime.now()
config_path = ""
# Instead of running the agent directly run it through the different report type classes
if report_type == "multi_agents":
report = await run_research_task(query=task, websocket=websocket, stream_output=stream_output, tone=tone, headers=headers)
report = report.get("report", "")
Expand Down
34 changes: 27 additions & 7 deletions frontend/nextjs/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ export default function Home() {
setAnswer((prev:any) => prev + data.output);
} else if (data.type === 'path') {
setLoading(false);
newSocket.close();
setSocket(null);
// newSocket.close(); We do not want to close the connection since we are chatting
// setSocket(null);
} else if (data.type === 'chat'){
setLoading(false);
}
}

Expand Down Expand Up @@ -147,9 +149,20 @@ export default function Home() {
setShowHumanFeedback(false);
};

const handleDisplayResult = async (newQuestion?: string) => {
newQuestion = newQuestion || promptValue;

const handleChat = async (message : string) =>{
if(socket){
setShowResult(true);
setQuestion(message);
setLoading(true);
setPromptValue("");
setAnswer(""); // Reset answer for new query
setOrderedData((prevOrder) => [...prevOrder, { type: 'question', content: message }]);
const data : string = "chat" + JSON.stringify({"message": message});
socket.send(data)
}
}

const handleDisplayResult = async (newQuestion: string) => {
setShowResult(true);
setLoading(true);
setQuestion(newQuestion);
Expand Down Expand Up @@ -245,7 +258,10 @@ export default function Home() {
groupedData.push({ type: 'langgraphButton', link });
} else if (type === 'question') {
groupedData.push({ type: 'question', content });
} else {
} else if (type == 'chat'){
groupedData.push({ type: 'chat', content: content });
}
else {
if (currentReportGroup) {
currentReportGroup = null;
}
Expand Down Expand Up @@ -361,6 +377,9 @@ export default function Home() {
} else if (data.type === 'question') {
const uniqueKey = `question-${index}`;
return <Question key={uniqueKey} question={data.content} />;
} else if (data.type === 'chat'){
const uniqueKey = `chat-${index}`;
return <Answer key={uniqueKey} answer={data.content} />;
} else {
const { type, content, metadata, output } = data;
const uniqueKey = `${type}-${content}-${index}`;
Expand Down Expand Up @@ -425,7 +444,8 @@ export default function Home() {
<InputArea
promptValue={promptValue}
setPromptValue={setPromptValue}
handleDisplayResult={handleDisplayResult}
handleSubmit={handleChat}
handleSecondary={handleDisplayResult}
disabled={loading}
reset={reset}
/>
Expand Down
4 changes: 2 additions & 2 deletions frontend/nextjs/components/Hero.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import InputArea from "./InputArea";
type THeroProps = {
promptValue: string;
setPromptValue: React.Dispatch<React.SetStateAction<string>>;
handleDisplayResult: () => void;
handleDisplayResult: (query : string) => void;
};

const Hero: FC<THeroProps> = ({
Expand Down Expand Up @@ -43,7 +43,7 @@ const Hero: FC<THeroProps> = ({
<InputArea
promptValue={promptValue}
setPromptValue={setPromptValue}
handleDisplayResult={handleDisplayResult}
handleSubmit={handleDisplayResult}
/>
</div>

Expand Down
42 changes: 38 additions & 4 deletions frontend/nextjs/components/InputArea.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,64 @@ import TypeAnimation from "./TypeAnimation";
type TInputAreaProps = {
promptValue: string;
setPromptValue: React.Dispatch<React.SetStateAction<string>>;
handleDisplayResult: () => void;
handleSubmit: (query: string) => void;
handleSecondary?: (query: string) => void;
disabled?: boolean;
reset?: () => void;
};

const InputArea: FC<TInputAreaProps> = ({
promptValue,
setPromptValue,
handleDisplayResult,
handleSubmit: handleSubmit,
handleSecondary: handleSecondary,
disabled,
reset,
}) => {
const placeholder = handleSecondary ? "Follow up questions..." : "What would you like to research next?"
return (
<form
className="mx-auto flex h-[66px] w-full items-center justify-between rounded-lg border bg-white px-3 shadow-[2px_2px_38px_0px_rgba(0,0,0,0.25),0px_-2px_4px_0px_rgba(0,0,0,0.25)_inset,1px_2px_4px_0px_rgba(0,0,0,0.25)_inset]"
onSubmit={(e) => {
e.preventDefault();
if (reset) reset();
handleDisplayResult();
handleSubmit(promptValue);
}}
>
{
handleSecondary &&
<div
role="button"
aria-disabled={disabled}
className="relative flex h-[50px] w-[50px] shrink-0 items-center justify-center rounded-[3px] bg-[linear-gradient(154deg,#1B1B16_23.37%,#565646_91.91%)] disabled:pointer-events-none disabled:opacity-75"
onClick={(e) =>{
if (!disabled){
e.preventDefault();
if (reset) reset();
handleSecondary(promptValue);
}
}
}
>
{disabled && (
<div className="absolute inset-0 flex items-center justify-center">
<TypeAnimation />
</div>
)}

<Image
unoptimized
src={"/img/search.svg"}
alt="search"
width={24}
height={24}
className={disabled ? "invisible" : ""}
/>
</div>
}
<input
type="text"
placeholder="What would you like me to research next?"
placeholder={placeholder}
className="focus-visible::outline-0 my-1 w-full pl-5 font-light not-italic leading-[normal] text-[#1B1B16]/30 text-black outline-none focus-visible:ring-0 focus-visible:ring-offset-0 sm:text-xl"
disabled={disabled}
value={promptValue}
Expand Down
1 change: 1 addition & 0 deletions frontend/nextjs/public/img/search.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ python-docx = "^1.1.0"
lxml = { version = ">=4.9.2", extras = ["html_clean"] }
unstructured = ">=0.13,<0.16"
tiktoken = ">=0.7.0"
json-repair = "^0.29.8"
json5 = "^0.9.25"
loguru = "^0.7.2"
websockets = "^13.1"

[build-system]
requires = ["poetry-core"]
Expand Down

0 comments on commit cc9f73a

Please sign in to comment.