Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hydra config #289

Merged
merged 15 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/.env-sample
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ FLASK_DEBUG=True # True or False
# OpenAI. Mandatory. Enables language modeling.
OPENAI_API_KEY= # OpenAI API key

# Temperature configuration for OpenAI. Optional. Default is 0.
TEMPERATURE= # Only applies to the legacy task agent

# Serper.dev. Optional. Enables Google web search capability
# SERPER_API_KEY= # Serper.dev API key

Expand Down
1,603 changes: 826 additions & 777 deletions src/apps/slackapp/poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/apps/slackapp/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ flask-cors = "^4.0.0"
flask = "^2.3.3"
loguru = "^0.7.0"
sherpa-ai = {path = "../..", develop = true}
hydra-core = "^1.3.2"

[tool.poetry.scripts]
sherpa_slack = 'slackapp.bolt_app:main'
Expand Down
46 changes: 21 additions & 25 deletions src/apps/slackapp/slackapp/bolt_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,16 @@
##############################################

import time
from typing import Dict, List
from typing import Dict, List, Optional

from flask import Flask, request
from langchain.schema import AIMessage, BaseMessage, HumanMessage
from loguru import logger
from omegaconf import OmegaConf
from slack_bolt import App
from slack_bolt.adapter.flask import SlackRequestHandler
from slackapp.routes.whitelist import whitelist_blueprint
from slackapp.utils import get_qa_agent_from_config_file

