Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ReAct Agent #16

Merged
merged 19 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions discussion_agents/cog/agent/generative_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,10 @@
"""
from datetime import datetime
from itertools import chain
from typing import Any, Dict, List, Optional, Tuple, Type, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.language_models import LLM
from pydantic.v1 import root_validator

from discussion_agents.cog.agent.base import BaseAgent
Expand Down Expand Up @@ -69,7 +68,7 @@ class GenerativeAgent(BaseAgent):
memory, and persona attributes.
"""

llm: LLM
llm: Any
memory: GenerativeAgentMemory
reflector: Optional[BaseReflector] = None
scorer: Optional[BaseScorer] = None
Expand Down Expand Up @@ -610,27 +609,31 @@ def retrieve(
last_k: Optional[int] = None,
consumed_tokens: Optional[int] = None,
max_tokens_limit: Optional[int] = None,
llm: Optional[LLM] = None,
llm: Optional[Any] = None,
now: Optional[datetime] = None,
queries_key: str = "relevant_memories",
most_recent_key: str = "most_recent_memories",
consumed_tokens_key: str = "most_recent_memories_limit",
) -> Dict[str, Any]:
"""Wraps around the memory's `load_memories` method."""
"""Wraps around the memory's `load_memories` method.

If `load_memories` uses `consumed_tokens` and `max_tokens_limit`,
llm will default to `GenerativeAgent` llm if not specified.
"""
return self.memory.load_memories(
queries=queries,
last_k=last_k,
consumed_tokens=consumed_tokens,
max_tokens_limit=max_tokens_limit,
llm=llm,
llm=llm if llm else self.llm,
now=now,
queries_key=queries_key,
most_recent_key=most_recent_key,
consumed_tokens_key=consumed_tokens_key,
)

def generate(
self, is_react: bool, observation: str, now: Optional[datetime] = None
self, observation: str, is_react: bool, now: Optional[datetime] = None
) -> Tuple[bool, str]:
"""Wrapper around `generate_reaction` and `generate_dialogue_response`."""
if is_react:
Expand Down
268 changes: 268 additions & 0 deletions discussion_agents/cog/agent/react.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
"""ReAct Agent implementation adapted from LangChain's zero-shot ReAct.

LangChain-adapted Zero-shot ReAct, except the default tool is the wikipedia searcher.
This implementation uses parts of the zero-shot ReAct prompt from langchain-hub, but it's
structured to match the original paper's implementation. It is open to other tools.

Original Paper: https://arxiv.org/abs/2210.03629
Paper Repository: https://github.com/ysymyth/ReAct
LangChain: https://github.com/langchain-ai/langchain
LangChain ReAct: https://python.langchain.com/docs/modules/agents/agent_types/react
"""
from typing import Any, Dict, List, Optional

import requests

from bs4 import BeautifulSoup
from langchain import hub
from langchain.agents import AgentExecutor, create_react_agent
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_community.document_loaders.wikipedia import WikipediaLoader
from langchain_core.tools import BaseTool, tool
from pydantic.v1 import root_validator

from discussion_agents.cog.agent.base import BaseAgent
from discussion_agents.cog.prompts.react import HOTPOTQA_FEWSHOT_EXAMPLES, INSTRUCTION
from discussion_agents.utils.parse import clean_str, construct_lookup_list, get_page_obs


class ReActAgent(BaseAgent):
"""ReAct agent from the original paper.

This agent has 2 methods: `search` and `generate`. It does not
have any memory, planning, reflecting, or scoring capabilities.
Given a question, this agent, equipped with Wikipedia search,
attempts to answer the question in, a maximum of, 7 steps. Each step
is a thought-action-observation sequence.

Available actions are:
- Search[], search for relevant info on Wikipedia (5 sentences)
- Lookup[], lookup keywords in Wikipedia search
- Finish[], finish task

Note:
By default, HOTPOTQA_FEWSHOT_EXAMPLES are used as fewshot context examples.
You have the option to provide your own fewshot examples in the `generate` method.

Attributes:
llm (LLM): An instance of a language model used for processing and generating content.

See: https://github.com/ysymyth/ReAct
"""

llm: Any # TODO: Why is `LLM` not usable here?

page: str = "" #: :meta private:
result_titles: list = [] #: :meta private:
lookup_keyword: str = "" #: :meta private:
lookup_list: list = [] #: :meta private:
lookup_cnt: int = 0 #: :meta private:

def search(self, entity: str, k: Optional[int] = 5) -> str:
"""Performs a search operation for a given entity on Wikipedia.

It parses the search results and either returns a list of similar topics
(if the exact entity is not found) or the content of the Wikipedia page related to the entity.

Args:
entity (str): The entity to be searched for.
k (Optional[int]): An optional argument to specify the number of sentences to be returned
from the Wikipedia page content.

Returns:
str: A string containing either the Wikipedia page content (trimmed to 'k' sentences) or
a list of similar topics if the exact match is not found.
"""
entity_ = entity.replace(" ", "+")
search_url = f"https://en.wikipedia.org/w/index.php?search={entity_}"
response_text = requests.get(search_url).text
soup = BeautifulSoup(response_text, features="html.parser")
result_divs = soup.find_all("div", {"class": "mw-search-result-heading"})
if result_divs: # Mismatch.
self.result_titles = [
clean_str(div.get_text().strip()) for div in result_divs
]
obs = f"Could not find {entity}. Similar: {self.result_titles[:5]}."
else:
page = [
p.get_text().strip() for p in soup.find_all("p") + soup.find_all("ul")
]
if any("may refer to:" in p for p in page):
obs = self.search("[" + entity + "]")

Check warning on line 92 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L92

Added line #L92 was not covered by tests
else:
self.page = ""
for p in page:
if len(p.split(" ")) > 2:
self.page += clean_str(p)
if not p.endswith("\n"):
self.page += "\n"
obs = get_page_obs(self.page, k=k)

# Reset lookup attributes.
self.lookup_keyword = ""
self.lookup_list = []
self.lookup_cnt = 0

return obs

def generate(
self,
observation: str,
fewshot_examples: Optional[str] = HOTPOTQA_FEWSHOT_EXAMPLES,
) -> str:
"""It takes an observation/question as input and generates a multi-step reasoning process.

The method involves generating thoughts and corresponding actions based on the observation,
and executing those actions which may include web searches, lookups, or concluding the reasoning process.

Args:
observation (str): The observation based on which the reasoning process is to be performed.
fewshot_examples (Optional[str]): A string containing few-shot examples to guide the language model.
Defaults to HOTPOTQA_FEWSHOT_EXAMPLES.

Returns:
str: A string representing the entire reasoning process, including thoughts, actions, and observations
at each step, culminating in a final answer or conclusion.
"""
prompt_template = [
INSTRUCTION,
fewshot_examples,
"\n",
"Question: ",
"{observation}",
"\n",
"Thought {i}: ",
]

# TODO: Find a way to enforce llm outputs.
done = False
out = ""
for i in range(1, 8):
# Create and run prompt.
prompt = PromptTemplate.from_template(
"".join(prompt_template) # type: ignore
if not out
else "".join(prompt_template[:-1]) + out # type: ignore
)
chain = LLMChain(llm=self.llm, prompt=prompt)
thought_action = chain.run(observation=observation, i=i).split(
f"\nObservation {i}:"
)[0]

# Get thought and action.
try:
thought, action = thought_action.strip().split(f"\nAction {i}: ")
thought = thought.split(f"Thought {i}: ")[-1]
except:
thought = thought_action.strip().split("\n")[0]
revised_prompt_template = (
(
"".join(prompt_template) # type: ignore
if not out
else "".join(prompt_template[:-1]) + out # type: ignore
)
+ f"{thought}\n"
+ "Action {i}: "
)
revised_prompt = PromptTemplate.from_template(revised_prompt_template)
chain = LLMChain(llm=self.llm, prompt=revised_prompt)
action = chain.run(observation=observation, i=i).strip().split("\n")[0]

# Execute action and get observation.
if action.lower().startswith("search[") and action.endswith("]"):
query = action[len("search[") : -1].lower()
obs = self.search(query)
if not obs.endswith("\n"):
obs = obs + "\n"
elif action.lower().startswith("lookup[") and action.endswith("]"):
keyword = action[len("lookup[") : -1].lower()

# Reset lookup.
if self.lookup_keyword != keyword:
self.lookup_keyword = keyword
self.lookup_list = construct_lookup_list(keyword, page=self.page)
self.lookup_cnt = 0

# All lookups used.
if self.lookup_cnt >= len(self.lookup_list):
obs = "No more results.\n"

Check warning on line 189 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L189

Added line #L189 was not covered by tests
else:
obs = (
f"(Result {self.lookup_cnt + 1} / {len(self.lookup_list)}) "
+ self.lookup_list[self.lookup_cnt]
)
self.lookup_cnt += 1
elif action.lower().startswith("finish[") and action.endswith("]"):
answer = action[len("finish[") : -1].lower()
done = True
obs = f"Episode finished. Answer: {answer}\n"
else:
obs = "Invalid action: {}".format(action)

Check warning on line 201 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L201

Added line #L201 was not covered by tests

# Update out.
obs = obs.replace("\\n", "")
out += (
f"Thought {i}: {thought}\n"
+ f"Action {i}: {action}\n"
+ f"Observation {i}: {obs}\n"
)

# Break, if done.
if done:
break

return out


@tool
def search(query: str) -> str:
"""Searches Wikipedia with a given query and returns first document found."""
docs = WikipediaLoader(

Check warning on line 221 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L221

Added line #L221 was not covered by tests
query=query,
load_max_docs=1,
).load()
return docs[0].page_content

Check warning on line 225 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L225

Added line #L225 was not covered by tests


class ZeroShotReActAgent(BaseAgent):
"""The Zero-Shot ReAct Agent class adapted from LangChain.

Attributes:
llm (Any): An attribute for a language model or a similar interface. The exact type is to be determined.
tools (List[BaseTool]): A list of tools that the agent can use to interact or perform tasks.
prompt (str, optional): An initial prompt for the agent. If not provided, a default prompt is fetched from a specified hub.

See: https://github.com/langchain-ai/langchain/tree/master/libs/langchain/langchain/agents/react
"""

llm: Any # TODO: Why is `LLM` not usable here?
tools: Optional[List[BaseTool]] = []
prompt: Optional[str] = None

@root_validator(pre=False)
def set_args(cls: Any, values: Dict[str, Any]) -> Dict[str, Any]:
"""Set default arguments."""
llm = values["llm"]
tools = values["tools"]
tools.append(search)
prompt = values["prompt"]
prompt = hub.pull("hwchase17/react") if not prompt else prompt
if llm and tools and prompt:
agent = create_react_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True) # type: ignore
values["agent"] = agent_executor
return values

def generate(self, observation_dict: Dict[str, str]) -> str:
"""Generates a response based on the provided observation dictionary.

This method wraps around the `AgentExecutor`'s `invoke` method.

Args:
observation_dict (Dict[str, str]): A dictionary containing observation data.

Returns:
str: The generated response.
"""
return self.agent.invoke(observation_dict) # type: ignore

Check warning on line 268 in discussion_agents/cog/agent/react.py

View check run for this annotation

Codecov / codecov/patch

discussion_agents/cog/agent/react.py#L268

Added line #L268 was not covered by tests
11 changes: 5 additions & 6 deletions discussion_agents/cog/functional/generative_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import re

from datetime import datetime
from typing import List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union

from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain_core.language_models import LLM
from langchain_core.retrievers import BaseRetriever

from discussion_agents.utils.fetch import fetch_memories
Expand All @@ -18,7 +17,7 @@
def score_memories_importance(
memory_contents: Union[str, List[str]],
relevant_memories: Union[str, List[str]],
llm: LLM,
llm: Any,
importance_weight: float = 0.15,
) -> List[float]:
"""Calculate absolute importance scores for given memory contents.
Expand Down Expand Up @@ -84,7 +83,7 @@ def score_memories_importance(

def get_topics_of_reflection(
observations: Union[str, List[str]],
llm: LLM,
llm: Any,
) -> List[str]:
"""Generate three insightful high-level questions based on recent observation(s).

Expand Down Expand Up @@ -117,7 +116,7 @@ def get_topics_of_reflection(
def get_insights_on_topics(
topics: Union[str, List[str]],
related_memories: Union[str, List[str]],
llm: LLM,
llm: Any,
) -> List[List[str]]:
"""Generate high-level insights on specified topics using relevant memories.

Expand Down Expand Up @@ -171,7 +170,7 @@ def get_insights_on_topics(

def reflect(
observations: Union[str, List[str]],
llm: LLM,
llm: Any,
retriever: BaseRetriever,
now: Optional[datetime] = None,
) -> Tuple[List[str], List[List[str]]]:
Expand Down
3 changes: 1 addition & 2 deletions discussion_agents/cog/modules/reflect/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,13 @@
from abc import ABC, abstractmethod
from typing import Any, List, Union

from langchain_core.language_models import LLM
from pydantic.v1 import BaseModel


class BaseReflector(BaseModel, ABC):
"""Base reflecting class."""

llm: LLM
llm: Any

@abstractmethod
def reflect(
Expand Down
Loading