Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chat with History #885

Merged
merged 5 commits into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions backend/chat/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .chat import ChatAgentWithMemory
79 changes: 79 additions & 0 deletions backend/chat/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from fastapi import WebSocket
import uuid

from gpt_researcher.utils.llm import get_llm
from gpt_researcher.memory import Memory
from gpt_researcher.config.config import Config

from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver

from langchain_community.vectorstores import InMemoryVectorStore
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.tools import Tool, tool

class ChatAgentWithMemory:
def __init__(
self,
report: str,
config_path,
headers,
vector_store = None
):
self.report = report
self.headers = headers
self.config = Config(config_path)
self.vector_store = vector_store
self.graph = self.create_agent()

def create_agent(self):
"""Create React Agent Graph"""
#If not vector store, split and talk to the report
llm_provider_name = getattr(self.config, "llm_provider")
fast_llm_model = getattr(self.config, "fast_llm_model")
temperature = getattr(self.config, "temperature")
fast_token_limit = getattr(self.config, "fast_token_limit")

provider = get_llm(llm_provider_name, model=fast_llm_model, temperature=temperature, max_tokens=fast_token_limit, **self.config.llm_kwargs).llm
if not self.vector_store:
documents = self._process_document(self.report)
self.chat_config = {"configurable": {"thread_id": str(uuid.uuid4())}}
self.embedding = Memory(getattr(self.config, 'embedding_provider', None), self.headers).get_embeddings()
self.vector_store = InMemoryVectorStore(self.embedding)
self.vector_store.add_texts(documents)
graph = create_react_agent(provider, tools=[self.vector_store_tool(self.vector_store)], checkpointer=MemorySaver())
return graph

def vector_store_tool(self, vector_store) -> Tool:
"""Create Vector Store Tool"""
@tool
def retrieve_info(query):
"""
Consult the report for relevant contexts whenever you don't know something
"""
retriever = vector_store.as_retriever(k = 4)
return retriever.invoke(query)
return retrieve_info

def _process_document(self, report):
"""Split Report into Chunks"""
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1024,
chunk_overlap=20,
length_function=len,
is_separator_regex=False,
)
documents = text_splitter.split_text(report)
return documents

async def chat(self, message, websocket):
"""Chat with React Agent"""
inputs = {"messages": [("user", message)]}
response = await self.graph.ainvoke(inputs, config=self.chat_config)
ai_message = response["messages"][-1].content
if websocket is not None:
await websocket.send_json({"type": "chat", "content": ai_message})

def get_context(self):
"""return the current context of the chat"""
return self.report
5 changes: 5 additions & 0 deletions backend/server/server_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ async def handle_human_feedback(data: str):
print(f"Received human feedback: {feedback_data}")
# TODO: Add logic to forward the feedback to the appropriate agent or update the research state

async def handle_chat(websocket, data: str, manager):
json_data = json.loads(data[4:])
await manager.chat(json_data.get("message"), websocket)

async def generate_report_files(report: str, filename: str) -> Dict[str, str]:
pdf_path = await write_md_to_pdf(report, filename)
Expand Down Expand Up @@ -117,6 +120,8 @@ async def handle_websocket_communication(websocket, manager):
await handle_start_command(websocket, data, manager)
elif data.startswith("human_feedback"):
await handle_human_feedback(data)
elif data.startswith("chat"):
await handle_chat(websocket, data, manager)
else:
print("Error: Unknown command or not enough parameters provided.")

Expand Down
19 changes: 16 additions & 3 deletions backend/server/websocket_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from fastapi import WebSocket

from backend.report_type import BasicReport, DetailedReport
from backend.chat import ChatAgentWithMemory

from gpt_researcher.utils.enum import ReportType, Tone
from multi_agents.main import run_research_task
from gpt_researcher.orchestrator.actions import stream_output # Import stream_output
Expand All @@ -18,6 +20,7 @@ def __init__(self):
self.active_connections: List[WebSocket] = []
self.sender_tasks: Dict[WebSocket, asyncio.Task] = {}
self.message_queues: Dict[WebSocket, asyncio.Queue] = {}
self.chat_agent = None

async def start_sender(self, websocket: WebSocket):
"""Start the sender task."""
Expand Down Expand Up @@ -58,14 +61,24 @@ async def disconnect(self, websocket: WebSocket):
async def start_streaming(self, task, report_type, report_source, source_urls, tone, websocket, headers=None):
"""Start streaming the output."""
tone = Tone[tone]
report = await run_agent(task, report_type, report_source, source_urls, tone, websocket, headers)
# add customized JSON config file path here
config_path = "default"
report = await run_agent(task, report_type, report_source, source_urls, tone, websocket, headers = headers, config_path = config_path)
#Create new Chat Agent whenever a new report is written
self.chat_agent = ChatAgentWithMemory(report, config_path, headers)
return report

async def chat(self, message, websocket):
"""Chat with the agent based message diff"""
if self.chat_agent:
await self.chat_agent.chat(message, websocket)
else:
await websocket.send_json({"type": "chat", "content": "Knowledge empty, please run the research first to obtain knowledge"})

async def run_agent(task, report_type, report_source, source_urls, tone: Tone, websocket, headers=None):
async def run_agent(task, report_type, report_source, source_urls, tone: Tone, websocket, headers=None, config_path=""):
"""Run the agent."""
start_time = datetime.datetime.now()
config_path = ""
# Instead of running the agent directly run it through the different report type classes
if report_type == "multi_agents":
report = await run_research_task(query=task, websocket=websocket, stream_output=stream_output, tone=tone, headers=headers)
report = report.get("report", "")
Expand Down
34 changes: 27 additions & 7 deletions frontend/nextjs/app/page.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@ export default function Home() {
setAnswer((prev) => prev + data.output);
} else if (data.type === 'path') {
setLoading(false);
newSocket.close();
setSocket(null);
// newSocket.close(); We do not want to close the connection since we are chatting
// setSocket(null);
} else if (data.type === 'chat'){
setLoading(false);
}
}

