-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #885 from khoangothe/features/chat-with-history-2
Chat with History
- Loading branch information
Showing
9 changed files
with
173 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .chat import ChatAgentWithMemory |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
from fastapi import WebSocket | ||
import uuid | ||
|
||
from gpt_researcher.utils.llm import get_llm | ||
from gpt_researcher.memory import Memory | ||
from gpt_researcher.config.config import Config | ||
|
||
from langgraph.prebuilt import create_react_agent | ||
from langgraph.checkpoint.memory import MemorySaver | ||
|
||
from langchain_community.vectorstores import InMemoryVectorStore | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain.tools import Tool, tool | ||
|
||
class ChatAgentWithMemory: | ||
def __init__( | ||
self, | ||
report: str, | ||
config_path, | ||
headers, | ||
vector_store = None | ||
): | ||
self.report = report | ||
self.headers = headers | ||
self.config = Config(config_path) | ||
self.vector_store = vector_store | ||
self.graph = self.create_agent() | ||
|
||
def create_agent(self): | ||
"""Create React Agent Graph""" | ||
#If not vector store, split and talk to the report | ||
llm_provider_name = getattr(self.config, "llm_provider") | ||
fast_llm_model = getattr(self.config, "fast_llm_model") | ||
temperature = getattr(self.config, "temperature") | ||
fast_token_limit = getattr(self.config, "fast_token_limit") | ||
|
||
provider = get_llm(llm_provider_name, model=fast_llm_model, temperature=temperature, max_tokens=fast_token_limit, **self.config.llm_kwargs).llm | ||
if not self.vector_store: | ||
documents = self._process_document(self.report) | ||
self.chat_config = {"configurable": {"thread_id": str(uuid.uuid4())}} | ||
self.embedding = Memory(getattr(self.config, 'embedding_provider', None), self.headers).get_embeddings() | ||
self.vector_store = InMemoryVectorStore(self.embedding) | ||
self.vector_store.add_texts(documents) | ||
graph = create_react_agent(provider, tools=[self.vector_store_tool(self.vector_store)], checkpointer=MemorySaver()) | ||
return graph | ||
|
||
def vector_store_tool(self, vector_store) -> Tool: | ||
"""Create Vector Store Tool""" | ||
@tool | ||
def retrieve_info(query): | ||
""" | ||
Consult the report for relevant contexts whenever you don't know something | ||
""" | ||
retriever = vector_store.as_retriever(k = 4) | ||
return retriever.invoke(query) | ||
return retrieve_info | ||
|
||
def _process_document(self, report): | ||
"""Split Report into Chunks""" | ||
text_splitter = RecursiveCharacterTextSplitter( | ||
chunk_size=1024, | ||
chunk_overlap=20, | ||
length_function=len, | ||
is_separator_regex=False, | ||
) | ||
documents = text_splitter.split_text(report) | ||
return documents | ||
|
||
async def chat(self, message, websocket): | ||
"""Chat with React Agent""" | ||
inputs = {"messages": [("user", message)]} | ||
response = await self.graph.ainvoke(inputs, config=self.chat_config) | ||
ai_message = response["messages"][-1].content | ||
if websocket is not None: | ||
await websocket.send_json({"type": "chat", "content": ai_message}) | ||
|
||
def get_context(self): | ||
"""return the current context of the chat""" | ||
return self.report |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters