Skip to content

community: add include_comment_forest to RedditSearch #29699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 21 additions & 2 deletions libs/community/langchain_community/tools/reddit_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,30 @@ class RedditSearchSchema(BaseModel):
description="a positive integer indicating the maximum number \
of results to return"
)
include_comment_forest: bool = Field(
default=False,
description="A boolean indicating whether to include the comment "
"forest in the results. Defaults to False, which avoids large data pulls.",
)


class RedditSearchRun(BaseTool): # type: ignore[override, override]
"""Tool that queries for posts on a subreddit."""
"""Tool that queries for posts on a subreddit,
optionally fetching full comment forest."""

name: str = "reddit_search"
description: str = (
"A tool that searches for posts on Reddit."
"A tool that searches for posts on Reddit. "
"Optionally, it can fetch the entire comment forest. "
"Useful when you need to know post information on a subreddit."
)
api_wrapper: RedditSearchAPIWrapper = Field(default_factory=RedditSearchAPIWrapper) # type: ignore[arg-type]
# Add a constructor param to allow (or disallow) comment forest fetching at all.
allow_comment_forest: bool = Field(
default=False,
description="If False, all calls will ignore `include_comment_forest=True`. "
"Set to True to allow the agent/model to fetch full comment trees.",
)
args_schema: Type[BaseModel] = RedditSearchSchema

def _run(
Expand All @@ -52,13 +65,19 @@ def _run(
time_filter: str,
subreddit: str,
limit: str,
include_comment_forest: bool = False,
run_manager: Optional[CallbackManagerForToolRun] = None,
) -> str:
"""Use the tool."""
# If we do not allow comment forest fetching, force it to False
if not self.allow_comment_forest:
include_comment_forest = False

return self.api_wrapper.run(
query=query,
sort=sort,
time_filter=time_filter,
subreddit=subreddit,
limit=int(limit),
include_comment_forest=include_comment_forest,
)
148 changes: 115 additions & 33 deletions libs/community/langchain_community/utilities/reddit_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,13 @@ def validate_environment(cls, values: Dict) -> Any:
return values

def run(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
self,
query: str,
sort: str,
time_filter: str,
subreddit: str,
limit: int,
include_comment_forest: bool = False,
) -> str:
"""Search Reddit and return posts as a single string."""
results: List[Dict] = self.results(
Expand All @@ -77,46 +83,122 @@ def run(
time_filter=time_filter,
subreddit=subreddit,
limit=limit,
include_comment_forest=include_comment_forest,
)
if len(results) > 0:
output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
for r in results:
category = "N/A" if r["post_category"] is None else r["post_category"]
p = f"Post Title: '{r['post_title']}'\n\
User: {r['post_author']}\n\
Subreddit: {r['post_subreddit']}:\n\
Text body: {r['post_text']}\n\
Post URL: {r['post_url']}\n\
Post Category: {category}.\n\
Score: {r['post_score']}\n"
output.append(p)
return "\n".join(output)
else:
return f"Searching r/{subreddit} did not find any posts:"
if len(results) == 0:
return f"Searching r/{subreddit} did not find any posts."

output: List[str] = [f"Searching r/{subreddit} found {len(results)} posts:"]
for r in results:
category = "N/A" if r["post_category"] is None else r["post_category"]
p = (
f"Post Title: '{r['post_title']}'\n"
f" Created: {r['post_created']}\n"
f" Post ID: {r['post_id']}\n"
f" User: {r['post_author']}\n"
f" Subreddit: {r['post_subreddit']}\n"
f" Text body: {r['post_text']}\n"
f" Post URL: {r['post_url']}\n"
f" Post Category: {category}\n"
f" Score: {r['post_score']}\n"
f" Upvote Ratio: {r['post_upvote_ratio']}\n"
)
output.append(p)

# If requested, format and display the entire comment forest
if include_comment_forest and r.get("post_comments"):
output.append(" Comments (entire comment tree):")
comment_str = self._format_comment_forest(r["post_comments"], indent=4)
output.append(comment_str)
output.append("===")

return "\n".join(output)

def _parse_comment_forest(self, comment_forest: Any) -> List[Dict[str, Any]]:
"""Recursively traverse the entire comment forest and return a list
of dictionaries with comment info, including nested replies.
"""
comments_data = []
for comment in comment_forest:
# Sometimes comment could be 'MoreComments' object
if hasattr(comment, "body"):
comment_info = {
"id": comment.id,
"body": comment.body,
"score": comment.score,
"ups": comment.ups,
"author": str(comment.author),
"created_utc": comment.created_utc,
# Recursively parse any replies (the nested forest)
"replies": self._parse_comment_forest(comment.replies),
}
comments_data.append(comment_info)
return comments_data

def _format_comment_forest(
self, comments: List[Dict[str, Any]], indent: int = 0
) -> str:
"""Recursively build a readable string of the entire comment forest."""
lines = []
for idx, c in enumerate(comments, start=1):
prefix = " " * indent + f"{idx}. "
# Replace newlines in comment body to avoid messing up formatting
body_single_line = c["body"].replace("\n", " ")
lines.append(
f"{prefix}[id: {c['id']}, score: {c['score']}, ups: {c['ups']}] "
f"(by {c['author']}) {body_single_line}"
)
# If there are replies, recurse deeper
if c["replies"]:
replies_str = self._format_comment_forest(
c["replies"], indent=indent + 4
)
lines.append(replies_str)
return "\n".join(lines)

def results(
self, query: str, sort: str, time_filter: str, subreddit: str, limit: int
self,
query: str,
sort: str,
time_filter: str,
subreddit: str,
limit: int,
include_comment_forest: bool = False,
) -> List[Dict]:
"""Use praw to search Reddit and return a list of dictionaries,
one for each post.
one for each post. If include_comments is True, fetch the entire
nested comment forest.
"""
subredditObject = self.reddit_client.subreddit(subreddit)
search_results = subredditObject.search(
subreddit_obj = self.reddit_client.subreddit(subreddit)
search_results = subreddit_obj.search(
query=query, sort=sort, time_filter=time_filter, limit=limit
)
search_results = [r for r in search_results]

results_object = []
for submission in search_results:
results_object.append(
{
"post_subreddit": submission.subreddit_name_prefixed,
"post_category": submission.category,
"post_title": submission.title,
"post_text": submission.selftext,
"post_score": submission.score,
"post_id": submission.id,
"post_url": submission.url,
"post_author": submission.author,
}
)
post_data = {
"post_subreddit": submission.subreddit_name_prefixed,
"post_category": submission.category,
"post_title": submission.title,
"post_text": submission.selftext,
"post_score": submission.score,
"post_id": submission.id,
"post_url": submission.url,
"post_author": str(submission.author),
"post_created": submission.created_utc,
"post_upvote_ratio": submission.upvote_ratio,
"post_ups": submission.ups,
}

# If include_comments, get the entire nested comment tree
if include_comment_forest:
submission.comments.replace_more(
limit=None
) # fetch all nested comments
post_data["post_comments"] = self._parse_comment_forest(
submission.comments
)

results_object.append(post_data)

return results_object
Loading