Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add multiple gsite feature #239

Merged
41 changes: 37 additions & 4 deletions src/sherpa_ai/config/task_config.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,33 @@
import re
from argparse import ArgumentParser
from functools import cached_property
from typing import List, Optional, Tuple
from urllib.parse import urlparse

from pydantic import BaseModel
from pydantic import BaseModel, computed_field, validator


class AgentConfig(BaseModel):
verbose: bool = False
gsite: Optional[str] = None
gsite: list[str] = []
do_reflect: bool = False


@validator("gsite", pre=True)
def parse_gsite(cls, value: Optional[str]) -> list[str]:
if value is None:
return []
return [url.strip() for url in value.split(",")]

@computed_field
@cached_property
def search_domains(self) -> List[str]:
return [url for url in self.gsite if validate_url(url)]

@computed_field
@cached_property
def invalid_domains(self) -> List[str]:
return [url for url in self.gsite if not validate_url(url)]

@classmethod
def from_input(cls, input_str: str) -> Tuple[str, "AgentConfig"]:
"""
Expand All @@ -21,7 +39,14 @@ def from_input(cls, input_str: str) -> Tuple[str, "AgentConfig"]:

for part in parts[1:]:
part = part.strip()
configs.extend(part.split())
if part.startswith("--gsite"):
20001LastOrder marked this conversation as resolved.
Show resolved Hide resolved
gsite_arg, gsite_val = part.split(maxsplit=1)
configs.append(gsite_arg)
urls = [url.strip() for url in gsite_val.split(",")]
concatenated_urls = ", ".join(urls)
configs.append(concatenated_urls)
else:
configs.extend(part.split())

return parts[0].strip(), cls.from_config(configs)

Expand Down Expand Up @@ -52,3 +77,11 @@ def from_config(cls, configs: List[str]) -> "AgentConfig":
raise ValueError(f"Invalid configuration, check your input: {unknown}")

return AgentConfig(**args.__dict__)


def validate_url(url: str) -> bool:
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False
85 changes: 67 additions & 18 deletions src/sherpa_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import urllib
import urllib.parse
import urllib.request
from typing import Any
from typing import Any, List, Tuple, Union
from urllib.parse import urlparse

import requests
from bs4 import BeautifulSoup
Expand All @@ -17,6 +18,7 @@

import sherpa_ai.config as cfg
from sherpa_ai.config.task_config import AgentConfig
from sherpa_ai.output_parser import TaskAction


def get_tools(memory, config):
Expand All @@ -39,7 +41,7 @@ def get_tools(memory, config):
class SearchArxivTool(BaseTool):
name = "Arxiv Search"
description = (
"Access all the papers from Arxiv to search for domain-specific scientific publication."
"Access all the papers from Arxiv to search for domain-specific scientific publication." # noqa: E501
"Only use this tool when you need information in the scientific paper."
)

Expand Down Expand Up @@ -79,17 +81,56 @@ def _arun(self, query: str) -> str:
class SearchTool(BaseTool):
name = "Search"
config = AgentConfig()
top_k: int = 10
description = (
"Access the internet to search for the information. Only use this tool when "
"you cannot find the information using internal search."
)

def augment_query(self, query) -> str:
return query + " site:" + self.config.gsite if self.config.gsite else query
def _run(
self, query: str, require_meta=False
) -> Union[str, Tuple[str, List[dict]]]:
result = ""
if self.config.search_domains:
query_list = [
query + " Site: " + str(i) for i in self.config.search_domains
]
if len(query_list) >= 5:
query_list = query_list[:5]
result = (
result
+ "Warning: Only the first 5 URLs are taken into consideration.\n"
) # noqa: E501
else:
query_list = [query]
if self.config.invalid_domains:
invalid_domain_string = ", ".join(self.config.invalid_domains)
result = (
result
+ f"Warning: The doman {invalid_domain_string} is invalid and is not taken into consideration.\n" # noqa: E501
) # noqa: E501

top_k = int(self.top_k / len(query_list))
if require_meta:
meta = []

def _run(self, query: str, require_meta=False) -> str:
query = self.augment_query(query)
for query in query_list:
cur_result = self._run_single_query(query, top_k, require_meta)

if require_meta:
result += "\n" + cur_result[0]
meta.extend(cur_result[1])
else:
result += "\n" + cur_result

if require_meta:
result = (result, meta)

return result

def _run_single_query(
self, query: str, top_k: int, require_meta=False
) -> Union[str, Tuple[str, List[dict]]]:
logger.debug(f"Search query: {query}")
google_serper = GoogleSerperAPIWrapper()
search_results = google_serper._google_serper_api_results(query)
Expand All @@ -107,7 +148,7 @@ def _run(self, query: str, require_meta=False) -> str:
title = search_results["organic"][0]["title"]
link = search_results["organic"][0]["link"]

response = "Answer: " + answer
response = "Answer: " + answer
meta = [{"Document": answer, "Source": link}]
if require_meta:
return response, meta
Expand All @@ -127,15 +168,15 @@ def _run(self, query: str, require_meta=False) -> str:
snippets.append(description)
for attribute, value in kg.get("attributes", {}).items():
snippets.append(f"{title} {attribute}: {value}.")
k = 10

search_type: Literal["news", "search", "places", "images"] = "search"
result_key_for_type = {
"news": "news",
"places": "places",
"images": "images",
"search": "organic",
}
for result in search_results[result_key_for_type[search_type]][:k]:
for result in search_results[result_key_for_type[search_type]][:top_k]:
if "snippet" in result:
snippets.append(result["snippet"])
for attribute, value in result.get("attributes", {}).items():
Expand All @@ -145,15 +186,19 @@ def _run(self, query: str, require_meta=False) -> str:
return ["No good Google Search Result was found"]

result = []

meta = []
for i in range(len(search_results["organic"][:10])):
for i in range(len(search_results["organic"][:top_k])):
r = search_results["organic"][i]
single_result = (
r["title"] + r["snippet"]
)
single_result = r["title"] + r["snippet"]

result.append(single_result)
meta.append({"Document": "Description: " + r["title"] + r["snippet"], "Source": r["link"]})
meta.append(
{
"Document": "Description: " + r["title"] + r["snippet"],
"Source": r["link"],
}
)
full_result = "\n".join(result)

# answer = " ".join(snippets)
Expand Down Expand Up @@ -183,7 +228,7 @@ class ContextTool(BaseTool):
name = "Context Search"
description = (
"Access internal technical documentation for AI related projects, including"
+ "Fixie, LangChain, GPT index, GPTCache, GPT4ALL, autoGPT, db-GPT, AgentGPT, sherpa."
+ "Fixie, LangChain, GPT index, GPTCache, GPT4ALL, autoGPT, db-GPT, AgentGPT, sherpa." # noqa: E501
+ "Only use this tool if you need information for these projects specifically."
)
memory: VectorStoreRetriever
Expand All @@ -201,8 +246,12 @@ def _run(self, query: str, need_meta=False) -> str:
+ "\n"
)
if need_meta:
metadata.append({'Document': doc.page_content,
"Source": doc.metadata.get("source", "")})
metadata.append(
{
"Document": doc.page_content,
"Source": doc.metadata.get("source", ""),
}
)

if need_meta:
return result, metadata
Expand All @@ -218,7 +267,7 @@ class UserInputTool(BaseTool):
name = "UserInput"
description = (
"Access the user input for the task."
"You use this tool if you need more context and would like to ask clarifying questions to solve the task"
"You use this tool if you need more context and would like to ask clarifying questions to solve the task" # noqa: E501
)

def _run(self, query: str) -> str:
Expand Down
25 changes: 23 additions & 2 deletions src/tests/unit_tests/config/test_task_config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from loguru import logger

from sherpa_ai.config import AgentConfig

Expand All @@ -11,8 +12,28 @@ def test_config_parses_all_parameters_correctly():

assert parsed == "Test input."
assert config.verbose
assert config.gsite == site
assert config.gsite == site.split(", ")
assert config.do_reflect
assert config.search_domains == ["https://www.google.com"]
assert config.invalid_domains == []


def test_config_parse_multiple_gsites_correctly():
site = "https://www.google.com, https://www.langchain.com, https://openai.com, /data/Python.html, 532"
input_str = f"Test input. --verbose --gsite {site} --do-reflect"

parsed, config = AgentConfig.from_input(input_str)

assert parsed == "Test input."
assert config.verbose
assert config.gsite == site.split(", ")
assert config.do_reflect
assert config.search_domains == [
"https://www.google.com",
"https://www.langchain.com",
"https://openai.com",
]
assert config.invalid_domains == ["/data/Python.html", "532"]


def test_config_parses_input_and_verbose_options_with_no_gsite():
Expand All @@ -22,7 +43,7 @@ def test_config_parses_input_and_verbose_options_with_no_gsite():

assert parsed == "Test input."
assert config.verbose
assert config.gsite is None
assert config.gsite == []


def test_config_raises_exception_for_unsupported_options():
Expand Down
67 changes: 64 additions & 3 deletions src/tests/unit_tests/tools/test_search_tool.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,76 @@
from sherpa_ai.config import AgentConfig
from sherpa_ai.output_parser import TaskAction
from sherpa_ai.tools import SearchTool


def test_search_query_includes_gsite_config():
site = "https://www.google.com"
config = AgentConfig(
verbose=True, gsite=site
)
search_tool = SearchTool(config=config)

query = "What is the weather today?"
search_result = search_tool._run(query)
assert search_result is not None
assert search_result != ""


def test_search_query_includes_multiple_gsite_config():
20001LastOrder marked this conversation as resolved.
Show resolved Hide resolved
site = "https://www.google.com, https://www.langchain.com, https://openai.com"
config = AgentConfig(
verbose=True,
gsite=site,
)
search_tool = SearchTool(config=config)
query = "What is the weather today?"
search_result = search_tool._run(query)
assert search_result is not None
assert search_result != ""


def test_search_query_includes_more_gsite_config_warning():
site = "https://www.google.com, https://www.langchain.com, https://openai.com, https://www.google.com, https://www.langchain.com, https://openai.com" # noqa: E501
config = AgentConfig(
verbose=True,
gsite=site,
)
assert config.gsite == site
assert config.gsite == site.split(", ")
search_tool = SearchTool(config=config)
query = "What is the weather today?"
updated_query = search_tool.augment_query(query)
assert f"site:{site}" in updated_query
search_result = search_tool._run(query)
assert (
"Warning: Only the first 5 URLs are taken into consideration." in search_result
)


def test_search_query_includes_more_gsite_config_empty():
site = ""
config = AgentConfig(verbose=True, gsite=site)
assert config.gsite == site.split(", ")
search_tool = SearchTool(config=config)
query = "What is the weather today?"
search_result = search_tool._run(query)
assert search_result is not None
assert search_result != ""


def test_search_query_includes_invalid_url():
site = "http://www.cwi.nl:80/%7Eguido/Python.html, /data/Python.html, 532, https://stackoverflow.com" # noqa: E501
invalid_domain_list = [
"/data/Python.html",
"532",
]
config = AgentConfig(
verbose=True,
gsite=site,
)
assert config.gsite == site.split(", ")
search_tool = SearchTool(config=config)
query = "What is the weather today?"
result = search_tool._run(query)

invalid_domain = ", ".join(invalid_domain_list)
expected_error = f"Warning: The doman {invalid_domain} is invalid and is not taken into consideration.\n" # noqa: E501

assert expected_error in result