diff --git a/src/Makefile b/src/Makefile index fbf20323..d152c6b0 100644 --- a/src/Makefile +++ b/src/Makefile @@ -11,4 +11,7 @@ lint: black . --check test: - pytest . \ No newline at end of file + echo "Running tests for sherpa" + pytest tests + echo "Running tests for sherpa slackapp" + cd apps/slackapp && pytest tests \ No newline at end of file diff --git a/src/apps/slackapp/pyproject.toml b/src/apps/slackapp/pyproject.toml index 6239107e..b43d3b30 100644 --- a/src/apps/slackapp/pyproject.toml +++ b/src/apps/slackapp/pyproject.toml @@ -33,6 +33,9 @@ build-backend = "poetry.core.masonry.api" pythonpath = [ "." ] +markers = [ + "external_api: this test calls 3rd party APIs" +] [tool.black] line-length = 88 diff --git a/src/apps/slackapp/slackapp/bolt_app.py b/src/apps/slackapp/slackapp/bolt_app.py index b3c988de..7f456399 100644 --- a/src/apps/slackapp/slackapp/bolt_app.py +++ b/src/apps/slackapp/slackapp/bolt_app.py @@ -7,12 +7,14 @@ from typing import Dict, List from flask import Flask, request +from langchain.schema import AIMessage, BaseMessage, HumanMessage from loguru import logger from slack_bolt import App from slack_bolt.adapter.flask import SlackRequestHandler from slackapp.routes.whitelist import whitelist_blueprint import sherpa_ai.config as cfg +from sherpa_ai.config import AgentConfig from sherpa_ai.connectors.vectorstores import get_vectordb from sherpa_ai.database.user_usage_tracker import UserUsageTracker from sherpa_ai.error_handling import AgentErrorHandler @@ -47,23 +49,48 @@ def hello_command(ack, body): ack(f"Hi, <@{user_id}>!") -def contains_verbose(query: str) -> bool: - """looks for -verbose in the question and returns True or False""" - return "-verbose" in query.lower() +def convert_thread_history_messages(messages: List[dict]) -> List[BaseMessage]: + results = [] + for message in messages: + logger.info(message) + if message["type"] != "message" and message["type"] != "text": + continue -def contains_verbosex(query: str) -> bool: - """looks for -verbosex in the question and returns True or False""" - return "-verbosex" in query.lower() + message_cls = AIMessage if message["user"] == self.ai_id else HumanMessage + # replace the at in the message with the name of the bot + text = message["text"].replace(f"@{self.ai_id}", f"@{self.ai_name}") + + text = text.split("#verbose", 1)[0] # remove everything after #verbose + text = text.replace("-verbose", "") # remove -verbose if it exists + results.append(message_cls(content=text)) + + return results def get_response( question: str, - previous_messages: List[Dict], - verbose_logger: BaseVerboseLogger, + previous_messages: List[BaseMessage], user_id: str, team_id: str, -): + verbose_logger: BaseVerboseLogger, + bot_info: Dict[str, str], +) -> str: + """ + Get response from the task agent for the question + + Args: + question (str): question to be answered + previous_messages (List[BaseMessage]): previous messages in the thread + user_id (str): user id of the user who asked the question + team_id (str): team id of the workspace + verbose_logger (BaseVerboseLogger): verbose logger to be used + bot_info (Dict[str, str]): information of the Slack bot + + Returns: + str: response from the task agent + """ + llm = SherpaChatOpenAI( openai_api_key=cfg.OPENAI_API_KEY, request_timeout=120, @@ -74,63 +101,54 @@ def get_response( memory = get_vectordb() - tools = get_tools(memory) + question, agent_config = AgentConfig.from_input(question) + verbose_logger = verbose_logger if agent_config.verbose else DummyVerboseLogger() + + tools = get_tools(memory, agent_config) ai_name = "Sherpa" - ai_id = bot["user_id"] + ai_id = bot_info["user_id"] task_agent = TaskAgent.from_llm_and_tools( ai_name="Sherpa", ai_role="assistant", - ai_id=bot["user_id"], + ai_id=bot_info["user_id"], memory=memory, tools=tools, previous_messages=previous_messages, llm=llm, verbose_logger=verbose_logger, + agent_config=agent_config, ) error_handler = AgentErrorHandler() question = question.replace(f"@{ai_id}", f"@{ai_name}") - if contains_verbosex(query=question): - logger.info("Verbose mode is on, show all") - question = question.replace("-verbose", "") - response = error_handler.run_with_error_handling(task_agent.run, task=question) - agent_log = task_agent.logger # logger is updated after running task_agent.run - try: # in case log_formatter fails - verbose_message = log_formatter(agent_log) - except KeyError: - verbose_message = str(agent_log) - return response, verbose_message - - elif contains_verbose(query=question): - logger.info("Verbose mode is on, commands only") - question = question.replace("-verbose", "") - response = error_handler.run_with_error_handling(task_agent.run, task=question) - - agent_log = task_agent.logger # logger is updated after running task_agent.run - try: # in case log_formatter fails - verbose_message = show_commands_only(agent_log) - except KeyError: - verbose_message = str(agent_log) - return response, verbose_message + response = error_handler.run_with_error_handling(task_agent.run, task=question) - else: - logger.info("Verbose mode is off") - response = error_handler.run_with_error_handling(task_agent.run, task=question) - return response, None - -def file_event_handler(say , files , team_id ,user_id , thread_ts , question): - if files[0]['size'] > cfg.FILE_SIZE_LIMIT: - say("Sorry, the file you attached is larger than 2mb. Please try again with a smaller file" , thread_ts=thread_ts) - return { "status":"error" } - file_prompt = QuestionWithFileHandler( question=question , team_id=team_id , user_id=user_id, files=files , token=cfg.SLACK_OAUTH_TOKEN ) + return response + + +def file_event_handler(say, files, team_id, user_id, thread_ts, question): + if files[0]["size"] > cfg.FILE_SIZE_LIMIT: + say( + "Sorry, the file you attached is larger than 2mb. Please try again with a smaller file", + thread_ts=thread_ts, + ) + return {"status": "error"} + file_prompt = QuestionWithFileHandler( + question=question, + team_id=team_id, + user_id=user_id, + files=files, + token=cfg.SLACK_OAUTH_TOKEN, + ) file_prompt_data = file_prompt.reconstruct_prompt_with_file() - if file_prompt_data['status']=='success': - question = file_prompt_data['data'] - return {"status":"success" , "question":question} + if file_prompt_data["status"] == "success": + question = file_prompt_data["data"] + return {"status": "success", "question": question} else: - say(file_prompt_data['message'] , thread_ts=thread_ts) - return { "status":"error" } + say(file_prompt_data["message"], thread_ts=thread_ts) + return {"status": "error"} + @app.event("app_mention") def event_test(client, say, event): @@ -138,20 +156,18 @@ def event_test(client, say, event): thread_ts = event.get("thread_ts", None) or event["ts"] replies = client.conversations_replies(channel=event["channel"], ts=thread_ts) previous_messages = replies["messages"][:-1] - - - # check if the verbose is on - verbose_on = contains_verbose(question) - verbose_logger = ( - SlackVerboseLogger(say, thread_ts) if verbose_on else DummyVerboseLogger() - ) + previous_messages = convert_thread_history_messages(previous_messages) input_message = replies["messages"][-1] - user_id = input_message["user"] - - # teamid is found on different places depending on the message from slack + user_id = input_message["user"] + + # teamid is found on different places depending on the message from slack # if file exist it will be inside one of the files other wise on the parent message - team_id = input_message['files'][0]["user_team"] if 'files' in input_message else input_message["team"] + team_id = ( + input_message["files"][0]["user_team"] + if "files" in input_message + else input_message["team"] + ) combined_id = user_id + "_" + team_id if cfg.FLASK_DEBUG: @@ -173,12 +189,19 @@ def event_test(client, say, event): if can_excute: if "files" in event: - files = event['files'] - file_event = file_event_handler( files=files ,say=say ,team_id=team_id , thread_ts=thread_ts , user_id=user_id , question=question) - if file_event['status']=="error": + files = event["files"] + file_event = file_event_handler( + files=files, + say=say, + team_id=team_id, + thread_ts=thread_ts, + user_id=user_id, + question=question, + ) + if file_event["status"] == "error": return else: - question = file_event['question'] + question = file_event["question"] else: # used to reconstruct the question. if the question contains a link recreate # them so that they contain scraped and summarized content of the link @@ -186,13 +209,22 @@ def event_test(client, say, event): question=question, slack_message=[replies["messages"][-1]] ) question = reconstructor.reconstruct_prompt() - results, _ = get_response( - question, previous_messages, verbose_logger, user_id, team_id + + results = get_response( + question, + previous_messages, + user_id, + team_id, + verbose_logger=SlackVerboseLogger(say, thread_ts), + bot_info=bot, ) say(results, thread_ts=thread_ts) else: - say(f"""I'm sorry for any inconvenience, but it appears you've gone over your daily token limit. Don't worry, you'll be able to use our service again in approximately {usage_cheker['time_left']}.Thank you for your patience and understanding.""", thread_ts=thread_ts) + say( + f"""I'm sorry for any inconvenience, but it appears you've gone over your daily token limit. Don't worry, you'll be able to use our service again in approximately {usage_cheker['time_left']}.Thank you for your patience and understanding.""", + thread_ts=thread_ts, + ) @app.event("app_home_opened") diff --git a/src/apps/slackapp/tests/test_get_response.py b/src/apps/slackapp/tests/test_get_response.py new file mode 100644 index 00000000..11321fef --- /dev/null +++ b/src/apps/slackapp/tests/test_get_response.py @@ -0,0 +1,57 @@ +from datetime import datetime + +import pytest +from slackapp.bolt_app import get_response + +import sherpa_ai.config as cfg +from sherpa_ai.verbose_loggers import DummyVerboseLogger + + +@pytest.mark.external_api +def test_get_response_contains_todays_date(): + question = "What is the date today, using the following format: YYYY-MM-DD?" + date = datetime.now().strftime("%Y-%m-%d") + + if cfg.SERPER_API_KEY is None: + pytest.skip( + "SERPER_API_KEY not found in environment variables, skipping this test" + ) + + verbose_logger = DummyVerboseLogger() + + response = get_response( + question=question, + previous_messages=[], + user_id="", + team_id="", + verbose_logger=verbose_logger, + bot_info={"user_id": "Sherpa"}, + ) + assert date in response, "Today's date not found in response" + + +@pytest.mark.external_api +def test_response_contains_correct_info(): + question = "What is AutoGPT and how does it compare with MetaGPT" + + if cfg.SERPER_API_KEY is None: + pytest.skip( + "SERPER_API_KEY not found in environment variables, skipping this test" + ) + + verbose_logger = DummyVerboseLogger() + + response = get_response( + question=question, + previous_messages=[], + user_id="", + team_id="", + verbose_logger=verbose_logger, + bot_info={"user_id": "Sherpa"}, + ) + + print(response) + assert response is not None + assert response != "" + assert "AutoGPT" in response + assert "MetaGPT" in response diff --git a/src/pyproject.toml b/src/pyproject.toml index 9eae2d04..cc14c6c2 100644 --- a/src/pyproject.toml +++ b/src/pyproject.toml @@ -43,7 +43,7 @@ pythonpath = [ "." ] markers = [ - "real: this test calls 3rd party APIs" + "external_api: this test calls 3rd party APIs" ] [tool.black] diff --git a/src/sherpa_ai/actions/google_search.py b/src/sherpa_ai/actions/google_search.py index 36af9266..e2f6a1c9 100644 --- a/src/sherpa_ai/actions/google_search.py +++ b/src/sherpa_ai/actions/google_search.py @@ -2,6 +2,7 @@ from loguru import logger from sherpa_ai.actions.base import BaseAction +from sherpa_ai.config.task_config import AgentConfig from sherpa_ai.tools import SearchTool SEARCH_SUMMARY_DESCRIPTION = """Role Description: {role_description} @@ -23,6 +24,7 @@ def __init__( task: str, llm: BaseLanguageModel, description: str = SEARCH_SUMMARY_DESCRIPTION, + config: AgentConfig = AgentConfig(), n: int = 5, ): self.role_description = role_description @@ -32,7 +34,7 @@ def __init__( self.llm = llm self.n = n - self.search_tool = SearchTool() + self.search_tool = SearchTool(config=config) def execute(self, query) -> str: result = self.search_tool._run(query) diff --git a/src/sherpa_ai/actions/planning.py b/src/sherpa_ai/actions/planning.py index 4ec6c646..e49b1c7a 100644 --- a/src/sherpa_ai/actions/planning.py +++ b/src/sherpa_ai/actions/planning.py @@ -116,8 +116,8 @@ def execute( self, task: str, agent_pool_description: str, - last_plan: Optional[str], - feedback: Optional[str], + last_plan: Optional[str] = None, + feedback: Optional[str] = None, ) -> Plan: """ Execute the action diff --git a/src/sherpa_ai/config.py b/src/sherpa_ai/config/__init__.py similarity index 97% rename from src/sherpa_ai/config.py rename to src/sherpa_ai/config/__init__.py index 033e7a72..7b1a82e6 100644 --- a/src/sherpa_ai/config.py +++ b/src/sherpa_ai/config/__init__.py @@ -17,6 +17,8 @@ from dotenv import find_dotenv, load_dotenv from loguru import logger +from sherpa_ai.config.task_config import AgentConfig + env_path = find_dotenv(usecwd=True) load_dotenv(env_path) @@ -109,3 +111,7 @@ def check_vectordb_setting(): logger.info("Config: OpenAI environment variables are set") check_vectordb_setting() + +__all__ = [ + "AgentConfig", +] diff --git a/src/sherpa_ai/config/task_config.py b/src/sherpa_ai/config/task_config.py new file mode 100644 index 00000000..f3a89c70 --- /dev/null +++ b/src/sherpa_ai/config/task_config.py @@ -0,0 +1,54 @@ +import re +from argparse import ArgumentParser +from typing import List, Optional, Tuple + +from pydantic import BaseModel + + +class AgentConfig(BaseModel): + verbose: bool = False + gsite: Optional[str] = None + do_reflect: bool = False + + @classmethod + def from_input(cls, input_str: str) -> Tuple[str, "AgentConfig"]: + """ + parse input string into AgentConfig. The configurations are + at the end of the string + """ + parts = re.split(r"(?=--)", input_str) + configs = [] + + for part in parts[1:]: + part = part.strip() + configs.extend(part.split()) + + return parts[0].strip(), cls.from_config(configs) + + @classmethod + def from_config(cls, configs: List[str]) -> "AgentConfig": + parser = ArgumentParser() + + parser.add_argument( + "--verbose", + action="store_true", + help="enable verbose messaging during agent execution", + ) + parser.add_argument( + "--gsite", + type=str, + default=None, + help="site to be used for the Google search tool.", + ) + parser.add_argument( + "--do-reflect", + action="store_true", + help="enable performing the reflection step for each agent.", + ) + + args, unknown = parser.parse_known_args(configs) + + if len(unknown) > 0: + raise ValueError(f"Invalid configuration, check your input: {unknown}") + + return AgentConfig(**args.__dict__) diff --git a/src/sherpa_ai/task_agent.py b/src/sherpa_ai/task_agent.py index 765620ba..d349d4cf 100644 --- a/src/sherpa_ai/task_agent.py +++ b/src/sherpa_ai/task_agent.py @@ -21,6 +21,7 @@ from sherpa_ai.action_planner import SelectiveActionPlanner from sherpa_ai.action_planner.base import BaseActionPlanner +from sherpa_ai.config import AgentConfig from sherpa_ai.output_parser import BaseTaskOutputParser, TaskOutputParser from sherpa_ai.output_parsers import LinkParser, MDToSlackParse from sherpa_ai.post_processors import md_link_to_slack @@ -42,14 +43,14 @@ def __init__( action_planner: BaseActionPlanner, output_parser: BaseTaskOutputParser, tools: List[BaseTool], - previous_messages: List[dict], + previous_messages: List[BaseMessage], verbose_logger: BaseVerboseLogger, feedback_tool: Optional[HumanInputRun] = None, + agent_config: AgentConfig = AgentConfig(), max_iterations: int = 5, ): self.ai_name = ai_name self.memory = memory - # self.full_message_history: List[BaseMessage] = [] self.next_action_count = 0 self.llm = llm self.output_parser = output_parser @@ -61,16 +62,15 @@ def __init__( self.max_iterations = max_iterations self.loop_count = 0 self.ai_id = ai_id - self.previous_message = self.process_chat_history(previous_messages) - self.logger = [] # added by JF + self.agent_config = agent_config + self.previous_message = previous_messages + self.logger = [] link_parser = LinkParser() slack_link_paerser = MDToSlackParse() self.tool_output_parsers = [link_parser] self.output_parsers = [link_parser, slack_link_paerser] - # print(self.full_message_history) - # print("message:", self.previous_message) @classmethod def from_llm_and_tools( @@ -86,6 +86,7 @@ def from_llm_and_tools( human_in_the_loop: bool = False, output_parser: Optional[BaseTaskOutputParser] = None, max_iterations: int = 1, + agent_config: AgentConfig = AgentConfig(), verbose_logger: BaseVerboseLogger = DummyVerboseLogger(), ): if action_planner is None: @@ -105,6 +106,7 @@ def from_llm_and_tools( previous_messages, verbose_logger, feedback_tool=human_feedback_tool, + agent_config=agent_config, max_iterations=max_iterations, ) @@ -129,9 +131,9 @@ def run(self, task: str) -> str: user_input = ( "Use the above information to respond to the user's message:" f"\n{task}\n\n" - "If you use any resource, then create inline citation by adding " - " of the reference document at the end of the sentence in the format " - "of 'Sentence [DocID]'\n" + "If you use any resource, then create inline citation by adding" + "the DocID of the reference document at the end of the sentence in " + "the format of 'Sentence [DocID]'\n" "Example:\n" "Sentence1 [1]. Sentence2. Sentence3 [2].\n" "Only use the reference document. DO NOT use any links" @@ -152,14 +154,20 @@ def run(self, task: str) -> str: logger_step["reply"] = reply_json except json.JSONDecodeError: logger_step["reply"] = assistant_reply # last reply is a string + if self.agent_config.verbose: + self.verbose_logger.log(f"```{assistant_reply}```") + self.logger.append(logger_step) - ########## Serial Verbose Feature ####### - try: - formatted_logger_step = show_commands_only(logger_step) - except Exception as e: - logger.error(e) + # Serial Verbose Feature + if self.agent_config.verbose: formatted_logger_step = logger_step + else: + try: + formatted_logger_step = show_commands_only(logger_step) + except KeyError as e: + logger.error(e) + formatted_logger_step = logger_step logger.info(f"```{formatted_logger_step}```") self.verbose_logger.log(f"```{formatted_logger_step}```") @@ -179,7 +187,7 @@ def run(self, task: str) -> str: ): result = result["command"]["args"]["response"] except json.JSONDecodeError: - result = assistant_reply + result = str(assistant_reply) for output_parser in self.output_parsers: result = output_parser.parse_output(result) @@ -251,24 +259,6 @@ def set_user_input(self, user_input: str): self.memory.add_documents([Document(page_content=memory_to_add)]) - def process_chat_history(self, messages: List[dict]) -> List[BaseMessage]: - results = [] - - for message in messages: - logger.info(message) - if message["type"] != "message" and message["type"] != "text": - continue - - message_cls = AIMessage if message["user"] == self.ai_id else HumanMessage - # replace the at in the message with the name of the bot - text = message["text"].replace(f"@{self.ai_id}", f"@{self.ai_name}") - # added by JF - text = text.split("#verbose", 1)[0] # remove everything after #verbose - text = text.replace("-verbose", "") # remove -verbose if it exists - results.append(message_cls(content=text)) - - return results - def process_output(self, output: str) -> str: """ Process the output of the AI to remove the bot's name and replace it with @bot diff --git a/src/sherpa_ai/tools.py b/src/sherpa_ai/tools.py index 8b26ccce..3146cb5f 100644 --- a/src/sherpa_ai/tools.py +++ b/src/sherpa_ai/tools.py @@ -16,16 +16,17 @@ from typing_extensions import Literal import sherpa_ai.config as cfg +from sherpa_ai.config.task_config import AgentConfig -def get_tools(memory): +def get_tools(memory, config): tools = [] # tools.append(ContextTool(memory=memory)) tools.append(UserInputTool()) if cfg.SERPER_API_KEY is not None: - search_tool = SearchTool() + search_tool = SearchTool(config=config) tools.append(search_tool) else: logger.warning( @@ -68,7 +69,7 @@ def _run(self, query: str) -> str: ) logger.debug(f"Arxiv Search Result: {result_list}") - + return " ".join(result_list) def _arun(self, query: str) -> str: @@ -77,12 +78,18 @@ def _arun(self, query: str) -> str: class SearchTool(BaseTool): name = "Search" + config = AgentConfig() description = ( "Access the internet to search for the information. Only use this tool when " "you cannot find the information using internal search." ) + def augment_query(self, query) -> str: + return query + " site:" + self.config.gsite if self.config.gsite else query + def _run(self, query: str) -> str: + query = self.augment_query(query) + logger.debug(f"Search query: {query}") google_serper = GoogleSerperAPIWrapper() search_results = google_serper._google_serper_api_results(query) @@ -190,7 +197,7 @@ def _arun(self, query: str) -> str: class UserInputTool(BaseTool): - # TODO: Make an action for the user input + # TODO: Make an action for the user input name = "UserInput" description = ( "Access the user input for the task." diff --git a/src/sherpa_ai/utils.py b/src/sherpa_ai/utils.py index 50abcc12..4a496277 100644 --- a/src/sherpa_ai/utils.py +++ b/src/sherpa_ai/utils.py @@ -246,7 +246,7 @@ def show_commands_only(logs): log_strings.append(formatted_reply) else: # for final response - formatted_reply = f"""đź’ˇThought process finished!""" + formatted_reply = """đź’ˇThought process finished!""" log_strings.append(formatted_reply) log_string = "\n".join(log_strings) diff --git a/src/tests/integration_tests/test_task_agent.py b/src/tests/integration_tests/test_task_agent.py index 3892c505..97fefdd0 100644 --- a/src/tests/integration_tests/test_task_agent.py +++ b/src/tests/integration_tests/test_task_agent.py @@ -31,7 +31,7 @@ def config_task_agent( return task_agent -@pytest.mark.real +@pytest.mark.external_api def test_task_solving_with_search(): """Test task solving with search""" question = "What is the date today, using the following format: YYYY-MM-DD?" @@ -42,7 +42,7 @@ def test_task_solving_with_search(): "SERPER_API_KEY not found in environment variables, skipping this test" ) memory = get_vectordb() - tools = [SearchTool(api_wrapper=GoogleSerperAPIWrapper())] + tools = [SearchTool()] task_agent = config_task_agent(tools=tools, memory=memory) @@ -50,7 +50,7 @@ def test_task_solving_with_search(): assert date in response, "Today's date not found in response" -@pytest.mark.real +@pytest.mark.external_api def test_task_solving_with_context_search(): question = "What is langchain?" diff --git a/src/tests/unit_tests/test_context_search.py b/src/tests/unit_tests/actions/test_context_search.py similarity index 100% rename from src/tests/unit_tests/test_context_search.py rename to src/tests/unit_tests/actions/test_context_search.py diff --git a/src/tests/unit_tests/test_planning.py b/src/tests/unit_tests/actions/test_planning.py similarity index 92% rename from src/tests/unit_tests/test_planning.py rename to src/tests/unit_tests/actions/test_planning.py index c441ee88..ba49d967 100644 --- a/src/tests/unit_tests/test_planning.py +++ b/src/tests/unit_tests/actions/test_planning.py @@ -4,8 +4,6 @@ import sherpa_ai.config as cfg from sherpa_ai.actions.planning import TaskPlanning -llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) - def test_planning(): llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) diff --git a/src/tests/unit_tests/test_critic_agent.py b/src/tests/unit_tests/agents/test_critic_agent.py similarity index 92% rename from src/tests/unit_tests/test_critic_agent.py rename to src/tests/unit_tests/agents/test_critic_agent.py index cb3ffd95..d64bd440 100644 --- a/src/tests/unit_tests/test_critic_agent.py +++ b/src/tests/unit_tests/agents/test_critic_agent.py @@ -21,7 +21,7 @@ def test_evaluation_matrices(): llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) - critic_agent = Critic(name="CriticAgent", llm=llm, ratio=1) + critic_agent = Critic(llm=llm, ratio=1) i_score, i_evaluation = critic_agent.get_importance_evaluation(task, plan) assert type(i_score) is int @@ -34,7 +34,7 @@ def test_evaluation_matrices(): def test_get_feedback(): llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) - critic_agent = Critic(name="CriticAgent", llm=llm, ratio=1) + critic_agent = Critic(llm=llm, ratio=1) feedback_list = critic_agent.get_feedback(task, plan) assert len(feedback_list) == 3 # assert type(feedback) is str diff --git a/src/tests/unit_tests/test_ml_engineer.py b/src/tests/unit_tests/agents/test_ml_engineer.py similarity index 100% rename from src/tests/unit_tests/test_ml_engineer.py rename to src/tests/unit_tests/agents/test_ml_engineer.py diff --git a/src/tests/unit_tests/test_physicist.py b/src/tests/unit_tests/agents/test_physicist.py similarity index 100% rename from src/tests/unit_tests/test_physicist.py rename to src/tests/unit_tests/agents/test_physicist.py diff --git a/src/tests/unit_tests/test_planner.py b/src/tests/unit_tests/agents/test_planner.py similarity index 68% rename from src/tests/unit_tests/test_planner.py rename to src/tests/unit_tests/agents/test_planner.py index a26873f3..83f85d6c 100644 --- a/src/tests/unit_tests/test_planner.py +++ b/src/tests/unit_tests/agents/test_planner.py @@ -4,35 +4,40 @@ from sherpa_ai.agents.agent_pool import AgentPool from sherpa_ai.agents.physicist import Physicist from sherpa_ai.agents.planner import Planner -from sherpa_ai.agents.programmer import Programmer + +# from sherpa_ai.agents.programmer import Programmer from sherpa_ai.memory.shared_memory import SharedMemory def test_planner(): - programmer_description = ( - "The programmer receives requirements about a program and write it" - ) - programmer = Programmer(name="Programmer", description=programmer_description) + # programmer_description = ( + # "The programmer receives requirements about a program and write it" + # ) + # programmer = Programmer(name="Programmer", description=programmer_description) + + llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) physicist_description = ( "The physicist agent answers questions or research about physics-related topics" ) - physicist = Physicist(name="Physicist", description=physicist_description) + physicist = Physicist( + name="Physicist", + description=physicist_description, + llm=llm, + ) agent_pool = AgentPool() - agent_pool.add_agents([programmer, physicist]) + agent_pool.add_agents([physicist]) - shared_memeory = SharedMemory( + shared_memory = SharedMemory( objective="Share the information across different agents.", agent_pool=agent_pool, ) - llm = OpenAI(openai_api_key=cfg.OPENAI_API_KEY, temperature=0) - planner = Planner( name="planner", agent_pool=agent_pool, - shared_memory=shared_memeory, + shared_memory=shared_memory, llm=llm, ) diff --git a/src/tests/unit_tests/test_qa_agent.py b/src/tests/unit_tests/agents/test_qa_agent.py similarity index 100% rename from src/tests/unit_tests/test_qa_agent.py rename to src/tests/unit_tests/agents/test_qa_agent.py diff --git a/src/tests/unit_tests/config/test_task_config.py b/src/tests/unit_tests/config/test_task_config.py new file mode 100644 index 00000000..22fbc559 --- /dev/null +++ b/src/tests/unit_tests/config/test_task_config.py @@ -0,0 +1,33 @@ +import pytest + +from sherpa_ai.config import AgentConfig + + +def test_all_parameters_parse_successfully(): + site = "https://www.google.com" + input_str = f"Test input. --verbose --gsite {site} --do-reflect" + + parsed, config = AgentConfig.from_input(input_str) + + assert parsed == "Test input." + assert config.verbose + assert config.gsite == site + assert config.do_reflect + + +def test_no_gsite_parses_successfully(): + input_str = "Test input. --verbose" + + parsed, config = AgentConfig.from_input(input_str) + + assert parsed == "Test input." + assert config.verbose + assert config.gsite is None + + +def test_parse_args_noise(): + site = "https://www.google.com" + input_str = f"This is an input with -- but--should not be considered --verbose --verbosex --gsite {site} --do-reflect" # noqa: E501 + + with pytest.raises(ValueError): + AgentConfig.from_input(input_str) diff --git a/src/tests/unit_tests/test_extract_github_readme.py b/src/tests/unit_tests/scrape/test_extract_github_readme.py similarity index 100% rename from src/tests/unit_tests/test_extract_github_readme.py rename to src/tests/unit_tests/scrape/test_extract_github_readme.py diff --git a/src/tests/unit_tests/test_prompt_reconstructor.py b/src/tests/unit_tests/scrape/test_prompt_reconstructor.py similarity index 100% rename from src/tests/unit_tests/test_prompt_reconstructor.py rename to src/tests/unit_tests/scrape/test_prompt_reconstructor.py diff --git a/src/tests/unit_tests/tools/test_search_tool.py b/src/tests/unit_tests/tools/test_search_tool.py new file mode 100644 index 00000000..ab78e80e --- /dev/null +++ b/src/tests/unit_tests/tools/test_search_tool.py @@ -0,0 +1,20 @@ +from sherpa_ai.config import AgentConfig +from sherpa_ai.tools import SearchTool + + +def test_search_query_includes_gsite_config(): + site = "https://www.google.com" + config = AgentConfig( + verbose=True, + gsite=site, + ) + + assert config.gsite == site + + search_tool = SearchTool(config=config) + + query = "What is the weather today?" + + updated_query = search_tool.augment_query(query) + + assert f"site:{site}" in updated_query