import sherpa_ai.config as cfg
from sherpa_ai.agents import QAAgent
Expand Down Expand Up @@ -92,7 +94,9 @@ def get_response(
previous_messages: List[BaseMessage],
verbose_logger: BaseVerboseLogger,
bot_info: Dict[str, str],
llm: SherpaChatOpenAI = None,
llm=None,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
) -> str:
"""
Get response from the task agent for the question
Expand All @@ -103,6 +107,8 @@ def get_response(
verbose_logger (BaseVerboseLogger): verbose logger to be used
bot_info (Dict[str, str]): information of the Slack bot
llm (SherpaChatOpenAI, optional): LLM to be used. Defaults to None.
team_id (str, optional): team id of the Slack workspace. Defaults to "".
user_id (str, optional): user id of the Slack user. Defaults to "".

Returns:
str: response from the task agent
Expand All @@ -120,6 +126,13 @@ def get_response(
tools = get_tools(memory, agent_config)

if agent_config.use_task_agent:
llm = SherpaChatOpenAI(
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
temperature=cfg.TEMPERATURE,
)

verbose_logger.log("⚠️🤖 Use task agent (obsolete)...")
task_agent = TaskAgent.from_llm_and_tools(
ai_name=ai_name,
Expand All @@ -135,21 +148,11 @@ def get_response(

response = error_handler.run_with_error_handling(task_agent.run, task=question)
else:
memory = SharedMemory(objective="Answer the question")

agent = get_qa_agent_from_config_file("conf/config.yaml", team_id, user_id, llm)
Eyobyb marked this conversation as resolved.
Show resolved Hide resolved
for message in previous_messages:
memory.add(EventType.result, message.type, message.content)
memory.add(EventType.task, "human", question)

agent = QAAgent(
llm=llm,
name=ai_name,
num_runs=1,
shared_memory=memory,
agent_config=agent_config,
require_meta=True,
verbose_logger=verbose_logger,
)
agent.shared_memory.add(EventType.result, message.type, message.content)
agent.shared_memory.add(EventType.task, "human", question)
agent.verbose_logger = verbose_logger

error_handler = AgentErrorHandler()
response = error_handler.run_with_error_handling(agent.run)
Expand Down Expand Up @@ -242,20 +245,13 @@ def event_test(client, say, event):
)
question = reconstructor.reconstruct_prompt()

llm = SherpaChatOpenAI(
openai_api_key=cfg.OPENAI_API_KEY,
user_id=user_id,
team_id=team_id,
verbose_logger = slack_verbose_logger,
temperature=cfg.TEMPRATURE,
)

results = get_response(
question,
previous_messages,
verbose_logger=slack_verbose_logger,
bot_info=bot,
llm=llm,
team_id=team_id,
user_id=user_id,
)

say(results, thread_ts=thread_ts)
Expand Down
33 changes: 33 additions & 0 deletions src/apps/slackapp/slackapp/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

from hydra.utils import instantiate
from langchain.base_language import BaseLanguageModel
from omegaconf import OmegaConf

from sherpa_ai.agents.qa_agent import QAAgent
from sherpa_ai.config.task_config import AgentConfig


def get_qa_agent_from_config_file(
config_path: str,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
llm: Optional[BaseLanguageModel] = None,
) -> QAAgent:
config = OmegaConf.load(config_path)

agent_config: AgentConfig = instantiate(config.agent_config)
if user_id is not None:
config["user_id"] = user_id

if team_id is not None:
config["team_id"] = team_id

if llm is None:
qa_agent: QAAgent = instantiate(config.qa_agent, agent_config=agent_config)
else:
qa_agent: QAAgent = instantiate(
config.qa_agent, agent_config=agent_config, llm=llm
)

return qa_agent
51 changes: 51 additions & 0 deletions src/apps/slackapp/tests/data/test_get_agent.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
shared_memory:
_target_: sherpa_ai.memory.shared_memory.SharedMemory
oshoma marked this conversation as resolved.
Show resolved Hide resolved
objective: Answer the question

user_id: none
team_id: none

llm:
_target_: sherpa_ai.models.sherpa_base_chat_model.SherpaChatOpenAI
model_name: gpt-4
temperature: 0.7
user_id: ${user_id}
team_id: ${team_id}

agent_config:
_target_: sherpa_ai.config.task_config.AgentConfig

citation_validation:
_target_: sherpa_ai.output_parsers.citation_validation.CitationValidation
sequence_threshold: 0.8
jaccard_threshold: 0.7
token_overlap: 0.6

arxiv_search:
_target_: sherpa_ai.actions.arxiv_search.ArxivSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
max_results: 3

google_search:
_target_: sherpa_ai.actions.GoogleSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
include_metadata: true
config: ${agent_config}

qa_agent:
_target_: sherpa_ai.agents.qa_agent.QAAgent
llm: ${llm}
shared_memory: ${shared_memory}
name: QA Sherpa
description: Act as a question answering agent
agent_config: ${agent_config}
num_runs: 1
actions:
- ${arxiv_search}
- ${google_search}
validations:
- ${citation_validation}
20 changes: 20 additions & 0 deletions src/apps/slackapp/tests/test_get_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from slackapp.utils import get_qa_agent_from_config_file

from sherpa_ai.actions import ArxivSearch, GoogleSearch
from sherpa_ai.agents.qa_agent import QAAgent
from sherpa_ai.output_parsers.citation_validation import CitationValidation
from sherpa_ai.test_utils.data import get_test_data_file_path


def test_get_agent(get_test_data_file_path): # noqa: F811
config_filename = get_test_data_file_path(__file__, "test_get_agent.yaml")
agent = get_qa_agent_from_config_file(config_filename)

assert agent is not None
assert type(agent) is QAAgent

assert len(agent.actions) == 2
assert type(agent.actions[0]) is ArxivSearch
assert type(agent.actions[1]) is GoogleSearch
assert len(agent.validations) == 1
assert type(agent.validations[0]) is CitationValidation
55 changes: 55 additions & 0 deletions src/conf/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
shared_memory:
_target_: sherpa_ai.memory.shared_memory.SharedMemory
objective: Answer the question

user_id: none
team_id: none

llm:
_target_: sherpa_ai.models.sherpa_base_chat_model.SherpaChatOpenAI
model_name: gpt-3.5-turbo
temperature: 0
user_id: ${user_id}
team_id: ${team_id}

agent_config:
_target_: sherpa_ai.config.task_config.AgentConfig

citation_validation:
_target_: sherpa_ai.output_parsers.citation_validation.CitationValidation
sequence_threshold: 0.5
jaccard_threshold: 0.5
token_overlap: 0.5

number_validation:
_target_: sherpa_ai.output_parsers.number_validation.NumberValidation

arxiv_search:
_target_: sherpa_ai.actions.arxiv_search.ArxivSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
max_results: 3

google_search:
_target_: sherpa_ai.actions.GoogleSearch
role_description: Act as a question answering agent
task: Question answering
llm: ${llm}
include_metadata: true
config: ${agent_config}

qa_agent:
_target_: sherpa_ai.agents.qa_agent.QAAgent
llm: ${llm}
shared_memory: ${shared_memory}
name: QA Sherpa
description: Act as a question answering agent
agent_config: ${agent_config}
num_runs: 1
validation_steps: 1
actions:
- ${google_search}
validations:
- ${number_validation}
- ${citation_validation}
Loading