Expand Down Expand Up @@ -125,9 +127,20 @@ export default function Home() {
setShowHumanFeedback(false);
};

const handleDisplayResult = async (newQuestion?: string) => {
newQuestion = newQuestion || promptValue;

const handleChat = async (message : string) =>{
if(socket){
setShowResult(true);
setQuestion(message);
setLoading(true);
setPromptValue("");
setAnswer(""); // Reset answer for new query
setOrderedData((prevOrder) => [...prevOrder, { type: 'question', content: message }]);
const data : string = "chat" + JSON.stringify({"message": message});
socket.send(data)
}
}

const handleDisplayResult = async (newQuestion: string) => {
setShowResult(true);
setLoading(true);
setQuestion(newQuestion);
Expand Down Expand Up @@ -222,7 +235,10 @@ export default function Home() {
groupedData.push({ type: 'langgraphButton', link });
} else if (type === 'question') {
groupedData.push({ type: 'question', content });
} else {
} else if (type == 'chat'){
groupedData.push({ type: 'chat', content: content });
}
else {
if (currentReportGroup) {
currentReportGroup = null;
}
Expand Down Expand Up @@ -324,6 +340,9 @@ export default function Home() {
} else if (data.type === 'question') {
const uniqueKey = `question-${index}`;
return <Question key={uniqueKey} question={data.content} />;
} else if (data.type === 'chat'){
const uniqueKey = `chat-${index}`;
return <Answer key={uniqueKey} answer={data.content} />;
} else {
const { type, content, metadata, output } = data;
const uniqueKey = `${type}-${content}-${index}`;
Expand Down Expand Up @@ -378,7 +397,8 @@ export default function Home() {
<InputArea
promptValue={promptValue}
setPromptValue={setPromptValue}
handleDisplayResult={handleDisplayResult}
handleSubmit={handleChat}
handleSecondary={handleDisplayResult}
disabled={loading}
reset={reset}
/>
Expand Down
4 changes: 2 additions & 2 deletions frontend/nextjs/components/Hero.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import InputArea from "./InputArea";
type THeroProps = {
promptValue: string;
setPromptValue: React.Dispatch<React.SetStateAction<string>>;
handleDisplayResult: () => void;
handleDisplayResult: (query : string) => void;
};

const Hero: FC<THeroProps> = ({
Expand Down Expand Up @@ -45,7 +45,7 @@ const Hero: FC<THeroProps> = ({
<InputArea
promptValue={promptValue}
setPromptValue={setPromptValue}
handleDisplayResult={handleDisplayResult}
handleSubmit={handleDisplayResult}
/>
</div>

Expand Down
42 changes: 38 additions & 4 deletions frontend/nextjs/components/InputArea.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,64 @@ import TypeAnimation from "./TypeAnimation";
type TInputAreaProps = {
promptValue: string;
setPromptValue: React.Dispatch<React.SetStateAction<string>>;
handleDisplayResult: () => void;
handleSubmit: (query: string) => void;
handleSecondary?: (query: string) => void;
disabled?: boolean;
reset?: () => void;
};

const InputArea: FC<TInputAreaProps> = ({
promptValue,
setPromptValue,
handleDisplayResult,
handleSubmit: handleSubmit,
handleSecondary: handleSecondary,
disabled,
reset,
}) => {
const placeholder = handleSecondary ? "Follow up questions..." : "What would you like to research next?"
return (
<form
className="mx-auto flex h-[66px] w-full items-center justify-between rounded-lg border bg-white px-3 shadow-[2px_2px_38px_0px_rgba(0,0,0,0.25),0px_-2px_4px_0px_rgba(0,0,0,0.25)_inset,1px_2px_4px_0px_rgba(0,0,0,0.25)_inset]"
onSubmit={(e) => {
e.preventDefault();
if (reset) reset();
handleDisplayResult();
handleSubmit(promptValue);
}}
>
{
handleSecondary &&
<div
role="button"
aria-disabled={disabled}
className="relative flex h-[50px] w-[50px] shrink-0 items-center justify-center rounded-[3px] bg-[linear-gradient(154deg,#1B1B16_23.37%,#565646_91.91%)] disabled:pointer-events-none disabled:opacity-75"
onClick={(e) =>{
if (!disabled){
e.preventDefault();
if (reset) reset();
handleSecondary(promptValue);
}
}
}
>
{disabled && (
<div className="absolute inset-0 flex items-center justify-center">
<TypeAnimation />
</div>
)}

<Image
unoptimized
src={"/img/search.svg"}
alt="search"
width={24}
height={24}
className={disabled ? "invisible" : ""}
/>
</div>
}
<input
type="text"
placeholder="What would you like me to research next?"
placeholder={placeholder}
className="focus-visible::outline-0 my-1 w-full pl-5 font-light not-italic leading-[normal] text-[#1B1B16]/30 text-black outline-none focus-visible:ring-0 focus-visible:ring-offset-0 sm:text-xl"
disabled={disabled}
value={promptValue}
Expand Down
1 change: 1 addition & 0 deletions frontend/nextjs/public/img/search.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ python-docx = "^1.1.0"
lxml = { version = ">=4.9.2", extras = ["html_clean"] }
unstructured = ">=0.13,<0.16"
tiktoken = ">=0.7.0"
json-repair = "^0.29.8"
json5 = "^0.9.25"
loguru = "^0.7.2"
websockets = "^13.1"

[build-system]
requires = ["poetry-core"]
Expand Down